From 8abc3ff07179e1281fb8bce68d0506f7aac76230 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Tue, 11 Feb 2025 09:49:37 -0800 Subject: [PATCH] Exempted `__slots__` symbol when determining whether a class is a callback protocol. This addresses #9878. (#9879) --- .../src/analyzer/typeEvaluator.ts | 30 +++++++++++++------ .../src/tests/samples/callbackProtocol5.py | 1 + 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index f4427db0d6aa..8a6ea47508b7 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -25928,19 +25928,31 @@ export function createTypeEvaluator( for (const mroClass of objType.shared.mro) { if (isClass(mroClass) && ClassType.isProtocolClass(mroClass)) { for (const field of ClassType.getSymbolTable(mroClass)) { - if (field[0] !== '__call__' && !field[1].isIgnoredForProtocolMatch()) { - let fieldIsPartOfFunction = false; + const fieldName = field[0]; + const fieldSymbol = field[1]; - if (prefetched?.functionClass && isClass(prefetched.functionClass)) { - if (ClassType.getSymbolTable(prefetched.functionClass).has(field[0])) { - fieldIsPartOfFunction = true; - } - } + // We're expecting a __call__ method. We will also ignore a + // __slots__ definition, which is (by convention) ignored for + // protocol matching. + if (fieldName === '__call__' || fieldName === '__slots__') { + continue; + } - if (!fieldIsPartOfFunction) { - return undefined; + if (fieldSymbol.isIgnoredForProtocolMatch()) { + continue; + } + + let fieldIsPartOfFunction = false; + + if (prefetched?.functionClass && isClass(prefetched.functionClass)) { + if (ClassType.getSymbolTable(prefetched.functionClass).has(field[0])) { + fieldIsPartOfFunction = true; } } + + if (!fieldIsPartOfFunction) { + return undefined; + } } } } diff --git a/packages/pyright-internal/src/tests/samples/callbackProtocol5.py b/packages/pyright-internal/src/tests/samples/callbackProtocol5.py index 6f3c498dd102..249f7d799425 100644 --- a/packages/pyright-internal/src/tests/samples/callbackProtocol5.py +++ b/packages/pyright-internal/src/tests/samples/callbackProtocol5.py @@ -54,6 +54,7 @@ class CallbackProto2(Protocol): __module__: str __qualname__: str __annotations__: dict[str, Any] + __slots__ = () def __call__(self) -> None: ...