diff --git a/crates/ty_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md b/crates/ty_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md index f880106c0f6db9..895faaf224f8c3 100644 --- a/crates/ty_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md +++ b/crates/ty_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md @@ -494,10 +494,12 @@ def compare_str_bound(a: V, b: V) -> bool: ### Constrained TypeVar comparisons -Constrained TypeVars support comparisons if all constraints support the operation: +Constrained TypeVars support comparisons if all constraints support the operation. Comparisons +between two occurrences of the same constrained TypeVar preserve the correlation that both +occurrences have the same specialization: ```py -from typing import TypeVar +from typing import Literal, TypeVar W = TypeVar("W", int, str) @@ -510,6 +512,12 @@ X = TypeVar("X", int, str) def compare_constrained_lt(a: X, b: X) -> bool: # Both int and str support < return a < b + +Y = TypeVar("Y", Literal[1], Literal[2]) + +def compare_same_constrained_literal(value: Y): + reveal_type(value == value) # revealed: Literal[True] + reveal_type(value != value) # revealed: Literal[False] ``` ### TypeVar with `complex` bound diff --git a/crates/ty_python_semantic/resources/mdtest/comparison/unions.md b/crates/ty_python_semantic/resources/mdtest/comparison/unions.md index a2418d71b1e650..b4d5c2c801f0fc 100644 --- a/crates/ty_python_semantic/resources/mdtest/comparison/unions.md +++ b/crates/ty_python_semantic/resources/mdtest/comparison/unions.md @@ -63,6 +63,33 @@ def _(small: Literal[1, 2], large: Literal[2, 3]): reveal_type(small > large) # revealed: Literal[False] ``` +Equality inference still preserves custom return types for every union arm: + +```py +class AEq: ... +class ANe: ... +class BEq: ... +class BNe: ... + +class A: + def __eq__(self, other: object) -> AEq: # error: [invalid-method-override] + return AEq() + + def __ne__(self, other: object) -> ANe: # error: [invalid-method-override] + return ANe() + +class B: + def __eq__(self, other: object) -> BEq: # error: [invalid-method-override] + return BEq() + + def __ne__(self, other: object) -> BNe: # error: [invalid-method-override] + return BNe() + +def _(value: A | B): + reveal_type(value == object()) # revealed: AEq | BEq + reveal_type(value != object()) # revealed: ANe | BNe +``` + ## Unsupported operations Make sure we emit a diagnostic if *any* of the possible comparisons is unsupported. For now, we fall diff --git a/crates/ty_python_semantic/src/types/equality.rs b/crates/ty_python_semantic/src/types/equality.rs index 6faf396dc6d83d..95f9de262585bf 100644 --- a/crates/ty_python_semantic/src/types/equality.rs +++ b/crates/ty_python_semantic/src/types/equality.rs @@ -10,10 +10,12 @@ use rustc_hash::FxHashSet; use crate::{Db, place::PlaceAndQualifiers}; use super::{ - EnumLiteralType, IntersectionBuilder, KnownBoundMethodType, KnownClass, LiteralValueType, - LiteralValueTypeKind, MemberLookupPolicy, Truthiness, Type, TypeVarBoundOrConstraints, - UnionBuilder, + EnumLiteralType, IntersectionBuilder, KnownBoundMethodType, KnownClass, KnownInstanceType, + LiteralValueType, LiteralValueTypeKind, MemberLookupPolicy, Truthiness, Type, + TypeVarBoundOrConstraints, UnionBuilder, + constraints::ConstraintSetBuilder, enums::{enum_member_literals, enum_metadata}, + tuple::TupleSpec, }; mod enums; @@ -57,6 +59,12 @@ enum ComparisonResult<'db> { /// The comparison may evaluate to true or false, depending on runtime values. Ambiguous, + + /// The comparison may evaluate to true or false, but its result is known to be a `bool`. + /// + /// This distinction lets expression inference avoid repeating comparison dispatch merely to + /// recover the builtin `bool` return type. + AmbiguousBoolean, } /// The branch of a comparison for which a narrowing constraint is being computed. @@ -102,7 +110,7 @@ impl<'db> ComparisonResult<'db> { (branch == ComparisonBranch::Positive).then_some(Type::Never) } ComparisonResult::CanNarrow(narrowed) => Some(narrowed), - ComparisonResult::Ambiguous => None, + ComparisonResult::Ambiguous | ComparisonResult::AmbiguousBoolean => None, } } @@ -118,6 +126,17 @@ impl<'db> ComparisonResult<'db> { result => result, } } + + /// Defer ambiguous result-type inference to expression inference. + /// + /// Definite comparison results are preserved; only the knowledge that an ambiguous result is a + /// builtin `bool` is discarded. + fn defer(self) -> Self { + match self { + ComparisonResult::AmbiguousBoolean => ComparisonResult::Ambiguous, + result => result, + } + } } /// Return a constraint for `left` in a branch where `left == right` has the given truthiness. @@ -245,35 +264,30 @@ pub(crate) fn equality_truthiness<'db>( left: Type<'db>, right: Type<'db>, ) -> Truthiness { - comparison_truthiness(db, left, right, ComparisonOperator::Equality) + equality_result_truthiness(db, left, right, ComparisonOperator::Equality) + .unwrap_or(Truthiness::Ambiguous) } -/// Return the truthiness of `left != right` when it is known for every represented runtime value. +/// Return the result truthiness when equality or inequality is known to return `bool`. /// -/// A result that only permits narrowing remains ambiguous because it can still evaluate either way. -pub(super) fn inequality_truthiness<'db>( - db: &'db dyn Db, - left: Type<'db>, - right: Type<'db>, -) -> Truthiness { - comparison_truthiness(db, left, right, ComparisonOperator::Inequality) -} - -fn comparison_truthiness<'db>( +/// Returns `None` when expression inference must inspect the comparison methods' actual return +/// types. +pub(super) fn equality_result_truthiness<'db>( db: &'db dyn Db, left: Type<'db>, right: Type<'db>, operator: ComparisonOperator, -) -> Truthiness { +) -> Option { match ComparisonEvaluator::for_truthiness(db).evaluate( left, right, ComparisonBranch::Positive, operator, ) { - ComparisonResult::AlwaysTrue => Truthiness::AlwaysTrue, - ComparisonResult::AlwaysFalse => Truthiness::AlwaysFalse, - ComparisonResult::CanNarrow(_) | ComparisonResult::Ambiguous => Truthiness::Ambiguous, + ComparisonResult::AlwaysTrue => Some(Truthiness::AlwaysTrue), + ComparisonResult::AlwaysFalse => Some(Truthiness::AlwaysFalse), + ComparisonResult::AmbiguousBoolean => Some(Truthiness::Ambiguous), + ComparisonResult::CanNarrow(_) | ComparisonResult::Ambiguous => None, } } @@ -281,9 +295,9 @@ fn comparison_truthiness<'db>( /// /// The goal is only an optimization; both modes use the same comparison semantics and agree on /// which results are definite. [`Constraint`](Self::Constraint) preserves branch-specific narrowing -/// for the left operand. [`Truthiness`](Self::Truthiness) can discard those constraints because its -/// caller only needs to know whether every expanded alternative agrees, and can stop as soon as the -/// comparison cannot be definite. +/// for the left operand. [`Truthiness`](Self::Truthiness) can discard those constraints, but also +/// tracks whether an ambiguous result is known to be a `bool` so expression inference can avoid a +/// second dispatch. /// /// For example, truthiness evaluation proves that this comparison is always false by checking the /// finite alternatives on both sides, without constructing a narrowing constraint: @@ -453,6 +467,8 @@ fn evaluate_comparison_once<'db>( } (_, Type::Dynamic(_)) => ComparisonResult::Ambiguous, + // Expression inference preserves correlations with a constrained TypeVar that can make a + // later comparison definite; expanding it here loses that information. (Type::TypeVar(var), other) => match var.typevar(db).bound_or_constraints(db) { None => ComparisonResult::Ambiguous, Some(TypeVarBoundOrConstraints::UpperBound(_)) => { @@ -464,14 +480,14 @@ fn evaluate_comparison_once<'db>( ComparisonResult::Ambiguous } } - Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { - evaluator.evaluate(constraints.as_type(db), other, branch, operator) - } + Some(TypeVarBoundOrConstraints::Constraints(constraints)) => evaluator + .evaluate(constraints.as_type(db), other, branch, operator) + .defer(), }, (other, Type::TypeVar(var)) => match var.typevar(db).bound_or_constraints(db) { - Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { - evaluator.evaluate(other, constraints.as_type(db), branch, operator) - } + Some(TypeVarBoundOrConstraints::Constraints(constraints)) => evaluator + .evaluate(other, constraints.as_type(db), branch, operator) + .defer(), None | Some(TypeVarBoundOrConstraints::UpperBound(_)) => ComparisonResult::Ambiguous, }, @@ -530,10 +546,11 @@ fn evaluate_comparison_once<'db>( LiteralOperand::Other, ), - (Type::TypedDict(_), Type::TypedDict(_)) => ComparisonResult::Ambiguous, + (Type::TypedDict(_), Type::TypedDict(_)) => ComparisonResult::AmbiguousBoolean, (Type::TypedDict(_), other) | (other, Type::TypedDict(_)) => { match KnownComparisonSemantics::of_type(db, other, operator) { - Some(KnownComparisonSemantics::Dict) | None => ComparisonResult::Ambiguous, + Some(KnownComparisonSemantics::Dict) => ComparisonResult::AmbiguousBoolean, + None => ComparisonResult::Ambiguous, Some(_) => operator.result_from_equality(false), } } @@ -559,6 +576,15 @@ fn evaluate_comparison_once<'db>( Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall(left_function)), Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall(right_function)), ) if left_function == right_function => operator.result_from_equality(true), + ( + Type::KnownInstance(KnownInstanceType::ConstraintSet(left)), + Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), + ) => { + let constraints = ConstraintSetBuilder::new(); + let left = constraints.load(db, left.constraints(db)); + let right = constraints.load(db, right.constraints(db)); + operator.result_from_equality(left.iff(db, &constraints, right).is_always_satisfied(db)) + } (Type::KnownInstance(left_instance), Type::KnownInstance(right_instance)) if left_instance == right_instance && left.is_single_valued(db) @@ -574,6 +600,10 @@ fn evaluate_comparison_once<'db>( } (Type::NominalInstance(left_instance), Type::NominalInstance(right_instance)) => { + if let Some(result) = compare_tuples(evaluator, left_instance, right_instance, operator) + { + return result; + } compare_nominal_instances(db, left_instance, right_instance, operator) } @@ -795,7 +825,7 @@ fn evaluate_union_left<'db>( operator: ComparisonOperator, ) -> ComparisonResult<'db> { if evaluator.goal == ComparisonGoal::Truthiness { - return combine_definite_truthiness( + return combine_truthiness( elements .iter() .map(|element| evaluator.evaluate(*element, other, branch, operator)), @@ -855,7 +885,7 @@ fn evaluate_target_union<'db>( all_false = false; narrowed.push(Some(narrowed_element)); } - ComparisonResult::Ambiguous => { + ComparisonResult::Ambiguous | ComparisonResult::AmbiguousBoolean => { all_true = false; all_false = false; narrowed.push(Some(*element)); @@ -896,7 +926,7 @@ fn evaluate_union_right<'db>( operator: ComparisonOperator, ) -> ComparisonResult<'db> { if evaluator.goal == ComparisonGoal::Truthiness { - return combine_definite_truthiness( + return combine_truthiness( elements .iter() .map(|element| evaluator.evaluate(left, *element, branch, operator)), @@ -914,32 +944,41 @@ fn evaluate_union_right<'db>( ) } -/// Combine results when the caller only needs definite truthiness. +/// Combine results when the caller only needs truthiness and the builtin result type. /// -/// Any ambiguous or narrowing result, or any disagreement between definite results, makes the -/// aggregate ambiguous. In each case, later alternatives cannot make it definite again. -fn combine_definite_truthiness<'db>( +/// Disagreement between definite results, or an ambiguous builtin result, produces an ambiguous +/// boolean. An unknown or narrowing result still requires expression inference. +fn combine_truthiness<'db>( results: impl IntoIterator>, ) -> ComparisonResult<'db> { - let mut definite = None; + let mut any = false; + let mut all_true = true; + let mut all_false = true; for result in results { - let current = match result { - ComparisonResult::AlwaysTrue => true, - ComparisonResult::AlwaysFalse => false, + any = true; + match result { + ComparisonResult::AlwaysTrue => all_false = false, + ComparisonResult::AlwaysFalse => all_true = false, + ComparisonResult::AmbiguousBoolean => { + all_true = false; + all_false = false; + } ComparisonResult::CanNarrow(_) | ComparisonResult::Ambiguous => { return ComparisonResult::Ambiguous; } - }; - - match definite { - Some(previous) if previous != current => return ComparisonResult::Ambiguous, - Some(_) => {} - None => definite = Some(current), } } - definite.map_or(ComparisonResult::Ambiguous, ComparisonResult::from_bool) + if !any { + ComparisonResult::Ambiguous + } else if all_true { + ComparisonResult::AlwaysTrue + } else if all_false { + ComparisonResult::AlwaysFalse + } else { + ComparisonResult::AmbiguousBoolean + } } /// Combine comparison results produced by alternatives of the non-target operand. @@ -977,7 +1016,7 @@ fn evaluate_against_results<'db>( all_false = false; builder = builder.add(narrowed); } - ComparisonResult::Ambiguous => { + ComparisonResult::Ambiguous | ComparisonResult::AmbiguousBoolean => { all_true = false; all_false = false; builder = builder.add(target); @@ -1006,7 +1045,7 @@ fn evaluate_intersection_left<'db>( operator: ComparisonOperator, ) -> ComparisonResult<'db> { if evaluator.goal == ComparisonGoal::Truthiness { - return combine_definite_truthiness( + return combine_truthiness( positive .iter() .map(|element| evaluator.evaluate(*element, other, branch, operator)), @@ -1028,7 +1067,9 @@ fn evaluate_intersection_left<'db>( any_narrowing = true; builder = builder.add_positive(narrowed); } - ComparisonResult::Ambiguous => any_ambiguous = true, + ComparisonResult::Ambiguous | ComparisonResult::AmbiguousBoolean => { + any_ambiguous = true; + } } } @@ -1147,7 +1188,7 @@ fn compare_literal_to_other<'db>( ) -> ComparisonResult<'db> { if matches!(literal, LiteralValueTypeKind::LiteralString) { return match KnownComparisonSemantics::of_type(db, other, operator) { - Some(KnownComparisonSemantics::Str) => ComparisonResult::Ambiguous, + Some(KnownComparisonSemantics::Str) => ComparisonResult::AmbiguousBoolean, Some(_) => ComparisonResult::from_bool(operator == ComparisonOperator::Inequality), None => ComparisonResult::Ambiguous, }; @@ -1177,7 +1218,7 @@ fn compare_literal_to_other<'db>( { ComparisonResult::CanNarrow(literal_type.negate_if(db, !condition_expects_equality)) } - Some(_) => ComparisonResult::Ambiguous, + Some(_) => ComparisonResult::AmbiguousBoolean, None if literal_operand == LiteralOperand::Other && !condition_expects_equality && literal_type.is_single_valued(db) => @@ -1188,6 +1229,70 @@ fn compare_literal_to_other<'db>( } } +/// Evaluate tuple equality from element comparisons when every ambiguous result is known boolean. +/// +/// Unknown element return types are left to expression inference, which also diagnoses invalid +/// boolean conversions performed by tuple comparison. +fn compare_tuples<'db>( + evaluator: &mut ComparisonEvaluator<'db>, + left_instance: super::NominalInstanceType<'db>, + right_instance: super::NominalInstanceType<'db>, + operator: ComparisonOperator, +) -> Option> { + let db = evaluator.db; + let left = Type::NominalInstance(left_instance); + let right = Type::NominalInstance(right_instance); + let left_tuple = left_instance.tuple_spec(db)?; + let right_tuple = right_instance.tuple_spec(db)?; + + if KnownComparisonSemantics::of_type(db, left, operator) + != Some(KnownComparisonSemantics::Tuple) + || KnownComparisonSemantics::of_type(db, right, operator) + != Some(KnownComparisonSemantics::Tuple) + { + return None; + } + + if left == right && left.is_singleton(db) { + return Some(operator.result_from_equality(true)); + } + + let (TupleSpec::Fixed(left), TupleSpec::Fixed(right)) = + (left_tuple.as_ref(), right_tuple.as_ref()) + else { + return Some(ComparisonResult::AmbiguousBoolean); + }; + + let mut ambiguous = false; + for (left, right) in left.iter_all_elements().zip(right.iter_all_elements()) { + match evaluator.evaluate( + left, + right, + ComparisonBranch::Positive, + ComparisonOperator::Equality, + ) { + ComparisonResult::AlwaysTrue => {} + ComparisonResult::AlwaysFalse => { + return Some(operator.result_from_equality(false)); + } + ComparisonResult::AmbiguousBoolean => ambiguous = true, + ComparisonResult::CanNarrow(_) | ComparisonResult::Ambiguous => { + // Inference still needs to inspect the element's actual return type so that tuple + // comparison can diagnose an invalid boolean conversion. + return Some(ComparisonResult::Ambiguous); + } + } + } + + if left.len() != right.len() { + Some(operator.result_from_equality(false)) + } else if ambiguous { + Some(ComparisonResult::AmbiguousBoolean) + } else { + Some(operator.result_from_equality(true)) + } +} + /// Compare nominal instances when their inherited comparison implementations are known. /// /// The result is definite only when the implementations cannot compare equal, or when both types @@ -1218,12 +1323,13 @@ fn compare_nominal_instances<'db>( if left == right && left.is_singleton(db) { ComparisonResult::from_bool(operator == ComparisonOperator::Equality) } else { - ComparisonResult::Ambiguous + ComparisonResult::AmbiguousBoolean } } +/// The equality operation whose runtime semantics are being evaluated. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -enum ComparisonOperator { +pub(super) enum ComparisonOperator { Equality, Inequality, } diff --git a/crates/ty_python_semantic/src/types/infer/comparisons.rs b/crates/ty_python_semantic/src/types/infer/comparisons.rs index 6d3305191df892..a0cf75b1157278 100644 --- a/crates/ty_python_semantic/src/types/infer/comparisons.rs +++ b/crates/ty_python_semantic/src/types/infer/comparisons.rs @@ -4,15 +4,14 @@ use smallvec::SmallVec; use crate::Db; use crate::types::call::{CallArguments, CallDunderError}; -use crate::types::constraints::ConstraintSetBuilder; use crate::types::context::InferContext; use crate::types::cyclic::CycleDetector; -use crate::types::equality::{equality_truthiness, inequality_truthiness}; +use crate::types::equality::{ComparisonOperator as EqualityOperator, equality_result_truthiness}; use crate::types::tuple::TupleSpec; use crate::types::{ - DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, - LiteralValueType, LiteralValueTypeKind, MemberLookupPolicy, Type, TypeContext, - TypeVarBoundOrConstraints, UnionBuilder, + DynamicType, IntersectionBuilder, IntersectionType, KnownClass, LiteralValueType, + LiteralValueTypeKind, MemberLookupPolicy, Type, TypeContext, TypeVarBoundOrConstraints, + UnionBuilder, }; use ty_python_core::Truthiness; @@ -95,6 +94,35 @@ impl From for ast::CmpOp { } } +/// A comparison operator not handled by the equality evaluator. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum NonEqualityOperator { + Lt, + LtE, + Gt, + GtE, + In, + NotIn, + Is, + IsNot, +} + +impl NonEqualityOperator { + const fn from_ast(operator: ast::CmpOp) -> Option { + match operator { + ast::CmpOp::Eq | ast::CmpOp::NotEq => None, + ast::CmpOp::Lt => Some(Self::Lt), + ast::CmpOp::LtE => Some(Self::LtE), + ast::CmpOp::Gt => Some(Self::Gt), + ast::CmpOp::GtE => Some(Self::GtE), + ast::CmpOp::In => Some(Self::In), + ast::CmpOp::NotIn => Some(Self::NotIn), + ast::CmpOp::Is => Some(Self::Is), + ast::CmpOp::IsNot => Some(Self::IsNot), + } + } +} + /// Context for a failed comparison operation. /// /// `left_ty` and `right_ty` are the "low-level" types @@ -171,13 +199,15 @@ pub(super) fn infer_binary_type_comparison<'db>( } }; - let comparison_truthiness = match op { - ast::CmpOp::Eq => equality_truthiness(db, left, right), - ast::CmpOp::NotEq => inequality_truthiness(db, left, right), - _ => Truthiness::Ambiguous, + let equality_operator = match op { + ast::CmpOp::Eq => Some(EqualityOperator::Equality), + ast::CmpOp::NotEq => Some(EqualityOperator::Inequality), + _ => None, }; - if comparison_truthiness != Truthiness::Ambiguous { - return Ok(Type::from_truthiness(db, comparison_truthiness)); + if let Some(operator) = equality_operator + && let Some(truthiness) = equality_result_truthiness(db, left, right, operator) + { + return Ok(Type::from_truthiness(db, truthiness)); } let comparison_result = match (left, right) { @@ -405,25 +435,26 @@ pub(super) fn infer_binary_type_comparison<'db>( } (Type::LiteralValue(left_literal), Type::LiteralValue(right_literal)) => { - match (left_literal.kind(), right_literal.kind()) { + NonEqualityOperator::from_ast(op).and_then(|operator| match ( + left_literal.kind(), + right_literal.kind(), + ) { (LiteralValueTypeKind::Int(n), LiteralValueTypeKind::Int(m)) => { - Some(match op { - ast::CmpOp::Eq => Ok(Type::bool_literal(n == m)), - ast::CmpOp::NotEq => Ok(Type::bool_literal(n != m)), - ast::CmpOp::Lt => Ok(Type::bool_literal(n < m)), - ast::CmpOp::LtE => Ok(Type::bool_literal(n <= m)), - ast::CmpOp::Gt => Ok(Type::bool_literal(n > m)), - ast::CmpOp::GtE => Ok(Type::bool_literal(n >= m)), + Some(match operator { + NonEqualityOperator::Lt => Ok(Type::bool_literal(n < m)), + NonEqualityOperator::LtE => Ok(Type::bool_literal(n <= m)), + NonEqualityOperator::Gt => Ok(Type::bool_literal(n > m)), + NonEqualityOperator::GtE => Ok(Type::bool_literal(n >= m)), // We cannot say that two equal int Literals will return True from an `is` or `is not` comparison. // Even if they are the same value, they may not be the same object. - ast::CmpOp::Is => { + NonEqualityOperator::Is => { if n == m { Ok(KnownClass::Bool.to_instance(db)) } else { Ok(Type::bool_literal(false)) } } - ast::CmpOp::IsNot => { + NonEqualityOperator::IsNot => { if n == m { Ok(KnownClass::Bool.to_instance(db)) } else { @@ -431,11 +462,13 @@ pub(super) fn infer_binary_type_comparison<'db>( } } // Undefined for (int, int) - ast::CmpOp::In | ast::CmpOp::NotIn => Err(UnsupportedComparisonError { - op, - left_ty: left, - right_ty: right, - }), + NonEqualityOperator::In | NonEqualityOperator::NotIn => { + Err(UnsupportedComparisonError { + op, + left_ty: left, + right_ty: right, + }) + } }) } // Booleans are coded as integers (False = 0, True = 1) @@ -491,23 +524,25 @@ pub(super) fn infer_binary_type_comparison<'db>( ) => { let s1 = salsa_s1.value(db); let s2 = salsa_s2.value(db); - let result = match op { - ast::CmpOp::Eq => Type::bool_literal(s1 == s2), - ast::CmpOp::NotEq => Type::bool_literal(s1 != s2), - ast::CmpOp::Lt => Type::bool_literal(s1 < s2), - ast::CmpOp::LtE => Type::bool_literal(s1 <= s2), - ast::CmpOp::Gt => Type::bool_literal(s1 > s2), - ast::CmpOp::GtE => Type::bool_literal(s1 >= s2), - ast::CmpOp::In => Type::bool_literal(s2.contains(s1)), - ast::CmpOp::NotIn => Type::bool_literal(!s2.contains(s1)), - ast::CmpOp::Is => { + let result = match operator { + NonEqualityOperator::Lt => Type::bool_literal(s1 < s2), + NonEqualityOperator::LtE => Type::bool_literal(s1 <= s2), + NonEqualityOperator::Gt => Type::bool_literal(s1 > s2), + NonEqualityOperator::GtE => Type::bool_literal(s1 >= s2), + NonEqualityOperator::In => { + Type::bool_literal(s2.contains(s1)) + } + NonEqualityOperator::NotIn => { + Type::bool_literal(!s2.contains(s1)) + } + NonEqualityOperator::Is => { if s1 == s2 { KnownClass::Bool.to_instance(db) } else { Type::bool_literal(false) } } - ast::CmpOp::IsNot => { + NonEqualityOperator::IsNot => { if s1 == s2 { KnownClass::Bool.to_instance(db) } else { @@ -524,27 +559,25 @@ pub(super) fn infer_binary_type_comparison<'db>( ) => { let b1 = salsa_b1.value(db); let b2 = salsa_b2.value(db); - let result = match op { - ast::CmpOp::Eq => Type::bool_literal(b1 == b2), - ast::CmpOp::NotEq => Type::bool_literal(b1 != b2), - ast::CmpOp::Lt => Type::bool_literal(b1 < b2), - ast::CmpOp::LtE => Type::bool_literal(b1 <= b2), - ast::CmpOp::Gt => Type::bool_literal(b1 > b2), - ast::CmpOp::GtE => Type::bool_literal(b1 >= b2), - ast::CmpOp::In => { + let result = match operator { + NonEqualityOperator::Lt => Type::bool_literal(b1 < b2), + NonEqualityOperator::LtE => Type::bool_literal(b1 <= b2), + NonEqualityOperator::Gt => Type::bool_literal(b1 > b2), + NonEqualityOperator::GtE => Type::bool_literal(b1 >= b2), + NonEqualityOperator::In => { Type::bool_literal(memchr::memmem::find(b2, b1).is_some()) } - ast::CmpOp::NotIn => { + NonEqualityOperator::NotIn => { Type::bool_literal(memchr::memmem::find(b2, b1).is_none()) } - ast::CmpOp::Is => { + NonEqualityOperator::Is => { if b1 == b2 { KnownClass::Bool.to_instance(db) } else { Type::bool_literal(false) } } - ast::CmpOp::IsNot => { + NonEqualityOperator::IsNot => { if b1 == b2 { KnownClass::Bool.to_instance(db) } else { @@ -555,52 +588,8 @@ pub(super) fn infer_binary_type_comparison<'db>( Some(Ok(result)) } - // Same-kind exact literals and the special relationship between `int` and `bool` - // are handled above. Any remaining pair of exact builtin literals compares - // unequal. `LiteralString` also compares unequal to non-string literals, but its - // comparison with an exact string literal remains ambiguous. - ( - LiteralValueTypeKind::Int(_) - | LiteralValueTypeKind::Bool(_) - | LiteralValueTypeKind::String(_) - | LiteralValueTypeKind::Bytes(_), - LiteralValueTypeKind::Int(_) - | LiteralValueTypeKind::Bool(_) - | LiteralValueTypeKind::String(_) - | LiteralValueTypeKind::Bytes(_), - ) - | ( - LiteralValueTypeKind::LiteralString, - LiteralValueTypeKind::Int(_) - | LiteralValueTypeKind::Bool(_) - | LiteralValueTypeKind::Bytes(_), - ) - | ( - LiteralValueTypeKind::Int(_) - | LiteralValueTypeKind::Bool(_) - | LiteralValueTypeKind::Bytes(_), - LiteralValueTypeKind::LiteralString, - ) if matches!(op, ast::CmpOp::Eq | ast::CmpOp::NotEq) => { - Some(Ok(Type::bool_literal(op == ast::CmpOp::NotEq))) - } - _ => None, - } - } - - ( - Type::KnownInstance(KnownInstanceType::ConstraintSet(left)), - Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), - ) => { - let constraints = ConstraintSetBuilder::new(); - let left = constraints.load(db, left.constraints(db)); - let right = constraints.load(db, right.constraints(db)); - let result = left.iff(db, &constraints, right); - let equivalent = result.is_always_satisfied(db); - match op { - ast::CmpOp::Eq => Some(Ok(Type::bool_literal(equivalent))), - ast::CmpOp::NotEq => Some(Ok(Type::bool_literal(!equivalent))), _ => None, - } + }) } (Type::NominalInstance(nominal1), Type::NominalInstance(nominal2)) => nominal1