Skip to content

Commit

Permalink
Fixed a bug that leads to incorrect (unsound) type narrowing when usi…
Browse files Browse the repository at this point in the history
…ng the `x in y` type guard pattern. The `in` operator uses equality checks, and `__eq__` can succeed for objects of disjoint types, which means disjointedness cannot be used as the basis for narrowing here. This change also affects the `reportUnnecessaryContains` check, which leverages the same logic. This addresses #9338. (#9868)
  • Loading branch information
erictraut authored Feb 10, 2025
1 parent ed53f3e commit 0006588
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 106 deletions.
93 changes: 1 addition & 92 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2125,7 +2125,7 @@ export class Checker extends ParseTreeWalker {
return;
}

if (this._isTypeComparable(leftSubtype, rightSubtype)) {
if (this._evaluator.isTypeComparable(leftSubtype, rightSubtype)) {
isComparable = true;
}

Expand All @@ -2151,97 +2151,6 @@ export class Checker extends ParseTreeWalker {
}
}

// Determines whether the two types are potentially comparable -- i.e.
// their types overlap in such a way that it makes sense for them to
// be compared with an == or != operator.
private _isTypeComparable(leftType: Type, rightType: Type) {
if (isAnyOrUnknown(leftType) || isAnyOrUnknown(rightType)) {
return true;
}

if (isNever(leftType) || isNever(rightType)) {
return false;
}

if (isModule(leftType) || isModule(rightType)) {
return isTypeSame(leftType, rightType, { ignoreConditions: true });
}

const isLeftCallable = isFunction(leftType) || isOverloaded(leftType);
const isRightCallable = isFunction(rightType) || isOverloaded(rightType);
if (isLeftCallable !== isRightCallable) {
return false;
}

if (isInstantiableClass(leftType) || (isClassInstance(leftType) && ClassType.isBuiltIn(leftType, 'type'))) {
if (
isInstantiableClass(rightType) ||
(isClassInstance(rightType) && ClassType.isBuiltIn(rightType, 'type'))
) {
const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined);
const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined);

if (
this._evaluator.assignType(genericLeftType, genericRightType) ||
this._evaluator.assignType(genericRightType, genericLeftType)
) {
return true;
}
}

// Does the class have an operator overload for eq?
const metaclass = leftType.shared.effectiveMetaclass;
if (metaclass && isClass(metaclass)) {
if (lookUpClassMember(metaclass, '__eq__', MemberAccessFlags.SkipObjectBaseClass)) {
return true;
}
}

return false;
}

if (isClassInstance(leftType)) {
if (isClass(rightType)) {
const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined);
const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined);

if (
this._evaluator.assignType(genericLeftType, genericRightType) ||
this._evaluator.assignType(genericRightType, genericLeftType)
) {
return true;
}

// Assume that if the types are disjoint and built-in classes that they
// will never be comparable.
if (ClassType.isBuiltIn(leftType) && ClassType.isBuiltIn(rightType) && TypeBase.isInstance(rightType)) {
return false;
}
}

// Does the class have an operator overload for eq?
const eqMethod = lookUpClassMember(
ClassType.cloneAsInstantiable(leftType),
'__eq__',
MemberAccessFlags.SkipObjectBaseClass
);

if (eqMethod) {
// If this is a synthesized method for a dataclass, we can assume
// that other dataclass types will not be comparable.
if (ClassType.isDataClass(leftType) && eqMethod.symbol.getSynthesizedType()) {
return false;
}

return true;
}

return false;
}

return true;
}

// If the function is a generator, validates that its annotated return type
// is appropriate for a generator.
private _validateGeneratorReturnType(node: FunctionNode, functionType: FunctionType) {
Expand Down
115 changes: 115 additions & 0 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25540,6 +25540,120 @@ export function createTypeEvaluator(
);
}

// Determines whether the two types are potentially comparable -- i.e.
// their types overlap in such a way that it makes sense for them to
// be compared with an == or != operator.
function isTypeComparable(leftType: Type, rightType: Type) {
if (isAnyOrUnknown(leftType) || isAnyOrUnknown(rightType)) {
return true;
}

if (isNever(leftType) || isNever(rightType)) {
return false;
}

if (isModule(leftType) || isModule(rightType)) {
return isTypeSame(leftType, rightType, { ignoreConditions: true });
}

const isLeftCallable = isFunction(leftType) || isOverloaded(leftType);
const isRightCallable = isFunction(rightType) || isOverloaded(rightType);
if (isLeftCallable !== isRightCallable) {
return false;
}

if (isInstantiableClass(leftType) || (isClassInstance(leftType) && ClassType.isBuiltIn(leftType, 'type'))) {
if (
isInstantiableClass(rightType) ||
(isClassInstance(rightType) && ClassType.isBuiltIn(rightType, 'type'))
) {
const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined);
const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined);

if (assignType(genericLeftType, genericRightType) || assignType(genericRightType, genericLeftType)) {
return true;
}
}

// Does the class have an operator overload for eq?
const metaclass = leftType.shared.effectiveMetaclass;
if (metaclass && isClass(metaclass)) {
if (lookUpClassMember(metaclass, '__eq__', MemberAccessFlags.SkipObjectBaseClass)) {
return true;
}
}

return false;
}

if (isClassInstance(leftType)) {
if (isClass(rightType)) {
const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined);
const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined);

if (assignType(genericLeftType, genericRightType) || assignType(genericRightType, genericLeftType)) {
return true;
}

// Assume that if the types are disjoint and built-in classes that they
// will never be comparable.
if (ClassType.isBuiltIn(leftType) && ClassType.isBuiltIn(rightType) && TypeBase.isInstance(rightType)) {
// We need to be careful with bool and int literals because
// they are comparable under certain circumstances.
let boolType: ClassType | undefined;
let intType: ClassType | undefined;
if (ClassType.isBuiltIn(leftType, 'bool') && ClassType.isBuiltIn(rightType, 'int')) {
boolType = leftType;
intType = rightType;
} else if (ClassType.isBuiltIn(rightType, 'bool') && ClassType.isBuiltIn(leftType, 'int')) {
boolType = rightType;
intType = leftType;
}

if (boolType && intType) {
const intVal = intType.priv?.literalValue as number | BigInt | undefined;
if (intVal === undefined) {
return true;
}
if (intVal !== 0 && intVal !== 1) {
return false;
}

const boolVal = boolType.priv?.literalValue as boolean | undefined;
if (boolVal === undefined) {
return true;
}

return boolVal === (intVal === 1);
}

return false;
}
}

// Does the class have an operator overload for eq?
const eqMethod = lookUpClassMember(
ClassType.cloneAsInstantiable(leftType),
'__eq__',
MemberAccessFlags.SkipObjectBaseClass
);

if (eqMethod) {
// If this is a synthesized method for a dataclass, we can assume
// that other dataclass types will not be comparable.
if (ClassType.isDataClass(leftType) && eqMethod.symbol.getSynthesizedType()) {
return false;
}

return true;
}

return false;
}

return true;
}

function assignToUnionType(
destType: UnionType,
srcType: Type,
Expand Down Expand Up @@ -28325,6 +28439,7 @@ export function createTypeEvaluator(
getCallSignatureInfo,
getAbstractSymbols,
narrowConstrainedTypeVar,
isTypeComparable,
assignType,
validateOverrideMethod,
validateCallArgs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ export interface TypeEvaluator {
getCallSignatureInfo: (node: CallNode, activeIndex: number, activeOrFake: boolean) => CallSignatureInfo | undefined;
getAbstractSymbols: (classType: ClassType) => AbstractSymbol[];
narrowConstrainedTypeVar: (node: ParseNode, typeVar: TypeVarType) => Type | undefined;
isTypeComparable: (leftType: Type, rightType: Type) => boolean;

assignType: (
destType: Type,
Expand Down
31 changes: 28 additions & 3 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2066,15 +2066,40 @@ export function narrowTypeForContainerElementType(evaluator: TypeEvaluator, refe
return referenceSubtype;
}

if (evaluator.assignType(referenceSubtype, elementSubtype)) {
// If the two types are disjoint (i.e. are not comparable), eliminate this subtype.
if (!evaluator.isTypeComparable(elementSubtype, referenceSubtype)) {
return undefined;
}

// If one of the two types is a literal, we can narrow to that type.
if (
isClassInstance(elementSubtype) &&
(isLiteralType(elementSubtype) || isNoneInstance(elementSubtype)) &&
evaluator.assignType(referenceSubtype, elementSubtype)
) {
return stripTypeForm(addConditionToType(elementSubtype, referenceSubtype.props?.condition));
}

if (evaluator.assignType(elementSubtype, referenceSubtype)) {
if (
isClassInstance(referenceSubtype) &&
(isLiteralType(referenceSubtype) || isNoneInstance(referenceSubtype)) &&
evaluator.assignType(elementSubtype, referenceSubtype)
) {
return stripTypeForm(addConditionToType(referenceSubtype, elementSubtype.props?.condition));
}

return undefined;
// If the element type is a known class object that is assignable to
// the reference type, we can narrow to that class object.
if (
isInstantiableClass(elementSubtype) &&
!elementSubtype.priv.includeSubclasses &&
evaluator.assignType(referenceSubtype, elementSubtype)
) {
return stripTypeForm(addConditionToType(elementSubtype, referenceSubtype.props?.condition));
}

// It's not safe to narrow.
return referenceSubtype;
});
});
}
Expand Down
48 changes: 38 additions & 10 deletions packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def func6(x: type):

def func7(x: object | bytes, y: str, z: int):
if x in (y, z):
reveal_type(x, expected_text="str | int")
reveal_type(x, expected_text="object")
else:
reveal_type(x, expected_text="object | bytes")
reveal_type(x, expected_text="str | int | object | bytes")
reveal_type(x, expected_text="object | bytes")


def func8(x: object):
Expand Down Expand Up @@ -127,13 +127,6 @@ class TD2(TypedDict):
y: str


def func11(x: dict[str, str]):
if x in (TD1(x="a"), TD2(y="b")):
reveal_type(x, expected_text="TD1 | TD2")
else:
reveal_type(x, expected_text="dict[str, str]")


T1 = TypeVar("T1", TD1, TD2)


Expand Down Expand Up @@ -175,4 +168,39 @@ def func14(x: str, y: dict[Any, Any]):

def func15(x: Any, y: dict[str, str]):
if x in y:
reveal_type(x, expected_text="str")
reveal_type(x, expected_text="Any")


def func16(x: int, y: list[Literal[0, 1]]):
if x in y:
reveal_type(x, expected_text="Literal[0, 1]")


def func17(x: Literal[-1, 0], y: list[Literal[0, 1]]):
if x in y:
reveal_type(x, expected_text="Literal[0]")


def func18(x: Literal[0, 1, 2], y: list[Literal[0, 1]]):
if x in y:
reveal_type(x, expected_text="Literal[0, 1]")


def func19(x: float, y: list[int]):
if x in y:
reveal_type(x, expected_text="float")


def func20(x: float, y: list[Literal[0, 1]]):
if x in y:
reveal_type(x, expected_text="Literal[0, 1]")


def func21(x: int, y: list[Literal[0, True]]):
if x in y:
reveal_type(x, expected_text="Literal[0, True]")


def func22(x: bool, y: list[Literal[0, 1]]):
if x in y:
reveal_type(x, expected_text="bool")
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def func3(x: list[str]):
return

# This should generate an error if "reportUnnecessaryContains" is enabled.
if x not in ([1, 2], [3]):
if x not in ((1, 2), (3,)):
pass


Expand Down

0 comments on commit 0006588

Please sign in to comment.