Skip to content
Closed
74 changes: 74 additions & 0 deletions crates/ruff_benchmark/benches/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,77 @@ fn benchmark_many_enum_members(criterion: &mut Criterion) {
});
}

fn benchmark_enum_comparison(criterion: &mut Criterion, name: &str, code: &str) {
setup_rayon();

criterion.bench_function(name, |b| {
b.iter_batched_ref(
|| setup_micro_case(code),
|case| {
let Case { db, .. } = case;
let result = db.check();
assert_eq!(result.len(), 0);
},
BatchSize::SmallInput,
);
});
}

/// Regression benchmark for <https://github.com/astral-sh/ty/issues/3830>.
fn benchmark_narrowed_str_enum_comparison(criterion: &mut Criterion) {
const NUM_ENUM_MEMBERS: usize = 256;

let mut code = "from enum import StrEnum\n\nclass LargeEnum(StrEnum):\n".to_string();
for index in 0..NUM_ENUM_MEMBERS {
writeln!(&mut code, " VALUE_{index} = \"value_{index}\"").ok();
}
code.push_str(
"\n\ndef compare(left: LargeEnum, right: LargeEnum):\n if right and left != right:\n return\n return left == right\n",
);

benchmark_enum_comparison(criterion, "ty_micro[narrowed_str_enum_comparison]", &code);
}

/// Ensure explicit enum-literal unions are compared as value sets, not member pairs.
fn benchmark_enum_literal_union_comparison(criterion: &mut Criterion) {
const NUM_ENUM_MEMBERS: usize = 256;

let mut code =
"from enum import StrEnum\nfrom typing import Literal\n\nclass LargeEnum(StrEnum):\n"
.to_string();
for index in 0..NUM_ENUM_MEMBERS {
writeln!(&mut code, " VALUE_{index} = \"value_{index}\"").ok();
}
code.push_str("\nLeft = Literal[\n");
for index in 0..NUM_ENUM_MEMBERS / 2 {
writeln!(&mut code, " LargeEnum.VALUE_{index},").ok();
}
code.push_str("]\nRight = Literal[\n");
for index in NUM_ENUM_MEMBERS / 2..NUM_ENUM_MEMBERS {
writeln!(&mut code, " LargeEnum.VALUE_{index},").ok();
}
code.push_str("]\n\n\ndef compare(left: Left, right: Right):\n return left == right\n");

benchmark_enum_comparison(criterion, "ty_micro[enum_literal_union_comparison]", &code);
}

/// Ensure the comparison profile is reused instead of scanning every member per expression.
fn benchmark_repeated_str_enum_comparisons(criterion: &mut Criterion) {
const NUM_ENUM_MEMBERS: usize = 1_024;
const NUM_COMPARISONS: usize = 1_000;

let mut code = "from enum import StrEnum\n\nclass LargeEnum(StrEnum):\n".to_string();
for index in 0..NUM_ENUM_MEMBERS {
writeln!(&mut code, " VALUE_{index} = \"value_{index}\"").ok();
}
code.push_str("\n\ndef compare(left: LargeEnum, right: LargeEnum):\n");
for _ in 0..NUM_COMPARISONS {
code.push_str(" left == right\n");
}

benchmark_enum_comparison(criterion, "ty_micro[repeated_str_enum_comparisons]", &code);
}

/// Micro-benchmark that tests our performance when slicing and unpacking
/// a very large tuple that has many varied literal strings inside it.
///
Expand Down Expand Up @@ -1509,6 +1580,9 @@ criterion_group!(
benchmark_complex_constrained_attributes_2,
benchmark_complex_constrained_attributes_3,
benchmark_many_enum_members,
benchmark_narrowed_str_enum_comparison,
benchmark_enum_literal_union_comparison,
benchmark_repeated_str_enum_comparisons,
benchmark_many_enum_members_2,
benchmark_many_protocol_members_mismatch,
benchmark_gradual_vararg_call,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,21 +255,21 @@ from enum import Enum, Flag
class Permission(Flag):
READ = 1

class OpenEnum(Enum):
class MissingValueEnum(Enum):
ONLY = 1

@classmethod
def _missing_(cls, value: object) -> "OpenEnum":
def _missing_(cls, value: object) -> "MissingValueEnum":
return object.__new__(cls)

def match_flag(value: Permission) -> int: # error: [invalid-return-type]
match value:
case Permission.READ:
return 1

def match_open_enum(value: OpenEnum) -> int: # error: [invalid-return-type]
def match_open_enum(value: MissingValueEnum) -> int: # error: [invalid-return-type]
match value:
case OpenEnum.ONLY:
case MissingValueEnum.ONLY:
return 1
```

Expand Down
168 changes: 159 additions & 9 deletions crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,119 @@ def enum_complement_rhs(x: Color, y: Intersection[Color, Not[Literal[Color.RED]]
reveal_type(x) # revealed: Literal[Color.GREEN, Color.BLUE]
```

When both operands are restricted to members of the same enum, equality narrows each operand to the
members allowed by both. If the restrictions do not overlap, the comparison is always false:

```py
from enum import Enum, IntEnum, StrEnum
from typing import Literal

class Choice(StrEnum):
FIRST = "first"
SECOND = "second"
THIRD = "third"
FOURTH = "fourth"

def compare_after_truthiness_check(left: Choice, right: Choice):
if right and left != right:
reveal_type(right) # revealed: Choice & ~AlwaysFalsy

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe unrelated, but ideally we should simplify out the AlwaysFalsy here. Choice has no instances that are falsy and ty should be able to figure that out.

return

reveal_type(right) # revealed: Choice

def compare_with_narrowed_right(left: Choice, right: Choice):
if right == Choice.FIRST:
return
if left == right:
reveal_type(left) # revealed: Literal[Choice.SECOND, Choice.THIRD, Choice.FOURTH]

def compare_non_overlapping_narrowed_values(left: Choice, right: Choice):
if left == Choice.FIRST or left == Choice.SECOND:
return
if right == Choice.THIRD or right == Choice.FOURTH:
return

reveal_type(left == right) # revealed: Literal[False]

def compare_literal_unions(
left: Literal[Choice.FIRST, Choice.SECOND],
right: Literal[Choice.SECOND, Choice.THIRD],
):
if left == right:
reveal_type(left) # revealed: Literal[Choice.SECOND]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool that ty can do this

reveal_type(right) # revealed: Literal[Choice.SECOND]

def compare_non_overlapping_literal_unions(
left: Literal[Choice.FIRST, Choice.SECOND],
right: Literal[Choice.THIRD, Choice.FOURTH],
):
reveal_type(left == right) # revealed: Literal[False]
```

Members with the same known value are aliases, even when one value comes from a function call.
Comparisons between their canonical members are always true:

```py
def make_value() -> Literal["value"]:
return "value"

class RuntimeAlias(StrEnum):
FIRST = make_value()
SECOND = "value"

reveal_type(RuntimeAlias.FIRST == RuntimeAlias.SECOND) # revealed: Literal[True]

def make_int_value() -> Literal[1]:
return 1

class RuntimeIntAlias(IntEnum):
FIRST = make_int_value()
SECOND = 1

reveal_type(RuntimeIntAlias.FIRST == RuntimeIntAlias.SECOND) # revealed: Literal[True]
```

A scalar mixin can normalize member values before `Enum` checks for aliases. Here, `str` converts
`1` to `"1"`, so the two members are aliases at runtime. Since ty does not model this constructor
call, the comparison remains `bool`:

```py
class CoercingAlias(str, Enum):
FIRST = 1
SECOND = "1"

reveal_type(CoercingAlias.FIRST == CoercingAlias.SECOND) # revealed: bool
```

Equality can transfer restrictions on enum members, but other intersection elements must stay on the
operand where they originated:

```py
from enum import StrEnum
from typing import Any, Literal, NewType
from ty_extensions import Intersection

class Response(StrEnum):
ACCEPT = "accept"
REJECT = "reject"

Tag = NewType("Tag", str)

def compare_any(
left: Response,
right: Intersection[Literal[Response.REJECT], Any],
):
if left != right:
return
reveal_type(left) # revealed: Literal[Response.REJECT]
reveal_type(right) # revealed: Literal[Response.REJECT] & Any

def compare_newtype(left: Response, right: Intersection[Response, Tag]):
if left != right:
return
reveal_type(left) # revealed: Response
```

`Flag` and `IntFlag` values can include zero and unnamed combinations, so their named members do not
cover every possible value:

Expand Down Expand Up @@ -167,23 +280,18 @@ even when only one member is declared:
```py
from enum import Enum

class OpenEnum(Enum):
class MissingValueEnum(Enum):
ONLY = 1

@classmethod
def _missing_(cls, value: object) -> "OpenEnum":
def _missing_(cls, value: object) -> "MissingValueEnum":
return object.__new__(cls)

def compare_open_enums(left: OpenEnum, right: OpenEnum):
def compare_open_enums(left: MissingValueEnum, right: MissingValueEnum):
reveal_type(left == right) # revealed: bool

if left != right:
reveal_type(left) # revealed: OpenEnum

def exclude_declared_member(value: OpenEnum):
if value is OpenEnum.ONLY:
return
reveal_type(value) # revealed: OpenEnum & ~Literal[OpenEnum.ONLY]
reveal_type(left) # revealed: MissingValueEnum
```

A custom enum metaclass can add members that do not appear in the class body. Two values of a
Expand All @@ -204,6 +312,48 @@ def compare_transformed_enums(left: TransformedEnum, right: TransformedEnum):
reveal_type(left == right) # revealed: bool
```

A custom comparison method determines the result even when both operands have the same enum type:

```py
from enum import Enum
from typing import Literal

class NeverEqual(Enum):
FIRST = 1
SECOND = 2
THIRD = 3

def __eq__(self, other: object) -> Literal[False]:
return False

def compare_custom(left: NeverEqual, right: NeverEqual):
reveal_type(left == right) # revealed: Literal[False]

if left is NeverEqual.FIRST:
return
reveal_type(left == right) # revealed: Literal[False]
Comment on lines +332 to +334

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could expand this test case:

Suggested change
if left is NeverEqual.FIRST:
return
reveal_type(left == right) # revealed: Literal[False]
if left is NeverEqual.FIRST and right is NeverEqual.FIRST:
reveal_type(left == right) # revealed: Literal[False]
class AlwaysEqual(Enum):
FIRST = 1
SECOND = 2
THIRD = 3
def __eq__(self, other: object) -> Literal[True]:
return True
def _(left: AlwaysEqual, right: AlwaysEqual):
reveal_type(left == right) # revealed: Literal[True]
if left is AlwaysEqual.FIRST and right is AlwaysEqual.SECOND:
reveal_type(left == right) # revealed: Literal[True]
class EqualityUnknown(Enum):
FIRST = 1
SECOND = 2
THIRD = 3
def __eq__(self, other: object): ...
def _(left: EqualityUnknown, right: EqualityUnknown):
reveal_type(left == right) # revealed: bool
if left is EqualityUnknown.FIRST and right is EqualityUnknown.FIRST:
reveal_type(left == right) # revealed: bool
if left is EqualityUnknown.FIRST and right is EqualityUnknown.SECOND:
reveal_type(left == right) # revealed: bool

```

When member values are not known statically, two different members may still compare equal:

```py
from enum import StrEnum
from typing import Literal

def runtime_value(value: str) -> str:
return value

class UnknownValues(StrEnum):
FIRST = runtime_value("first")
SECOND = runtime_value("second")

def compare_unknown_values(
left: Literal[UnknownValues.FIRST],
right: Literal[UnknownValues.SECOND],
):
reveal_type(left == right) # revealed: bool
```

Unlike plain `Enum` members, `IntEnum` members inherit integer equality. Members of different
`IntEnum` classes therefore compare equal when they have the same integer value, so both equality
and inequality narrowing must account for matching members from every class in the union:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def _(x: LiteralString | int):

```py
from enum import Enum
from typing import Literal

class Color(Enum):
RED = "red"
Expand Down Expand Up @@ -310,6 +311,24 @@ def after_excluding_red_mixed(x: Color | int):
reveal_type(x) # revealed: Literal[Color.BLUE] | int
```

When the container's element type is a union of enum literals, membership narrows to that union.
Without the annotation, the tuple's elements are widened to `Color`, so the comprehension remains
`list[Color]`:

```py
SelectedColor = Literal[Color.RED, Color.GREEN]
SELECTED_COLORS: tuple[SelectedColor, ...] = (Color.RED, Color.GREEN)

def selected_colors(colors: list[Color]) -> list[SelectedColor]:
result: list[SelectedColor] = []
result.extend([color for color in colors if color in SELECTED_COLORS])
return result

def _(colors: list[Color]):
inline = [color for color in colors if color in (Color.RED, Color.GREEN)]
reveal_type(inline) # revealed: list[Color]
```

An enum that can have additional runtime members can still be narrowed by a membership test against
an explicit member. The other branch excludes that member without assuming that the declared members
are exhaustive.
Expand All @@ -322,14 +341,14 @@ class InjectingEnumMeta(EnumMeta):
namespace["INJECTED"] = 2
return super().__new__(metacls, name, bases, namespace, **kwargs)

class OpenEnum(Enum, metaclass=InjectingEnumMeta):
class InjectedEnum(Enum, metaclass=InjectingEnumMeta):
ONLY = 1

def _(value: OpenEnum):
if value in (OpenEnum.ONLY,):
reveal_type(value) # revealed: Literal[OpenEnum.ONLY]
def _(value: InjectedEnum):
if value in (InjectedEnum.ONLY,):
reveal_type(value) # revealed: Literal[InjectedEnum.ONLY]
else:
reveal_type(value) # revealed: OpenEnum & ~Literal[OpenEnum.ONLY]
reveal_type(value) # revealed: InjectedEnum & ~Literal[InjectedEnum.ONLY]
```

## Union with enum and `int`
Expand Down
Loading
Loading