diff --git a/crates/ty_python_semantic/resources/mdtest/comparison/enums.md b/crates/ty_python_semantic/resources/mdtest/comparison/enums.md index e4034adf59690..ff0b480c70713 100644 --- a/crates/ty_python_semantic/resources/mdtest/comparison/enums.md +++ b/crates/ty_python_semantic/resources/mdtest/comparison/enums.md @@ -19,3 +19,63 @@ reveal_type(Answer.NO is Answer.YES) # revealed: Literal[False] reveal_type(Answer.NO is not Answer.NO) # revealed: Literal[False] reveal_type(Answer.NO is not Answer.YES) # revealed: Literal[True] ``` + +Equality result inference uses the same runtime semantics as equality narrowing. This allows us to +recognize when two enum domains cannot contain equal values, even if neither operand is a singleton: + +```py +from enum import Enum +from typing import Literal + +class Choice(str, Enum): + FIRST = "first" + SECOND = "second" + THIRD = "third" + FOURTH = "fourth" + +def compare( + left: Literal[Choice.FIRST, Choice.SECOND], + right: Literal[Choice.THIRD, Choice.FOURTH], +): + reveal_type(left == right) # revealed: Literal[False] + reveal_type(left != right) # revealed: Literal[True] +``` + +Scalar enum members from different classes compare according to their underlying values: + +```py +from enum import Enum + +class X(str, Enum): + A = "a" + B = "b" + +class Y(str, Enum): + A = "a" + B = "b" + +reveal_type(X.A == Y.A) # revealed: Literal[True] +reveal_type(X.A != Y.A) # revealed: Literal[False] +reveal_type(X.A == Y.B) # revealed: Literal[False] +``` + +When equality semantics are custom, result inference falls back to the corresponding dunder method: + +```py +from enum import Enum + +class EqResult: ... +class NeResult: ... + +class CustomEquality(Enum): + MEMBER = 1 + + def __eq__(self, other: object) -> EqResult: # error: [invalid-method-override] + return EqResult() + + def __ne__(self, other: object) -> NeResult: # error: [invalid-method-override] + return NeResult() + +reveal_type(CustomEquality.MEMBER == CustomEquality.MEMBER) # revealed: EqResult +reveal_type(CustomEquality.MEMBER != CustomEquality.MEMBER) # revealed: NeResult +``` diff --git a/crates/ty_python_semantic/resources/mdtest/doc/public_type_undeclared_symbols.md b/crates/ty_python_semantic/resources/mdtest/doc/public_type_undeclared_symbols.md index 1246e72949d23..5d9ababc5d020 100644 --- a/crates/ty_python_semantic/resources/mdtest/doc/public_type_undeclared_symbols.md +++ b/crates/ty_python_semantic/resources/mdtest/doc/public_type_undeclared_symbols.md @@ -40,6 +40,12 @@ class TestResponse: def check(self) -> None: reveal_type(self.response_class) # revealed: type[Response] reveal_type(self.response_classes) # revealed: tuple[type[Response]] + reveal_type(self.response_class == Response) # revealed: bool + + if self.response_class == Response: + true_branch: int = "not an int" # error: [invalid-assignment] + else: + false_branch: int = "not an int" # error: [invalid-assignment] class TestHtmlResponse(TestResponse): response_class = HtmlResponse @@ -92,12 +98,14 @@ class AnnotatedResponse: def check(self) -> None: reveal_type(self.response_class) # revealed: type[Response] + reveal_type(self.response_class == Response) # revealed: bool class FixedResponse: response_class: Final = Response def check(self) -> None: reveal_type(self.response_class) # revealed: + reveal_type(self.response_class == Response) # revealed: Literal[True] ``` The same widening applies to undeclared instance attributes assigned in methods: @@ -127,6 +135,24 @@ class EitherClass: reveal_type(EitherClass().value) # revealed: type[UnionA | UnionB] ``` +Module-level variables keep their narrow inferred type. In particular, class literals in an +invariant collection remain precise enough for exhaustive equality checks: + +```py +class OffsetA: ... +class OffsetB: ... + +classes = {"a": OffsetA, "b": OffsetB} + +def choose(name: str) -> None: + class_value = classes[name] + if class_value == OffsetA: + expected = 1 + elif class_value == OffsetB: + expected = 2 + reveal_type(expected) # revealed: Literal[1, 2] +``` + ## Widening of non-literal singleton types It's similarly unlikely that an unannotated attribute initialized to a singleton type (like `None`) diff --git a/crates/ty_python_semantic/src/types/equality.rs b/crates/ty_python_semantic/src/types/equality.rs index d6c5e4feb4d44..c9e874af026c1 100644 --- a/crates/ty_python_semantic/src/types/equality.rs +++ b/crates/ty_python_semantic/src/types/equality.rs @@ -221,11 +221,31 @@ pub(crate) fn equality_truthiness<'db>( left: Type<'db>, right: Type<'db>, ) -> Truthiness { - match ComparisonEvaluator::new(db).evaluate( + comparison_truthiness(db, left, right, ComparisonOperator::Equality) +} + +/// Return the truthiness of `left != right` when it is known for every represented runtime value. +/// +/// A result that only permits narrowing remains ambiguous because it can still evaluate either way. +pub(super) fn inequality_truthiness<'db>( + db: &'db dyn Db, + left: Type<'db>, + right: Type<'db>, +) -> Truthiness { + comparison_truthiness(db, left, right, ComparisonOperator::Inequality) +} + +fn comparison_truthiness<'db>( + db: &'db dyn Db, + left: Type<'db>, + right: Type<'db>, + operator: ComparisonOperator, +) -> Truthiness { + match ComparisonEvaluator::for_truthiness(db).evaluate( left, right, ComparisonBranch::Positive, - ComparisonOperator::Equality, + operator, ) { ComparisonResult::AlwaysTrue => Truthiness::AlwaysTrue, ComparisonResult::AlwaysFalse => Truthiness::AlwaysFalse, @@ -233,6 +253,36 @@ pub(crate) fn equality_truthiness<'db>( } } +/// Selects how recursive comparison results are combined. +/// +/// The goal is only an optimization; both modes use the same comparison semantics and agree on +/// which results are definite. [`Constraint`](Self::Constraint) preserves branch-specific narrowing +/// for the left operand. [`Truthiness`](Self::Truthiness) can discard those constraints because its +/// caller only needs to know whether every expanded alternative agrees, and can stop as soon as the +/// comparison cannot be definite. +/// +/// For example, truthiness evaluation proves that this comparison is always false by checking the +/// finite alternatives on both sides, without constructing a narrowing constraint: +/// +/// ```python +/// from enum import Enum +/// from typing import Literal +/// +/// class Choice(Enum): +/// A = 1 +/// B = 2 +/// C = 3 +/// D = 4 +/// +/// def compare(left: Literal[Choice.A, Choice.B], right: Literal[Choice.C, Choice.D]): +/// reveal_type(left == right) # Literal[False] +/// ``` +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ComparisonGoal { + Constraint, + Truthiness, +} + /// Identifies an active comparison evaluation. /// /// Operand order and branch are significant because the left operand is the narrowing target. @@ -248,6 +298,7 @@ struct ComparisonKey<'db> { struct ComparisonEvaluator<'db> { db: &'db dyn Db, active: FxHashSet>, + goal: ComparisonGoal, } impl<'db> ComparisonEvaluator<'db> { @@ -255,6 +306,15 @@ impl<'db> ComparisonEvaluator<'db> { Self { db, active: FxHashSet::default(), + goal: ComparisonGoal::Constraint, + } + } + + fn for_truthiness(db: &'db dyn Db) -> Self { + Self { + db, + active: FxHashSet::default(), + goal: ComparisonGoal::Truthiness, } } @@ -273,8 +333,10 @@ impl<'db> ComparisonEvaluator<'db> { /// reveal_type(x) # Any & ~EQUAL_VALUES /// ``` /// - /// `branch` selects the branch whose constraint is accumulated when either operand expands - /// into multiple alternatives. Re-entering an active comparison conservatively returns an + /// In [`ComparisonGoal::Constraint`] mode, `branch` selects the branch whose constraint is + /// accumulated when either operand expands into multiple alternatives. In + /// [`ComparisonGoal::Truthiness`] mode, expansion instead requires every alternative to agree + /// on the comparison result. Re-entering an active comparison conservatively returns an /// ambiguous result instead of recursing indefinitely. fn evaluate( &mut self, @@ -698,6 +760,14 @@ fn evaluate_union_left<'db>( branch: ComparisonBranch, operator: ComparisonOperator, ) -> ComparisonResult<'db> { + if evaluator.goal == ComparisonGoal::Truthiness { + return combine_definite_truthiness( + elements + .iter() + .map(|element| evaluator.evaluate(*element, other, branch, operator)), + ); + } + let db = evaluator.db; evaluate_target_union(db, elements, branch, |element| { evaluator.evaluate(element, other, branch, operator) @@ -791,6 +861,14 @@ fn evaluate_union_right<'db>( branch: ComparisonBranch, operator: ComparisonOperator, ) -> ComparisonResult<'db> { + if evaluator.goal == ComparisonGoal::Truthiness { + return combine_definite_truthiness( + elements + .iter() + .map(|element| evaluator.evaluate(left, *element, branch, operator)), + ); + } + let db = evaluator.db; evaluate_against_results( db, @@ -802,6 +880,34 @@ fn evaluate_union_right<'db>( ) } +/// Combine results when the caller only needs definite truthiness. +/// +/// Any ambiguous or narrowing result, or any disagreement between definite results, makes the +/// aggregate ambiguous. In each case, later alternatives cannot make it definite again. +fn combine_definite_truthiness<'db>( + results: impl IntoIterator>, +) -> ComparisonResult<'db> { + let mut definite = None; + + for result in results { + let current = match result { + ComparisonResult::AlwaysTrue => true, + ComparisonResult::AlwaysFalse => false, + ComparisonResult::CanNarrow(_) | ComparisonResult::Ambiguous => { + return ComparisonResult::Ambiguous; + } + }; + + match definite { + Some(previous) if previous != current => return ComparisonResult::Ambiguous, + Some(_) => {} + None => definite = Some(current), + } + } + + definite.map_or(ComparisonResult::Ambiguous, ComparisonResult::from_bool) +} + /// Combine comparison results produced by alternatives of the non-target operand. /// /// The target remains possible when any alternative can satisfy the selected branch; definite @@ -865,6 +971,14 @@ fn evaluate_intersection_left<'db>( branch: ComparisonBranch, operator: ComparisonOperator, ) -> ComparisonResult<'db> { + if evaluator.goal == ComparisonGoal::Truthiness { + return combine_definite_truthiness( + positive + .iter() + .map(|element| evaluator.evaluate(*element, other, branch, operator)), + ); + } + let db = evaluator.db; let mut any_true = false; let mut any_false = false; diff --git a/crates/ty_python_semantic/src/types/infer/comparisons.rs b/crates/ty_python_semantic/src/types/infer/comparisons.rs index 6b7e63111d108..6d3305191df89 100644 --- a/crates/ty_python_semantic/src/types/infer/comparisons.rs +++ b/crates/ty_python_semantic/src/types/infer/comparisons.rs @@ -7,6 +7,7 @@ use crate::types::call::{CallArguments, CallDunderError}; use crate::types::constraints::ConstraintSetBuilder; use crate::types::context::InferContext; use crate::types::cyclic::CycleDetector; +use crate::types::equality::{equality_truthiness, inequality_truthiness}; use crate::types::tuple::TupleSpec; use crate::types::{ DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, @@ -170,6 +171,15 @@ pub(super) fn infer_binary_type_comparison<'db>( } }; + let comparison_truthiness = match op { + ast::CmpOp::Eq => equality_truthiness(db, left, right), + ast::CmpOp::NotEq => inequality_truthiness(db, left, right), + _ => Truthiness::Ambiguous, + }; + if comparison_truthiness != Truthiness::Ambiguous { + return Ok(Type::from_truthiness(db, comparison_truthiness)); + } + let comparison_result = match (left, right) { (Type::EnumComplement(complement), right) => Some(infer_binary_type_comparison( context, @@ -573,28 +583,6 @@ pub(super) fn infer_binary_type_comparison<'db>( ) if matches!(op, ast::CmpOp::Eq | ast::CmpOp::NotEq) => { Some(Ok(Type::bool_literal(op == ast::CmpOp::NotEq))) } - - (LiteralValueTypeKind::Enum(literal_1), LiteralValueTypeKind::Enum(literal_2)) - if op == ast::CmpOp::Eq => - { - Some(Ok( - match try_dunder(MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK) { - Ok(ty) => ty, - Err(_) => Type::bool_literal(literal_1 == literal_2), - }, - )) - } - (LiteralValueTypeKind::Enum(literal_1), LiteralValueTypeKind::Enum(literal_2)) - if op == ast::CmpOp::NotEq => - { - Some(Ok( - match try_dunder(MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK) { - Ok(ty) => ty, - Err(_) => Type::bool_literal(literal_1 != literal_2), - }, - )) - } - _ => None, } }