diff --git a/crates/ruff_benchmark/benches/ty.rs b/crates/ruff_benchmark/benches/ty.rs index 8f8305fbcc0a11..bbae742c2563d8 100644 --- a/crates/ruff_benchmark/benches/ty.rs +++ b/crates/ruff_benchmark/benches/ty.rs @@ -600,6 +600,77 @@ fn benchmark_many_enum_members(criterion: &mut Criterion) { }); } +fn benchmark_enum_comparison(criterion: &mut Criterion, name: &str, code: &str) { + setup_rayon(); + + criterion.bench_function(name, |b| { + b.iter_batched_ref( + || setup_micro_case(code), + |case| { + let Case { db, .. } = case; + let result = db.check(); + assert_eq!(result.len(), 0); + }, + BatchSize::SmallInput, + ); + }); +} + +/// Regression benchmark for . +fn benchmark_narrowed_str_enum_comparison(criterion: &mut Criterion) { + const NUM_ENUM_MEMBERS: usize = 256; + + let mut code = "from enum import StrEnum\n\nclass LargeEnum(StrEnum):\n".to_string(); + for index in 0..NUM_ENUM_MEMBERS { + writeln!(&mut code, " VALUE_{index} = \"value_{index}\"").ok(); + } + code.push_str( + "\n\ndef compare(left: LargeEnum, right: LargeEnum):\n if right and left != right:\n return\n return left == right\n", + ); + + benchmark_enum_comparison(criterion, "ty_micro[narrowed_str_enum_comparison]", &code); +} + +/// Ensure explicit enum-literal unions are compared as value sets, not member pairs. +fn benchmark_enum_literal_union_comparison(criterion: &mut Criterion) { + const NUM_ENUM_MEMBERS: usize = 256; + + let mut code = + "from enum import StrEnum\nfrom typing import Literal\n\nclass LargeEnum(StrEnum):\n" + .to_string(); + for index in 0..NUM_ENUM_MEMBERS { + writeln!(&mut code, " VALUE_{index} = \"value_{index}\"").ok(); + } + code.push_str("\nLeft = Literal[\n"); + for index in 0..NUM_ENUM_MEMBERS / 2 { + writeln!(&mut code, " LargeEnum.VALUE_{index},").ok(); + } + code.push_str("]\nRight = Literal[\n"); + for index in NUM_ENUM_MEMBERS / 2..NUM_ENUM_MEMBERS { + writeln!(&mut code, " LargeEnum.VALUE_{index},").ok(); + } + code.push_str("]\n\n\ndef compare(left: Left, right: Right):\n return left == right\n"); + + benchmark_enum_comparison(criterion, "ty_micro[enum_literal_union_comparison]", &code); +} + +/// Ensure the comparison profile is reused instead of scanning every member per expression. +fn benchmark_repeated_str_enum_comparisons(criterion: &mut Criterion) { + const NUM_ENUM_MEMBERS: usize = 1_024; + const NUM_COMPARISONS: usize = 1_000; + + let mut code = "from enum import StrEnum\n\nclass LargeEnum(StrEnum):\n".to_string(); + for index in 0..NUM_ENUM_MEMBERS { + writeln!(&mut code, " VALUE_{index} = \"value_{index}\"").ok(); + } + code.push_str("\n\ndef compare(left: LargeEnum, right: LargeEnum):\n"); + for _ in 0..NUM_COMPARISONS { + code.push_str(" left == right\n"); + } + + benchmark_enum_comparison(criterion, "ty_micro[repeated_str_enum_comparisons]", &code); +} + /// Micro-benchmark that tests our performance when slicing and unpacking /// a very large tuple that has many varied literal strings inside it. /// @@ -1509,6 +1580,9 @@ criterion_group!( benchmark_complex_constrained_attributes_2, benchmark_complex_constrained_attributes_3, benchmark_many_enum_members, + benchmark_narrowed_str_enum_comparison, + benchmark_enum_literal_union_comparison, + benchmark_repeated_str_enum_comparisons, benchmark_many_enum_members_2, benchmark_many_protocol_members_mismatch, benchmark_gradual_vararg_call, diff --git a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md index a60a30cb4f9388..e4eda7a8c5661f 100644 --- a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md +++ b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md @@ -255,11 +255,11 @@ from enum import Enum, Flag class Permission(Flag): READ = 1 -class OpenEnum(Enum): +class MissingValueEnum(Enum): ONLY = 1 @classmethod - def _missing_(cls, value: object) -> "OpenEnum": + def _missing_(cls, value: object) -> "MissingValueEnum": return object.__new__(cls) def match_flag(value: Permission) -> int: # error: [invalid-return-type] @@ -267,9 +267,9 @@ def match_flag(value: Permission) -> int: # error: [invalid-return-type] case Permission.READ: return 1 -def match_open_enum(value: OpenEnum) -> int: # error: [invalid-return-type] +def match_open_enum(value: MissingValueEnum) -> int: # error: [invalid-return-type] match value: - case OpenEnum.ONLY: + case MissingValueEnum.ONLY: return 1 ``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md index 32232137f3a006..bedfb5a0a6cd6d 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md @@ -122,6 +122,119 @@ def enum_complement_rhs(x: Color, y: Intersection[Color, Not[Literal[Color.RED]] reveal_type(x) # revealed: Literal[Color.GREEN, Color.BLUE] ``` +When both operands are restricted to members of the same enum, equality narrows each operand to the +members allowed by both. If the restrictions do not overlap, the comparison is always false: + +```py +from enum import Enum, IntEnum, StrEnum +from typing import Literal + +class Choice(StrEnum): + FIRST = "first" + SECOND = "second" + THIRD = "third" + FOURTH = "fourth" + +def compare_after_truthiness_check(left: Choice, right: Choice): + if right and left != right: + reveal_type(right) # revealed: Choice & ~AlwaysFalsy + return + + reveal_type(right) # revealed: Choice + +def compare_with_narrowed_right(left: Choice, right: Choice): + if right == Choice.FIRST: + return + if left == right: + reveal_type(left) # revealed: Literal[Choice.SECOND, Choice.THIRD, Choice.FOURTH] + +def compare_non_overlapping_narrowed_values(left: Choice, right: Choice): + if left == Choice.FIRST or left == Choice.SECOND: + return + if right == Choice.THIRD or right == Choice.FOURTH: + return + + reveal_type(left == right) # revealed: Literal[False] + +def compare_literal_unions( + left: Literal[Choice.FIRST, Choice.SECOND], + right: Literal[Choice.SECOND, Choice.THIRD], +): + if left == right: + reveal_type(left) # revealed: Literal[Choice.SECOND] + reveal_type(right) # revealed: Literal[Choice.SECOND] + +def compare_non_overlapping_literal_unions( + left: Literal[Choice.FIRST, Choice.SECOND], + right: Literal[Choice.THIRD, Choice.FOURTH], +): + reveal_type(left == right) # revealed: Literal[False] +``` + +Members with the same known value are aliases, even when one value comes from a function call. +Comparisons between their canonical members are always true: + +```py +def make_value() -> Literal["value"]: + return "value" + +class RuntimeAlias(StrEnum): + FIRST = make_value() + SECOND = "value" + +reveal_type(RuntimeAlias.FIRST == RuntimeAlias.SECOND) # revealed: Literal[True] + +def make_int_value() -> Literal[1]: + return 1 + +class RuntimeIntAlias(IntEnum): + FIRST = make_int_value() + SECOND = 1 + +reveal_type(RuntimeIntAlias.FIRST == RuntimeIntAlias.SECOND) # revealed: Literal[True] +``` + +A scalar mixin can normalize member values before `Enum` checks for aliases. Here, `str` converts +`1` to `"1"`, so the two members are aliases at runtime. Since ty does not model this constructor +call, the comparison remains `bool`: + +```py +class CoercingAlias(str, Enum): + FIRST = 1 + SECOND = "1" + +reveal_type(CoercingAlias.FIRST == CoercingAlias.SECOND) # revealed: bool +``` + +Equality can transfer restrictions on enum members, but other intersection elements must stay on the +operand where they originated: + +```py +from enum import StrEnum +from typing import Any, Literal, NewType +from ty_extensions import Intersection + +class Response(StrEnum): + ACCEPT = "accept" + REJECT = "reject" + +Tag = NewType("Tag", str) + +def compare_any( + left: Response, + right: Intersection[Literal[Response.REJECT], Any], +): + if left != right: + return + reveal_type(left) # revealed: Literal[Response.REJECT] + reveal_type(right) # revealed: Literal[Response.REJECT] & Any + +def compare_newtype(left: Response, right: Intersection[Response, Tag]): + if left != right: + return + reveal_type(left) # revealed: Response +``` + `Flag` and `IntFlag` values can include zero and unnamed combinations, so their named members do not cover every possible value: @@ -167,23 +280,18 @@ even when only one member is declared: ```py from enum import Enum -class OpenEnum(Enum): +class MissingValueEnum(Enum): ONLY = 1 @classmethod - def _missing_(cls, value: object) -> "OpenEnum": + def _missing_(cls, value: object) -> "MissingValueEnum": return object.__new__(cls) -def compare_open_enums(left: OpenEnum, right: OpenEnum): +def compare_open_enums(left: MissingValueEnum, right: MissingValueEnum): reveal_type(left == right) # revealed: bool if left != right: - reveal_type(left) # revealed: OpenEnum - -def exclude_declared_member(value: OpenEnum): - if value is OpenEnum.ONLY: - return - reveal_type(value) # revealed: OpenEnum & ~Literal[OpenEnum.ONLY] + reveal_type(left) # revealed: MissingValueEnum ``` A custom enum metaclass can add members that do not appear in the class body. Two values of a @@ -204,6 +312,48 @@ def compare_transformed_enums(left: TransformedEnum, right: TransformedEnum): reveal_type(left == right) # revealed: bool ``` +A custom comparison method determines the result even when both operands have the same enum type: + +```py +from enum import Enum +from typing import Literal + +class NeverEqual(Enum): + FIRST = 1 + SECOND = 2 + THIRD = 3 + + def __eq__(self, other: object) -> Literal[False]: + return False + +def compare_custom(left: NeverEqual, right: NeverEqual): + reveal_type(left == right) # revealed: Literal[False] + + if left is NeverEqual.FIRST: + return + reveal_type(left == right) # revealed: Literal[False] +``` + +When member values are not known statically, two different members may still compare equal: + +```py +from enum import StrEnum +from typing import Literal + +def runtime_value(value: str) -> str: + return value + +class UnknownValues(StrEnum): + FIRST = runtime_value("first") + SECOND = runtime_value("second") + +def compare_unknown_values( + left: Literal[UnknownValues.FIRST], + right: Literal[UnknownValues.SECOND], +): + reveal_type(left == right) # revealed: bool +``` + Unlike plain `Enum` members, `IntEnum` members inherit integer equality. Members of different `IntEnum` classes therefore compare equal when they have the same integer value, so both equality and inequality narrowing must account for matching members from every class in the union: diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md index 7214410d30ad61..57c1b46094ac9d 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md @@ -281,6 +281,7 @@ def _(x: LiteralString | int): ```py from enum import Enum +from typing import Literal class Color(Enum): RED = "red" @@ -310,6 +311,24 @@ def after_excluding_red_mixed(x: Color | int): reveal_type(x) # revealed: Literal[Color.BLUE] | int ``` +When the container's element type is a union of enum literals, membership narrows to that union. +Without the annotation, the tuple's elements are widened to `Color`, so the comprehension remains +`list[Color]`: + +```py +SelectedColor = Literal[Color.RED, Color.GREEN] +SELECTED_COLORS: tuple[SelectedColor, ...] = (Color.RED, Color.GREEN) + +def selected_colors(colors: list[Color]) -> list[SelectedColor]: + result: list[SelectedColor] = [] + result.extend([color for color in colors if color in SELECTED_COLORS]) + return result + +def _(colors: list[Color]): + inline = [color for color in colors if color in (Color.RED, Color.GREEN)] + reveal_type(inline) # revealed: list[Color] +``` + An enum that can have additional runtime members can still be narrowed by a membership test against an explicit member. The other branch excludes that member without assuming that the declared members are exhaustive. @@ -322,14 +341,14 @@ class InjectingEnumMeta(EnumMeta): namespace["INJECTED"] = 2 return super().__new__(metacls, name, bases, namespace, **kwargs) -class OpenEnum(Enum, metaclass=InjectingEnumMeta): +class InjectedEnum(Enum, metaclass=InjectingEnumMeta): ONLY = 1 -def _(value: OpenEnum): - if value in (OpenEnum.ONLY,): - reveal_type(value) # revealed: Literal[OpenEnum.ONLY] +def _(value: InjectedEnum): + if value in (InjectedEnum.ONLY,): + reveal_type(value) # revealed: Literal[InjectedEnum.ONLY] else: - reveal_type(value) # revealed: OpenEnum & ~Literal[OpenEnum.ONLY] + reveal_type(value) # revealed: InjectedEnum & ~Literal[InjectedEnum.ONLY] ``` ## Union with enum and `int` diff --git a/crates/ty_python_semantic/src/types/enums.rs b/crates/ty_python_semantic/src/types/enums.rs index 581cd131726da4..0bc9ec464254c4 100644 --- a/crates/ty_python_semantic/src/types/enums.rs +++ b/crates/ty_python_semantic/src/types/enums.rs @@ -778,30 +778,33 @@ pub(crate) fn enum_ignored_names<'db>(db: &'db dyn Db, scope_id: ScopeId<'db>) - } } -/// If `value_ty` is a hashable literal and already exists in `enum_values`, -/// record it as an alias and return `true`. Otherwise track it as canonical. +/// If `value_ty` has the same supported literal kind and payload as a value in `enum_values`, record +/// it as an alias and return `true`. Otherwise track it as canonical. Literal metadata does not +/// affect enum aliasing at runtime, so the map is keyed by [`LiteralValueTypeKind`] rather than +/// [`Type`]. fn try_register_alias<'db>( value_ty: Type<'db>, name: &Name, - enum_values: &mut FxHashMap, Name>, + enum_values: &mut FxHashMap, Name>, aliases: &mut FxHashMap, ) -> bool { + let Some(value) = value_ty.as_literal_value_kind() else { + return false; + }; if !matches!( - value_ty.as_literal_value_kind(), - Some( - LiteralValueTypeKind::Bool(_) - | LiteralValueTypeKind::Int(_) - | LiteralValueTypeKind::String(_) - | LiteralValueTypeKind::Bytes(_) - ) + value, + LiteralValueTypeKind::Bool(_) + | LiteralValueTypeKind::Int(_) + | LiteralValueTypeKind::String(_) + | LiteralValueTypeKind::Bytes(_) ) { return false; } - if let Some(canonical) = enum_values.get(&value_ty) { + if let Some(canonical) = enum_values.get(&value) { aliases.insert(name.clone(), canonical.clone()); return true; } - enum_values.insert(value_ty, name.clone()); + enum_values.insert(value, name.clone()); false } @@ -833,7 +836,7 @@ pub(crate) fn enum_metadata<'db>( } let mut members = FxIndexMap::default(); let mut aliases = FxHashMap::default(); - let mut enum_values: FxHashMap, Name> = FxHashMap::default(); + let mut enum_values: FxHashMap, Name> = FxHashMap::default(); for (name, ty) in spec.members(db) { if try_register_alias(*ty, name, &mut enum_values, &mut aliases) { continue; @@ -868,7 +871,7 @@ pub(crate) fn enum_metadata<'db>( let use_def_map = use_def_map(db, scope_id); let table = place_table(db, scope_id); - let mut enum_values: FxHashMap, Name> = FxHashMap::default(); + let mut enum_values: FxHashMap, Name> = FxHashMap::default(); let mut auto_counter = 0; let mut auto_members = FxHashSet::default(); let mut prev_value_was_non_literal_int = false; diff --git a/crates/ty_python_semantic/src/types/equality.rs b/crates/ty_python_semantic/src/types/equality.rs index d6c5e4feb4d44f..998cb1760ab369 100644 --- a/crates/ty_python_semantic/src/types/equality.rs +++ b/crates/ty_python_semantic/src/types/equality.rs @@ -18,6 +18,8 @@ use super::{ mod enums; pub(super) use self::enums::enum_membership_constraint; +use self::enums::evaluate_same_enum_domains; +pub(super) use self::enums::{EnumComparison, same_enum_comparison}; /// The result of evaluating a runtime comparison between two types. /// @@ -316,6 +318,10 @@ fn evaluate_comparison_once<'db>( ) -> ComparisonResult<'db> { let db = evaluator.db; + if let Some(result) = evaluate_same_enum_domains(db, left, right, branch, operator) { + return result; + } + if let Some(alternatives) = finite_alternatives(db, left, operator) { return evaluate_union_left(evaluator, &alternatives, right, branch, operator); } diff --git a/crates/ty_python_semantic/src/types/equality/enums.rs b/crates/ty_python_semantic/src/types/equality/enums.rs index bee716acc0e680..8809f045c32646 100644 --- a/crates/ty_python_semantic/src/types/equality/enums.rs +++ b/crates/ty_python_semantic/src/types/equality/enums.rs @@ -1,77 +1,515 @@ -//! Equality reasoning for values from the same enum class. +//! Equality reasoning for values from the same enum class without expanding member unions. -use crate::Db; -use crate::types::{EnumClassLiteral, IntersectionBuilder, LiteralValueTypeKind, Type}; +use ruff_python_ast::name::Name; +use rustc_hash::FxHashSet; +use ty_python_core::Truthiness; -/// Return the enum class when `ty` represents an open enum domain. -fn open_enum_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option> { - match ty.resolve_type_alias(db) { - Type::NominalInstance(instance) => { - let enum_class = instance.class_literal(db).into_enum_class(db)?; - (!enum_class.members_are_exhaustive(db)).then_some(enum_class) - } - Type::Intersection(intersection) => { - let mut enum_classes = intersection - .positive(db) - .iter() - .filter_map(|positive| open_enum_class(db, *positive)); - let enum_class = enum_classes.next()?; - enum_classes - .all(|other| other == enum_class) - .then_some(enum_class) +use crate::types::literal::IntLiteralType; +use crate::types::{ + EnumClassLiteral, EnumComplementType, EnumLiteralType, IntersectionBuilder, IntersectionType, + LiteralValueType, LiteralValueTypeKind, Type, UnionBuilder, +}; +use crate::{Db, FxOrderMap, FxOrderSet}; + +use super::{ComparisonBranch, ComparisonOperator, ComparisonResult, KnownComparisonSemantics}; + +/// Compare two value domains from the same enum without comparing every pair of members. +/// +/// Any narrowing constraint produced here contains only enum-membership facts. In particular, +/// equality never transfers gradual or nominal intersection state from one operand to the other. +pub(super) fn evaluate_same_enum_domains<'db>( + db: &'db dyn Db, + target: Type<'db>, + other: Type<'db>, + branch: ComparisonBranch, + operator: ComparisonOperator, +) -> Option> { + let comparison = SameEnumComparison::from_types(db, target, other, operator)?; + match comparison.truthiness(db, operator)? { + Truthiness::AlwaysTrue => Some(ComparisonResult::AlwaysTrue), + Truthiness::AlwaysFalse => Some(ComparisonResult::AlwaysFalse), + Truthiness::Ambiguous if !comparison.supports_domain_narrowing() => { + Some(ComparisonResult::Ambiguous) } - _ => None, + Truthiness::Ambiguous if operator.condition_expects_equality(branch) => Some( + ComparisonResult::CanNarrow(comparison.right.restriction_type(db)), + ), + Truthiness::Ambiguous => Some(comparison.right.singleton_type(db).map_or( + ComparisonResult::Ambiguous, + |singleton| { + ComparisonResult::CanNarrow( + IntersectionBuilder::new(db) + .add_positive(comparison.left.restriction_type(db)) + .add_negative(singleton) + .build(), + ) + }, + )), } } -/// Return the enum class when `ty` is an exact set of literals from one enum. -fn exact_enum_member_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option> { - match ty.resolve_type_alias(db) { - Type::LiteralValue(literal) => { - let LiteralValueTypeKind::Enum(literal) = literal.kind() else { - return None; - }; - Some(literal.enum_class_literal(db)) - } - Type::Union(union) => { - let mut enum_classes = union - .elements(db) - .iter() - .map(|element| exact_enum_member_class(db, *element)); - let enum_class = enum_classes.next()??; - enum_classes - .all(|other| other == Some(enum_class)) - .then_some(enum_class) - } - _ => None, - } +/// The result of comparing two value domains from the same enum. +pub(in crate::types) enum EnumComparison { + /// The comparison method is modeled and has the given truthiness. + Known(Truthiness), + /// A custom or otherwise unmodeled comparison method must be called directly. + Unmodeled, } -/// Return the constraint established by membership in an exact set of open-enum members. +/// Compare two value domains from the same enum without comparing every pair of members. +/// +/// `None` means that an operand is not structurally an enum domain. +pub(in crate::types) fn same_enum_comparison<'db>( + db: &'db dyn Db, + left: Type<'db>, + right: Type<'db>, + is_equality: bool, +) -> Option { + let operator = if is_equality { + ComparisonOperator::Equality + } else { + ComparisonOperator::Inequality + }; + let comparison = SameEnumComparison::from_types(db, left, right, operator)?; + Some( + comparison + .truthiness(db, operator) + .map_or(EnumComparison::Unmodeled, EnumComparison::Known), + ) +} + +/// Return the constraint established by membership in an exact set of members from the same enum. pub(in crate::types) fn enum_membership_constraint<'db>( db: &'db dyn Db, target: Type<'db>, members: Type<'db>, is_positive: bool, ) -> Option> { - let enum_class = open_enum_class(db, target)?; - if exact_enum_member_class(db, members)? != enum_class { - return None; - } - - let enum_instance = enum_class.class_literal(db).to_non_generic_instance(db); - if enum_instance.overrides_equality(db) { + let comparison = + SameEnumComparison::from_types(db, target, members, ComparisonOperator::Equality)?; + if !comparison.supports_domain_narrowing() { return None; } + let members = comparison.right.restriction_type(db); if is_positive { Some(members) } else { Some( IntersectionBuilder::new(db) - .add_positive(enum_instance) + .add_positive(comparison.left.restriction_type(db)) .add_negative(members) .build(), ) } } + +/// Two non-empty value domains from the same enum and the semantics used to compare them. +/// +/// This representation avoids constructing and pairwise comparing unions of every declared +/// member. +struct SameEnumComparison<'db> { + left: EnumValueSet<'db>, + right: EnumValueSet<'db>, + profile: EnumComparisonProfile, +} + +impl<'db> SameEnumComparison<'db> { + fn from_types( + db: &'db dyn Db, + left: Type<'db>, + right: Type<'db>, + operator: ComparisonOperator, + ) -> Option { + let left = EnumValueSet::from_type(db, left)?; + let right = EnumValueSet::from_type(db, right)?; + if left.enum_class != right.enum_class { + return None; + } + let enum_class = left.enum_class; + + Some(Self { + left, + right, + profile: enum_comparison_profile(db, enum_class, operator), + }) + } + + /// Return `None` only when comparison behavior is custom or otherwise unmodeled. + fn truthiness(&self, db: &'db dyn Db, operator: ComparisonOperator) -> Option { + let comparison_keys = self.profile.comparison_keys?; + let members_are_exhaustive = self.profile.members_are_exhaustive; + let domains_are_closed = self.left.is_closed(members_are_exhaustive) + && self.right.is_closed(members_are_exhaustive); + let equality = if domains_are_closed + && comparison_keys == EnumComparisonKeys::Distinct + && !self.left.overlaps(&self.right, db) + { + Truthiness::AlwaysFalse + } else if self.left.is_singleton(db, members_are_exhaustive) + && self.right.is_singleton(db, members_are_exhaustive) + && self.left.overlaps(&self.right, db) + { + Truthiness::AlwaysTrue + } else { + Truthiness::Ambiguous + }; + + Some(equality.negate_if(operator == ComparisonOperator::Inequality)) + } + + /// Return whether equality can soundly transfer the right-hand enum restriction to the left. + /// + /// The right domain must be closed, and the left must either be closed as well or use identity + /// comparison, for which undeclared runtime members cannot equal a declared member. + fn supports_domain_narrowing(&self) -> bool { + matches!( + self.profile.comparison_keys, + Some(EnumComparisonKeys::Distinct) + ) && self.right.is_closed(self.profile.members_are_exhaustive) + && (self.left.is_closed(self.profile.members_are_exhaustive) + || self.profile.members_compare_by_identity) + } +} + +/// The enum-member values represented by an operand, excluding non-enum intersection state. +/// +/// This is an upper bound on the operand's enum values. Gradual and nominal rest components can +/// make the operand more specific, but they must never be transferred to the other operand by an +/// equality constraint. +struct EnumValueSet<'db> { + enum_class: EnumClassLiteral<'db>, + members: EnumValueSetMembers<'db>, +} + +/// Compact representation of the member names admitted by an [`EnumValueSet`]. +enum EnumValueSetMembers<'db> { + /// The entire enum domain, including undeclared runtime values when the enum is open. + All, + /// One canonical member name after resolving aliases, and whether its literal is promotable. + One { name: &'db Name, promotable: bool }, + /// An exact set of canonical member names and their literal promotability. + Included(FxOrderMap<&'db Name, bool>), + /// The enum domain except for the declared members excluded by an enum complement. + AllExcept(EnumComplementType<'db>), +} + +impl<'db> EnumValueSet<'db> { + /// Extract only structural enum membership facts from `ty`. + /// + /// This deliberately does not use subtyping: a `NewType` over an enum is a subtype of the + /// enum but remains disjoint from the enum's literal members. + fn from_type(db: &'db dyn Db, ty: Type<'db>) -> Option { + let value_set = match ty.resolve_type_alias(db) { + Type::LiteralValue(literal) => { + let LiteralValueTypeKind::Enum(enum_literal) = literal.kind() else { + return None; + }; + let enum_class = enum_literal.enum_class_literal(db); + let name = enum_class.resolve_member(db, enum_literal.name(db))?; + Self { + enum_class, + members: EnumValueSetMembers::One { + name, + promotable: literal.is_promotable(), + }, + } + } + Type::NominalInstance(instance) => Self { + enum_class: instance.class_literal(db).into_enum_class(db)?, + members: EnumValueSetMembers::All, + }, + Type::EnumComplement(complement) => Self { + enum_class: complement.enum_class_literal(db), + members: EnumValueSetMembers::AllExcept(complement), + }, + Type::Union(union) => Self::from_union(db, union.elements(db))?, + Type::Intersection(intersection) => Self::from_intersection(db, intersection)?, + _ => return None, + }; + (value_set.member_count(db) > 0).then_some(value_set) + } + + /// Extract an exact included-member set from a union of enum domains. + /// + /// Whole-domain and complement arms are rejected because they are not exact included sets. + fn from_union(db: &'db dyn Db, elements: &[Type<'db>]) -> Option { + let mut enum_class = None; + let mut included = FxOrderMap::default(); + for element in elements { + let value_set = Self::from_type(db, *element)?; + if let Some(enum_class) = enum_class + && enum_class != value_set.enum_class + { + return None; + } + enum_class = Some(value_set.enum_class); + match value_set.members { + EnumValueSetMembers::One { name, promotable } => { + Self::insert_member(&mut included, name, promotable); + } + EnumValueSetMembers::Included(members) => { + for (name, promotable) in members { + Self::insert_member(&mut included, name, promotable); + } + } + EnumValueSetMembers::All | EnumValueSetMembers::AllExcept(_) => return None, + } + } + + let enum_class = enum_class?; + let members = if included.len() == 1 { + let (name, promotable) = included.into_iter().next()?; + EnumValueSetMembers::One { name, promotable } + } else { + EnumValueSetMembers::Included(included) + }; + Some(Self { + enum_class, + members, + }) + } + + /// Insert a member, preserving unpromotable literal provenance if either occurrence has it. + fn insert_member( + included: &mut FxOrderMap<&'db Name, bool>, + name: &'db Name, + promotable: bool, + ) { + match included.entry(name) { + ordermap::map::Entry::Vacant(entry) => { + entry.insert(promotable); + } + ordermap::map::Entry::Occupied(mut entry) => { + *entry.get_mut() &= promotable; + } + } + } + + /// Extract the enum restriction while discarding unrelated positive intersection state. + fn from_intersection(db: &'db dyn Db, intersection: IntersectionType<'db>) -> Option { + if let Some(complement) = intersection.enum_complement(db) { + return Self::from_type(db, Type::EnumComplement(complement)); + } + + // Other intersection components can only reduce the represented enum values. Ignoring + // them therefore preserves a safe upper bound without transferring them during narrowing. + let mut value_sets = intersection + .positive(db) + .iter() + .filter_map(|positive| Self::from_type(db, *positive)); + let value_set = value_sets.next()?; + value_sets + .all(|other| other.enum_class == value_set.enum_class) + .then_some(value_set) + } + + fn member_count(&self, db: &'db dyn Db) -> usize { + match &self.members { + EnumValueSetMembers::All => self.enum_class.member_count(db), + EnumValueSetMembers::One { .. } => 1, + EnumValueSetMembers::Included(names) => names.len(), + EnumValueSetMembers::AllExcept(complement) => { + self.enum_class.member_count(db) - complement.excluded_names(db).len() + } + } + } + + /// Return whether this set excludes every value not named by its representation. + /// + /// A whole-domain or complement representation is not closed for an enum that can create + /// undeclared members at runtime. + fn is_closed(&self, members_are_exhaustive: bool) -> bool { + members_are_exhaustive + || matches!( + self.members, + EnumValueSetMembers::One { .. } | EnumValueSetMembers::Included(_) + ) + } + + fn is_singleton(&self, db: &'db dyn Db, members_are_exhaustive: bool) -> bool { + self.member_count(db) == 1 && self.is_closed(members_are_exhaustive) + } + + fn overlaps(&self, other: &Self, db: &'db dyn Db) -> bool { + debug_assert_eq!(self.enum_class, other.enum_class); + match (&self.members, &other.members) { + (EnumValueSetMembers::All, _) | (_, EnumValueSetMembers::All) => true, + ( + EnumValueSetMembers::One { name: left, .. }, + EnumValueSetMembers::One { name: right, .. }, + ) => left == right, + (EnumValueSetMembers::One { name, .. }, EnumValueSetMembers::Included(names)) + | (EnumValueSetMembers::Included(names), EnumValueSetMembers::One { name, .. }) => { + names.contains_key(name) + } + (EnumValueSetMembers::Included(left), EnumValueSetMembers::Included(right)) => { + let (smaller, larger) = if left.len() < right.len() { + (left, right) + } else { + (right, left) + }; + smaller.keys().any(|name| larger.contains_key(name)) + } + (EnumValueSetMembers::One { name, .. }, EnumValueSetMembers::AllExcept(complement)) + | (EnumValueSetMembers::AllExcept(complement), EnumValueSetMembers::One { name, .. }) => { + !complement.excluded_names(db).contains(*name) + } + (EnumValueSetMembers::Included(names), EnumValueSetMembers::AllExcept(complement)) + | (EnumValueSetMembers::AllExcept(complement), EnumValueSetMembers::Included(names)) => { + names + .keys() + .any(|name| !complement.excluded_names(db).contains(*name)) + } + (EnumValueSetMembers::AllExcept(left), EnumValueSetMembers::AllExcept(right)) => { + let left = left.excluded_names(db); + let right = right.excluded_names(db); + let excluded = + left.len() + right.iter().filter(|name| !left.contains(*name)).count(); + excluded < self.enum_class.member_count(db) + } + } + } + + /// Reconstruct a constraint containing only this enum value restriction. + fn restriction_type(&self, db: &'db dyn Db) -> Type<'db> { + match &self.members { + EnumValueSetMembers::All => self + .enum_class + .class_literal(db) + .to_non_generic_instance(db), + EnumValueSetMembers::One { name, promotable } => { + self.member_type(db, name, *promotable) + } + EnumValueSetMembers::Included(members) => members + .iter() + .fold(UnionBuilder::new(db), |builder, (name, promotable)| { + builder.add(self.member_type(db, name, *promotable)) + }) + .build(), + EnumValueSetMembers::AllExcept(complement) => { + if complement.rest(db).is_empty() { + Type::EnumComplement(*complement) + } else { + Type::EnumComplement(EnumComplementType::new( + db, + self.enum_class, + complement.excluded_names(db).clone(), + FxOrderSet::default(), + )) + } + } + } + } + + /// Reconstruct the only declared member left in this set. + /// + /// The caller must separately establish that the domain is closed before treating this as the + /// operand's only possible runtime value. + fn singleton_type(&self, db: &'db dyn Db) -> Option> { + if self.member_count(db) != 1 { + return None; + } + let (name, promotable) = match &self.members { + EnumValueSetMembers::All => (self.enum_class.member_names(db).next()?, true), + EnumValueSetMembers::One { name, promotable } => (*name, *promotable), + EnumValueSetMembers::Included(members) => { + let (name, promotable) = members.first()?; + (*name, *promotable) + } + EnumValueSetMembers::AllExcept(complement) => ( + self.enum_class + .member_names(db) + .find(|name| !complement.excluded_names(db).contains(*name))?, + true, + ), + }; + Some(self.member_type(db, name, promotable)) + } + + fn member_type(&self, db: &'db dyn Db, name: &Name, promotable: bool) -> Type<'db> { + LiteralValueType::new( + EnumLiteralType::new(db, self.enum_class, name.clone()), + promotable, + ) + .into() + } +} + +/// Whether distinct declared members are known to have distinct runtime comparison keys. +#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +enum EnumComparisonKeys { + /// Different member names cannot compare equal. + Distinct, + /// Values are unknown or repeated, so different member names may compare equal. + UnknownOrRepeated, +} + +/// Class-wide facts required to compare enum value sets without member expansion. +#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +struct EnumComparisonProfile { + members_are_exhaustive: bool, + /// Whether distinct enum members compare by identity, including members created at runtime. + members_compare_by_identity: bool, + /// `None` means comparison behavior is custom or otherwise unmodeled. + comparison_keys: Option, +} + +/// Compute and cache the class-wide work needed for repeated comparisons of the same enum. +#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] +fn enum_comparison_profile<'db>( + db: &'db dyn Db, + enum_class: EnumClassLiteral<'db>, + operator: ComparisonOperator, +) -> EnumComparisonProfile { + let semantics = KnownComparisonSemantics::of_instance( + db, + enum_class.class_literal(db).to_non_generic_instance(db), + operator, + ); + let (comparison_keys, members_compare_by_identity) = match semantics { + None => (None, false), + Some(KnownComparisonSemantics::Object) => (Some(EnumComparisonKeys::Distinct), true), + Some( + semantics @ (KnownComparisonSemantics::Int + | KnownComparisonSemantics::Str + | KnownComparisonSemantics::Bytes), + ) if enum_members_have_distinct_value_keys(db, enum_class, semantics) => { + (Some(EnumComparisonKeys::Distinct), false) + } + Some(_) => (Some(EnumComparisonKeys::UnknownOrRepeated), false), + }; + EnumComparisonProfile { + members_are_exhaustive: enum_class.members_are_exhaustive(db), + members_compare_by_identity, + comparison_keys, + } +} + +/// Return whether every declared member has a unique modeled runtime comparison key. +/// +/// A value is only used when its literal kind matches the inherited scalar semantics; other values +/// may be normalized by the scalar constructor before enum alias detection. +/// Keys exclude literal metadata such as promotability, which does not affect runtime equality. +/// Boolean keys are normalized to integers because Python considers `False == 0` and `True == 1`. +fn enum_members_have_distinct_value_keys<'db>( + db: &'db dyn Db, + enum_class: EnumClassLiteral<'db>, + semantics: KnownComparisonSemantics, +) -> bool { + let mut keys = FxHashSet::default(); + enum_class.members(db).iter().all(|(_, value)| { + let key = match (semantics, value.as_literal_value_kind()) { + (KnownComparisonSemantics::Int, Some(LiteralValueTypeKind::Bool(value))) => { + LiteralValueTypeKind::Int(IntLiteralType::from_i64(i64::from(value))) + } + (KnownComparisonSemantics::Int, Some(kind @ LiteralValueTypeKind::Int(_))) + | (KnownComparisonSemantics::Str, Some(kind @ LiteralValueTypeKind::String(_))) + | (KnownComparisonSemantics::Bytes, Some(kind @ LiteralValueTypeKind::Bytes(_))) => { + kind + } + _ => return false, + }; + keys.insert(key) + }) +} diff --git a/crates/ty_python_semantic/src/types/infer/comparisons.rs b/crates/ty_python_semantic/src/types/infer/comparisons.rs index 6b7e63111d1089..8cc5286fabf06d 100644 --- a/crates/ty_python_semantic/src/types/infer/comparisons.rs +++ b/crates/ty_python_semantic/src/types/infer/comparisons.rs @@ -7,6 +7,7 @@ use crate::types::call::{CallArguments, CallDunderError}; use crate::types::constraints::ConstraintSetBuilder; use crate::types::context::InferContext; use crate::types::cyclic::CycleDetector; +use crate::types::equality::{EnumComparison, same_enum_comparison}; use crate::types::tuple::TupleSpec; use crate::types::{ DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, @@ -170,6 +171,19 @@ pub(super) fn infer_binary_type_comparison<'db>( } }; + if matches!(op, ast::CmpOp::Eq | ast::CmpOp::NotEq) + && let Some(comparison) = same_enum_comparison(db, left, right, op == ast::CmpOp::Eq) + { + return match comparison { + EnumComparison::Known(truthiness) => Ok(match truthiness { + Truthiness::AlwaysTrue => Type::bool_literal(true), + Truthiness::AlwaysFalse => Type::bool_literal(false), + Truthiness::Ambiguous => KnownClass::Bool.to_instance(db), + }), + EnumComparison::Unmodeled => try_dunder(MemberLookupPolicy::default()), + }; + } + let comparison_result = match (left, right) { (Type::EnumComplement(complement), right) => Some(infer_binary_type_comparison( context,