From 07a30791562f626c63d3f85a5c87dd411e3108cc Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sat, 20 Jan 2024 13:19:37 +0000 Subject: [PATCH 1/6] Backport recent improvements to the implementation of `Protocol` --- CHANGELOG.md | 11 +++++ src/test_typing_extensions.py | 73 +++++++++++++++++++++------- src/typing_extensions.py | 91 +++++++++++++++++++++++++++-------- 3 files changed, 136 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fedc2a3f..c3957a29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +# Unreleased + +- Speedup `issubclass()` checks against simple runtime-checkable protocols by + around 6% (backporting https://github.com/python/cpython/pull/112717, by Alex + Waygood). +- Fix a regression in the implementation of protocols where `typing.Protocol` + classes that were not marked as `@runtime-checkable` would be unnecessarily + introspected, potentially causing exceptions to be raised if the protocol had + problematic members. Patch by Alex Waygood, backporting + https://github.com/python/cpython/pull/113401. + # Release 4.9.0 (December 9, 2023) This feature release adds `typing_extensions.ReadOnly`, as specified diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 77876b7f..11208d95 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -1872,18 +1872,6 @@ class E(C, BP): pass self.assertNotIsInstance(D(), E) self.assertNotIsInstance(E(), D) - def test_runtimecheckable_on_typing_dot_Protocol(self): - @runtime_checkable - class Foo(typing.Protocol): - x: int - - class Bar: - def __init__(self): - self.x = 42 - - self.assertIsInstance(Bar(), Foo) - self.assertNotIsInstance(object(), Foo) - def test_typing_dot_runtimecheckable_on_Protocol(self): @typing.runtime_checkable class Foo(Protocol): @@ -2817,8 +2805,8 @@ def meth(self): pass # noqa: B027 self.assertNotIn("__protocol_attrs__", vars(NonP)) self.assertNotIn("__protocol_attrs__", vars(NonPR)) - self.assertNotIn("__callable_proto_members_only__", vars(NonP)) - self.assertNotIn("__callable_proto_members_only__", vars(NonPR)) + self.assertNotIn("__non_callable_proto_members__", vars(NonP)) + self.assertNotIn("__non_callable_proto_members__", vars(NonPR)) acceptable_extra_attrs = { '_is_protocol', '_is_runtime_protocol', '__parameters__', @@ -2891,11 +2879,26 @@ def __subclasshook__(cls, other): @skip_if_py312b1 def test_issubclass_fails_correctly(self): @runtime_checkable - class P(Protocol): + class NonCallableMembers(Protocol): x = 1 + + class NotRuntimeCheckable(Protocol): + def callable_member(self) -> int: ... + + @runtime_checkable + class RuntimeCheckable(Protocol): + def callable_member(self) -> int: ... + class C: pass - with self.assertRaisesRegex(TypeError, r"issubclass\(\) arg 1 must be a class"): - issubclass(C(), P) + + # These three all exercise different code paths, + # but should result in the same error message: + for protocol in NonCallableMembers, NotRuntimeCheckable, RuntimeCheckable: + with self.subTest(proto_name=protocol.__name__): + with self.assertRaisesRegex( + TypeError, r"issubclass\(\) arg 1 must be a class" + ): + issubclass(C(), protocol) def test_defining_generic_protocols(self): T = TypeVar('T') @@ -3456,6 +3459,7 @@ def method(self) -> None: ... @skip_if_early_py313_alpha def test_protocol_issubclass_error_message(self): + @runtime_checkable class Vec2D(Protocol): x: float y: float @@ -3471,6 +3475,39 @@ def square_norm(self) -> float: with self.assertRaisesRegex(TypeError, re.escape(expected_error_message)): issubclass(int, Vec2D) + def test_nonruntime_protocol_interaction_with_evil_classproperty(self): + class classproperty: + def __get__(self, instance, type): + raise RuntimeError("NO") + + class Commentable(Protocol): + evil = classproperty() + + # recognised as a protocol attr, + # but not actually accessed by the protocol metaclass + # (which would raise RuntimeError) for non-runtime protocols. + # See gh-113320 + self.assertEqual(get_protocol_members(Commentable), {"evil"}) + + def test_runtime_protocol_interaction_with_evil_classproperty(self): + class CustomError(Exception): pass + + class classproperty: + def __get__(self, instance, type): + raise CustomError + + with self.assertRaises(TypeError) as cm: + @runtime_checkable + class Commentable(Protocol): + evil = classproperty() + + exc = cm.exception + self.assertEqual( + exc.args[0], + "Failed to determine whether protocol member 'evil' is a method member" + ) + self.assertIs(type(exc.__cause__), CustomError) + class Point2DGeneric(Generic[T], TypedDict): a: T @@ -5263,7 +5300,7 @@ def test_typing_extensions_defers_when_possible(self): 'SupportsRound', 'Unpack', } if sys.version_info < (3, 13): - exclude |= {'NamedTuple', 'Protocol'} + exclude |= {'NamedTuple', 'Protocol', 'runtime_checkable'} if not hasattr(typing, 'ReadOnly'): exclude |= {'TypedDict', 'is_typeddict'} for item in typing_extensions.__all__: diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 1666e96b..7de29cd5 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -473,7 +473,7 @@ def clear_overloads(): "_is_runtime_protocol", "__dict__", "__slots__", "__parameters__", "__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__", "__subclasshook__", "__orig_class__", "__init__", "__new__", - "__protocol_attrs__", "__callable_proto_members_only__", + "__protocol_attrs__", "__non_callable_proto_members__", "__match_args__", } @@ -521,6 +521,22 @@ def _no_init(self, *args, **kwargs): if type(self)._is_protocol: raise TypeError('Protocols cannot be instantiated') + def _type_check_issubclass_arg_1(arg): + """Raise TypeError if `arg` is not an instance of `type` + in `issubclass(arg, )`. + + In most cases, this is verified by type.__subclasscheck__. + Checking it again unnecessarily would slow down issubclass() checks, + so, we don't perform this check unless we absolutely have to. + + For various error paths, however, + we want to ensure that *this* error message is shown to the user + where relevant, rather than a typing.py-specific error message. + """ + if not isinstance(arg, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + # Inheriting from typing._ProtocolMeta isn't actually desirable, # but is necessary to allow typing.Protocol and typing_extensions.Protocol # to mix without getting TypeErrors about "metaclass conflict" @@ -551,11 +567,6 @@ def __init__(cls, *args, **kwargs): abc.ABCMeta.__init__(cls, *args, **kwargs) if getattr(cls, "_is_protocol", False): cls.__protocol_attrs__ = _get_protocol_attrs(cls) - # PEP 544 prohibits using issubclass() - # with protocols that have non-method members. - cls.__callable_proto_members_only__ = all( - callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__ - ) def __subclasscheck__(cls, other): if cls is Protocol: @@ -564,26 +575,23 @@ def __subclasscheck__(cls, other): getattr(cls, '_is_protocol', False) and not _allow_reckless_class_checks() ): - if not isinstance(other, type): - # Same error message as for issubclass(1, int). - raise TypeError('issubclass() arg 1 must be a class') + if not getattr(cls, '_is_runtime_protocol', False): + _type_check_issubclass_arg_1(other) + raise TypeError( + "Instance and class checks can only be used with " + "@runtime_checkable protocols" + ) if ( - not cls.__callable_proto_members_only__ + # this attribute is set by @runtime_checkable: + cls.__non_callable_proto_members__ and cls.__dict__.get("__subclasshook__") is _proto_hook ): - non_method_attrs = sorted( - attr for attr in cls.__protocol_attrs__ - if not callable(getattr(cls, attr, None)) - ) + _type_check_issubclass_arg_1(other) + non_method_attrs = sorted(cls.__non_callable_proto_members__) raise TypeError( "Protocols with non-method members don't support issubclass()." f" Non-method members: {str(non_method_attrs)[1:-1]}." ) - if not getattr(cls, '_is_runtime_protocol', False): - raise TypeError( - "Instance and class checks can only be used with " - "@runtime_checkable protocols" - ) return abc.ABCMeta.__subclasscheck__(cls, other) def __instancecheck__(cls, instance): @@ -610,7 +618,8 @@ def __instancecheck__(cls, instance): val = inspect.getattr_static(instance, attr) except AttributeError: break - if val is None and callable(getattr(cls, attr, None)): + # this attribute is set by @runtime_checkable: + if val is None and attr not in cls.__non_callable_proto_members__: break else: return True @@ -678,8 +687,48 @@ def __init_subclass__(cls, *args, **kwargs): cls.__init__ = _no_init +if sys.version_info >= (3, 13): + runtime_checkable = typing.runtime_checkable +else: + def runtime_checkable(cls): + """Mark a protocol class as a runtime protocol. + Such protocol can be used with isinstance() and issubclass(). + Raise TypeError if applied to a non-protocol class. + This allows a simple-minded structural check very similar to + one trick ponies in collections.abc such as Iterable. + For example:: + @runtime_checkable + class Closable(Protocol): + def close(self): ... + assert isinstance(open('/some/file'), Closable) + Warning: this will check only the presence of the required methods, + not their type signatures! + """ + if not issubclass(cls, typing.Generic) or not getattr(cls, '_is_protocol', False): + raise TypeError('@runtime_checkable can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + # PEP 544 prohibits using issubclass() + # with protocols that have non-method members. + # See gh-113320 for why we compute this attribute here, + # rather than in `_ProtocolMeta.__init__` + cls.__non_callable_proto_members__ = set() + for attr in cls.__protocol_attrs__: + try: + is_callable = callable(getattr(cls, attr, None)) + except Exception as e: + raise TypeError( + f"Failed to determine whether protocol member {attr!r} " + "is a method member" + ) from e + else: + if not is_callable: + cls.__non_callable_proto_members__.add(attr) + return cls + + # The "runtime" alias exists for backwards compatibility. -runtime = runtime_checkable = typing.runtime_checkable +runtime = runtime_checkable # Our version of runtime-checkable protocols is faster on Python 3.8-3.11 From c9fdbc103c1a707669932ee42ce52e375582c6db Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sat, 20 Jan 2024 15:06:34 +0000 Subject: [PATCH 2/6] Restore being able to use `typing_extensions.runtime_checkable` on `typing.Protocol` --- src/test_typing_extensions.py | 12 ++++ src/typing_extensions.py | 125 +++++++++++++++++----------------- 2 files changed, 75 insertions(+), 62 deletions(-) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 11208d95..58dc1851 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -1872,6 +1872,18 @@ class E(C, BP): pass self.assertNotIsInstance(D(), E) self.assertNotIsInstance(E(), D) + def test_runtimecheckable_on_typing_dot_Protocol(self): + @runtime_checkable + class Foo(typing.Protocol): + x: int + + class Bar: + def __init__(self): + self.x = 42 + + self.assertIsInstance(Bar(), Foo) + self.assertNotIsInstance(object(), Foo) + def test_typing_dot_runtimecheckable_on_Protocol(self): @typing.runtime_checkable class Foo(Protocol): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 7de29cd5..10c5795d 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -687,6 +687,52 @@ def __init_subclass__(cls, *args, **kwargs): cls.__init__ = _no_init +if hasattr(typing, "is_protocol"): + is_protocol = typing.is_protocol + get_protocol_members = typing.get_protocol_members +else: + def is_protocol(tp: type, /) -> bool: + """Return True if the given type is a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, is_protocol + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> is_protocol(P) + True + >>> is_protocol(int) + False + """ + return ( + isinstance(tp, type) + and getattr(tp, '_is_protocol', False) + and tp is not Protocol + and tp is not typing.Protocol + ) + + def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: + """Return the set of members defined in a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) + frozenset({'a', 'b'}) + + Raise a TypeError for arguments that are not Protocols. + """ + if not is_protocol(tp): + raise TypeError(f'{tp!r} is not a Protocol') + if hasattr(tp, '__protocol_attrs__'): + return frozenset(tp.__protocol_attrs__) + return frozenset(_get_protocol_attrs(tp)) + + if sys.version_info >= (3, 13): runtime_checkable = typing.runtime_checkable else: @@ -708,22 +754,23 @@ def close(self): ... raise TypeError('@runtime_checkable can be only applied to protocol classes,' ' got %r' % cls) cls._is_runtime_protocol = True - # PEP 544 prohibits using issubclass() - # with protocols that have non-method members. - # See gh-113320 for why we compute this attribute here, - # rather than in `_ProtocolMeta.__init__` - cls.__non_callable_proto_members__ = set() - for attr in cls.__protocol_attrs__: - try: - is_callable = callable(getattr(cls, attr, None)) - except Exception as e: - raise TypeError( - f"Failed to determine whether protocol member {attr!r} " - "is a method member" - ) from e - else: - if not is_callable: - cls.__non_callable_proto_members__.add(attr) + if isinstance(cls, _ProtocolMeta): + # PEP 544 prohibits using issubclass() + # with protocols that have non-method members. + # See gh-113320 for why we compute this attribute here, + # rather than in `_ProtocolMeta.__init__` + cls.__non_callable_proto_members__ = set() + for attr in get_protocol_members(cls): + try: + is_callable = callable(getattr(cls, attr, None)) + except Exception as e: + raise TypeError( + f"Failed to determine whether protocol member {attr!r} " + "is a method member" + ) from e + else: + if not is_callable: + cls.__non_callable_proto_members__.add(attr) return cls @@ -2978,52 +3025,6 @@ def __ror__(self, left): return typing.Union[left, self] -if hasattr(typing, "is_protocol"): - is_protocol = typing.is_protocol - get_protocol_members = typing.get_protocol_members -else: - def is_protocol(tp: type, /) -> bool: - """Return True if the given type is a Protocol. - - Example:: - - >>> from typing_extensions import Protocol, is_protocol - >>> class P(Protocol): - ... def a(self) -> str: ... - ... b: int - >>> is_protocol(P) - True - >>> is_protocol(int) - False - """ - return ( - isinstance(tp, type) - and getattr(tp, '_is_protocol', False) - and tp is not Protocol - and tp is not typing.Protocol - ) - - def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: - """Return the set of members defined in a Protocol. - - Example:: - - >>> from typing_extensions import Protocol, get_protocol_members - >>> class P(Protocol): - ... def a(self) -> str: ... - ... b: int - >>> get_protocol_members(P) - frozenset({'a', 'b'}) - - Raise a TypeError for arguments that are not Protocols. - """ - if not is_protocol(tp): - raise TypeError(f'{tp!r} is not a Protocol') - if hasattr(tp, '__protocol_attrs__'): - return frozenset(tp.__protocol_attrs__) - return frozenset(_get_protocol_attrs(tp)) - - if hasattr(typing, "Doc"): Doc = typing.Doc else: From 8dd0b5b51ebfa428d21f2a1d4639ecec9b3a15e0 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sat, 20 Jan 2024 15:08:31 +0000 Subject: [PATCH 3/6] reduce diff --- src/typing_extensions.py | 94 ++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 10c5795d..1f03f15e 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -687,52 +687,6 @@ def __init_subclass__(cls, *args, **kwargs): cls.__init__ = _no_init -if hasattr(typing, "is_protocol"): - is_protocol = typing.is_protocol - get_protocol_members = typing.get_protocol_members -else: - def is_protocol(tp: type, /) -> bool: - """Return True if the given type is a Protocol. - - Example:: - - >>> from typing_extensions import Protocol, is_protocol - >>> class P(Protocol): - ... def a(self) -> str: ... - ... b: int - >>> is_protocol(P) - True - >>> is_protocol(int) - False - """ - return ( - isinstance(tp, type) - and getattr(tp, '_is_protocol', False) - and tp is not Protocol - and tp is not typing.Protocol - ) - - def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: - """Return the set of members defined in a Protocol. - - Example:: - - >>> from typing_extensions import Protocol, get_protocol_members - >>> class P(Protocol): - ... def a(self) -> str: ... - ... b: int - >>> get_protocol_members(P) - frozenset({'a', 'b'}) - - Raise a TypeError for arguments that are not Protocols. - """ - if not is_protocol(tp): - raise TypeError(f'{tp!r} is not a Protocol') - if hasattr(tp, '__protocol_attrs__'): - return frozenset(tp.__protocol_attrs__) - return frozenset(_get_protocol_attrs(tp)) - - if sys.version_info >= (3, 13): runtime_checkable = typing.runtime_checkable else: @@ -760,7 +714,7 @@ def close(self): ... # See gh-113320 for why we compute this attribute here, # rather than in `_ProtocolMeta.__init__` cls.__non_callable_proto_members__ = set() - for attr in get_protocol_members(cls): + for attr in cls.__protocol_attrs__: try: is_callable = callable(getattr(cls, attr, None)) except Exception as e: @@ -3025,6 +2979,52 @@ def __ror__(self, left): return typing.Union[left, self] +if hasattr(typing, "is_protocol"): + is_protocol = typing.is_protocol + get_protocol_members = typing.get_protocol_members +else: + def is_protocol(tp: type, /) -> bool: + """Return True if the given type is a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, is_protocol + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> is_protocol(P) + True + >>> is_protocol(int) + False + """ + return ( + isinstance(tp, type) + and getattr(tp, '_is_protocol', False) + and tp is not Protocol + and tp is not typing.Protocol + ) + + def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: + """Return the set of members defined in a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) + frozenset({'a', 'b'}) + + Raise a TypeError for arguments that are not Protocols. + """ + if not is_protocol(tp): + raise TypeError(f'{tp!r} is not a Protocol') + if hasattr(tp, '__protocol_attrs__'): + return frozenset(tp.__protocol_attrs__) + return frozenset(_get_protocol_attrs(tp)) + + if hasattr(typing, "Doc"): Doc = typing.Doc else: From 0d92d3bda68a1e016b1f30641e6250260d253138 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sat, 20 Jan 2024 15:10:01 +0000 Subject: [PATCH 4/6] add comment --- src/typing_extensions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 1f03f15e..1e7e9e33 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -708,6 +708,9 @@ def close(self): ... raise TypeError('@runtime_checkable can be only applied to protocol classes,' ' got %r' % cls) cls._is_runtime_protocol = True + + # Only execute the following block if it's a typing_extensions.Protocol class. + # typing.Protocol classes don't need it. if isinstance(cls, _ProtocolMeta): # PEP 544 prohibits using issubclass() # with protocols that have non-method members. @@ -725,6 +728,7 @@ def close(self): ... else: if not is_callable: cls.__non_callable_proto_members__.add(attr) + return cls From ecd711bfe88f4839797311cdee1956daaa58d6c1 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sat, 20 Jan 2024 15:11:57 +0000 Subject: [PATCH 5/6] fix docstring --- src/typing_extensions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 1e7e9e33..4007594c 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -692,15 +692,20 @@ def __init_subclass__(cls, *args, **kwargs): else: def runtime_checkable(cls): """Mark a protocol class as a runtime protocol. + Such protocol can be used with isinstance() and issubclass(). Raise TypeError if applied to a non-protocol class. This allows a simple-minded structural check very similar to one trick ponies in collections.abc such as Iterable. + For example:: + @runtime_checkable class Closable(Protocol): def close(self): ... + assert isinstance(open('/some/file'), Closable) + Warning: this will check only the presence of the required methods, not their type signatures! """ From c858502ddf1bf3474511c29b871400539f1c0acc Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 20 Jan 2024 09:55:03 -0800 Subject: [PATCH 6/6] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3957a29..be1c16d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ around 6% (backporting https://github.com/python/cpython/pull/112717, by Alex Waygood). - Fix a regression in the implementation of protocols where `typing.Protocol` - classes that were not marked as `@runtime-checkable` would be unnecessarily + classes that were not marked as `@runtime_checkable` would be unnecessarily introspected, potentially causing exceptions to be raised if the protocol had problematic members. Patch by Alex Waygood, backporting https://github.com/python/cpython/pull/113401.