Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/comparison/enums.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,63 @@ reveal_type(Answer.NO is Answer.YES) # revealed: Literal[False]
reveal_type(Answer.NO is not Answer.NO) # revealed: Literal[False]
reveal_type(Answer.NO is not Answer.YES) # revealed: Literal[True]
```

Equality result inference uses the same runtime semantics as equality narrowing. This allows us to
recognize when two enum domains cannot contain equal values, even if neither operand is a singleton:

```py
from enum import Enum
from typing import Literal

class Choice(str, Enum):
FIRST = "first"
SECOND = "second"
THIRD = "third"
FOURTH = "fourth"

def compare(
left: Literal[Choice.FIRST, Choice.SECOND],
right: Literal[Choice.THIRD, Choice.FOURTH],
):
reveal_type(left == right) # revealed: Literal[False]
reveal_type(left != right) # revealed: Literal[True]
```

Scalar enum members from different classes compare according to their underlying values:

```py
from enum import Enum

class X(str, Enum):
A = "a"
B = "b"

class Y(str, Enum):
A = "a"
B = "b"

reveal_type(X.A == Y.A) # revealed: Literal[True]
reveal_type(X.A != Y.A) # revealed: Literal[False]
reveal_type(X.A == Y.B) # revealed: Literal[False]
```

When equality semantics are custom, result inference falls back to the corresponding dunder method:

```py
from enum import Enum

class EqResult: ...
class NeResult: ...

class CustomEquality(Enum):
MEMBER = 1

def __eq__(self, other: object) -> EqResult: # error: [invalid-method-override]
return EqResult()

def __ne__(self, other: object) -> NeResult: # error: [invalid-method-override]
return NeResult()

reveal_type(CustomEquality.MEMBER == CustomEquality.MEMBER) # revealed: EqResult
reveal_type(CustomEquality.MEMBER != CustomEquality.MEMBER) # revealed: NeResult
```
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class TestResponse:
def check(self) -> None:
reveal_type(self.response_class) # revealed: type[Response]
reveal_type(self.response_classes) # revealed: tuple[type[Response]]
reveal_type(self.response_class == Response) # revealed: bool

if self.response_class == Response:
true_branch: int = "not an int" # error: [invalid-assignment]
else:
false_branch: int = "not an int" # error: [invalid-assignment]

class TestHtmlResponse(TestResponse):
response_class = HtmlResponse
Expand Down Expand Up @@ -92,12 +98,14 @@ class AnnotatedResponse:

def check(self) -> None:
reveal_type(self.response_class) # revealed: type[Response]
reveal_type(self.response_class == Response) # revealed: bool

class FixedResponse:
response_class: Final = Response

def check(self) -> None:
reveal_type(self.response_class) # revealed: <class 'Response'>
reveal_type(self.response_class == Response) # revealed: Literal[True]
```

The same widening applies to undeclared instance attributes assigned in methods:
Expand Down Expand Up @@ -127,6 +135,24 @@ class EitherClass:
reveal_type(EitherClass().value) # revealed: type[UnionA | UnionB]
```

Module-level variables keep their narrow inferred type. In particular, class literals in an
invariant collection remain precise enough for exhaustive equality checks:

```py
class OffsetA: ...
class OffsetB: ...

classes = {"a": OffsetA, "b": OffsetB}

def choose(name: str) -> None:
class_value = classes[name]
if class_value == OffsetA:
expected = 1
elif class_value == OffsetB:
expected = 2
reveal_type(expected) # revealed: Literal[1, 2]
```

## Widening of non-literal singleton types

It's similarly unlikely that an unannotated attribute initialized to a singleton type (like `None`)
Expand Down
122 changes: 118 additions & 4 deletions crates/ty_python_semantic/src/types/equality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,18 +221,68 @@ pub(crate) fn equality_truthiness<'db>(
left: Type<'db>,
right: Type<'db>,
) -> Truthiness {
match ComparisonEvaluator::new(db).evaluate(
comparison_truthiness(db, left, right, ComparisonOperator::Equality)
}

/// Return the truthiness of `left != right` when it is known for every represented runtime value.
///
/// A result that only permits narrowing remains ambiguous because it can still evaluate either way.
pub(super) fn inequality_truthiness<'db>(
db: &'db dyn Db,
left: Type<'db>,
right: Type<'db>,
) -> Truthiness {
comparison_truthiness(db, left, right, ComparisonOperator::Inequality)
}

fn comparison_truthiness<'db>(
db: &'db dyn Db,
left: Type<'db>,
right: Type<'db>,
operator: ComparisonOperator,
) -> Truthiness {
match ComparisonEvaluator::for_truthiness(db).evaluate(
left,
right,
ComparisonBranch::Positive,
ComparisonOperator::Equality,
operator,
) {
ComparisonResult::AlwaysTrue => Truthiness::AlwaysTrue,
ComparisonResult::AlwaysFalse => Truthiness::AlwaysFalse,
ComparisonResult::CanNarrow(_) | ComparisonResult::Ambiguous => Truthiness::Ambiguous,
}
}

/// Selects how recursive comparison results are combined.
///
/// The goal is only an optimization; both modes use the same comparison semantics and agree on
/// which results are definite. [`Constraint`](Self::Constraint) preserves branch-specific narrowing
/// for the left operand. [`Truthiness`](Self::Truthiness) can discard those constraints because its
/// caller only needs to know whether every expanded alternative agrees, and can stop as soon as the
/// comparison cannot be definite.
///
/// For example, truthiness evaluation proves that this comparison is always false by checking the
/// finite alternatives on both sides, without constructing a narrowing constraint:
///
/// ```python
/// from enum import Enum
/// from typing import Literal
///
/// class Choice(Enum):
/// A = 1
/// B = 2
/// C = 3
/// D = 4
///
/// def compare(left: Literal[Choice.A, Choice.B], right: Literal[Choice.C, Choice.D]):
/// reveal_type(left == right) # Literal[False]
/// ```
Comment on lines +256 to +279

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is clever. Thanks for documenting it so clearly!

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum ComparisonGoal {
Constraint,
Truthiness,
}

/// Identifies an active comparison evaluation.
///
/// Operand order and branch are significant because the left operand is the narrowing target.
Expand All @@ -248,13 +298,23 @@ struct ComparisonKey<'db> {
struct ComparisonEvaluator<'db> {
db: &'db dyn Db,
active: FxHashSet<ComparisonKey<'db>>,
goal: ComparisonGoal,
}

impl<'db> ComparisonEvaluator<'db> {
fn new(db: &'db dyn Db) -> Self {
Self {
db,
active: FxHashSet::default(),
goal: ComparisonGoal::Constraint,
}
}

fn for_truthiness(db: &'db dyn Db) -> Self {
Self {
db,
active: FxHashSet::default(),
goal: ComparisonGoal::Truthiness,
}
}

Expand All @@ -273,8 +333,10 @@ impl<'db> ComparisonEvaluator<'db> {
/// reveal_type(x) # Any & ~EQUAL_VALUES
/// ```
///
/// `branch` selects the branch whose constraint is accumulated when either operand expands
/// into multiple alternatives. Re-entering an active comparison conservatively returns an
/// In [`ComparisonGoal::Constraint`] mode, `branch` selects the branch whose constraint is
/// accumulated when either operand expands into multiple alternatives. In
/// [`ComparisonGoal::Truthiness`] mode, expansion instead requires every alternative to agree
/// on the comparison result. Re-entering an active comparison conservatively returns an
/// ambiguous result instead of recursing indefinitely.
fn evaluate(
&mut self,
Expand Down Expand Up @@ -698,6 +760,14 @@ fn evaluate_union_left<'db>(
branch: ComparisonBranch,
operator: ComparisonOperator,
) -> ComparisonResult<'db> {
if evaluator.goal == ComparisonGoal::Truthiness {
return combine_definite_truthiness(
elements
.iter()
.map(|element| evaluator.evaluate(*element, other, branch, operator)),
);
}

let db = evaluator.db;
evaluate_target_union(db, elements, branch, |element| {
evaluator.evaluate(element, other, branch, operator)
Expand Down Expand Up @@ -791,6 +861,14 @@ fn evaluate_union_right<'db>(
branch: ComparisonBranch,
operator: ComparisonOperator,
) -> ComparisonResult<'db> {
if evaluator.goal == ComparisonGoal::Truthiness {
return combine_definite_truthiness(
elements
.iter()
.map(|element| evaluator.evaluate(left, *element, branch, operator)),
);
}

let db = evaluator.db;
evaluate_against_results(
db,
Expand All @@ -802,6 +880,34 @@ fn evaluate_union_right<'db>(
)
}

/// Combine results when the caller only needs definite truthiness.
///
/// Any ambiguous or narrowing result, or any disagreement between definite results, makes the
/// aggregate ambiguous. In each case, later alternatives cannot make it definite again.
fn combine_definite_truthiness<'db>(
results: impl IntoIterator<Item = ComparisonResult<'db>>,
) -> ComparisonResult<'db> {
let mut definite = None;

for result in results {
let current = match result {
ComparisonResult::AlwaysTrue => true,
ComparisonResult::AlwaysFalse => false,
ComparisonResult::CanNarrow(_) | ComparisonResult::Ambiguous => {
return ComparisonResult::Ambiguous;
}
};

match definite {
Some(previous) if previous != current => return ComparisonResult::Ambiguous,
Some(_) => {}
None => definite = Some(current),
}
}

definite.map_or(ComparisonResult::Ambiguous, ComparisonResult::from_bool)
}

/// Combine comparison results produced by alternatives of the non-target operand.
///
/// The target remains possible when any alternative can satisfy the selected branch; definite
Expand Down Expand Up @@ -865,6 +971,14 @@ fn evaluate_intersection_left<'db>(
branch: ComparisonBranch,
operator: ComparisonOperator,
) -> ComparisonResult<'db> {
if evaluator.goal == ComparisonGoal::Truthiness {
return combine_definite_truthiness(
positive
.iter()
.map(|element| evaluator.evaluate(*element, other, branch, operator)),
);
}

let db = evaluator.db;
let mut any_true = false;
let mut any_false = false;
Expand Down
32 changes: 10 additions & 22 deletions crates/ty_python_semantic/src/types/infer/comparisons.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::types::call::{CallArguments, CallDunderError};
use crate::types::constraints::ConstraintSetBuilder;
use crate::types::context::InferContext;
use crate::types::cyclic::CycleDetector;
use crate::types::equality::{equality_truthiness, inequality_truthiness};
use crate::types::tuple::TupleSpec;
use crate::types::{
DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType,
Expand Down Expand Up @@ -170,6 +171,15 @@ pub(super) fn infer_binary_type_comparison<'db>(
}
};

let comparison_truthiness = match op {
ast::CmpOp::Eq => equality_truthiness(db, left, right),
ast::CmpOp::NotEq => inequality_truthiness(db, left, right),
_ => Truthiness::Ambiguous,
};
if comparison_truthiness != Truthiness::Ambiguous {
return Ok(Type::from_truthiness(db, comparison_truthiness));
}

let comparison_result = match (left, right) {
(Type::EnumComplement(complement), right) => Some(infer_binary_type_comparison(
context,
Expand Down Expand Up @@ -573,28 +583,6 @@ pub(super) fn infer_binary_type_comparison<'db>(
) if matches!(op, ast::CmpOp::Eq | ast::CmpOp::NotEq) => {
Some(Ok(Type::bool_literal(op == ast::CmpOp::NotEq)))
}

(LiteralValueTypeKind::Enum(literal_1), LiteralValueTypeKind::Enum(literal_2))
if op == ast::CmpOp::Eq =>
{
Some(Ok(
match try_dunder(MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK) {
Ok(ty) => ty,
Err(_) => Type::bool_literal(literal_1 == literal_2),
},
))
}
(LiteralValueTypeKind::Enum(literal_1), LiteralValueTypeKind::Enum(literal_2))
if op == ast::CmpOp::NotEq =>
{
Some(Ok(
match try_dunder(MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK) {
Ok(ty) => ty,
Err(_) => Type::bool_literal(literal_1 != literal_2),
},
))
}

_ => None,
}
}
Expand Down
Loading