|
8 | 8 |
|
9 | 9 | import sys |
10 | 10 | from collections.abc import Sequence |
11 | | -from typing import Callable, Final, Optional |
| 11 | +from typing import Callable, Final, Optional, cast |
| 12 | +from typing_extensions import TypeGuard |
12 | 13 |
|
13 | 14 | from mypy.argmap import map_actuals_to_formals |
14 | 15 | from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind |
|
185 | 186 | from mypyc.primitives.str_ops import ( |
186 | 187 | str_check_if_true, |
187 | 188 | str_eq, |
| 189 | + str_eq_literal, |
188 | 190 | str_ssize_t_size_op, |
189 | 191 | unicode_compare, |
190 | 192 | ) |
@@ -1551,9 +1553,33 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) - |
1551 | 1553 | def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: |
1552 | 1554 | """Compare two strings""" |
1553 | 1555 | if op == "==": |
| 1556 | + # We can specialize this case if one or both values are string literals |
| 1557 | + literal_fastpath = False |
| 1558 | + |
| 1559 | + def is_string_literal(value: Value) -> TypeGuard[LoadLiteral]: |
| 1560 | + return isinstance(value, LoadLiteral) and is_str_rprimitive(value.type) |
| 1561 | + |
| 1562 | + if is_string_literal(lhs): |
| 1563 | + if is_string_literal(rhs): |
| 1564 | + # we can optimize out the check entirely in some constant-folded cases |
| 1565 | + return self.true() if lhs.value == rhs.value else self.false() |
| 1566 | + |
| 1567 | + # if lhs argument is string literal, switch sides to match specializer C api |
| 1568 | + lhs, rhs = rhs, lhs |
| 1569 | + literal_fastpath = True |
| 1570 | + elif is_string_literal(rhs): |
| 1571 | + literal_fastpath = True |
| 1572 | + |
| 1573 | + if literal_fastpath: |
| 1574 | + literal_string = cast(str, cast(LoadLiteral, rhs).value) |
| 1575 | + literal_length = Integer(len(literal_string), c_pyssize_t_rprimitive, line) |
| 1576 | + return self.primitive_op(str_eq_literal, [lhs, rhs, literal_length], line) |
| 1577 | + |
1554 | 1578 | return self.primitive_op(str_eq, [lhs, rhs], line) |
| 1579 | + |
1555 | 1580 | elif op == "!=": |
1556 | | - eq = self.primitive_op(str_eq, [lhs, rhs], line) |
| 1581 | + # perform a standard equality check, then negate |
| 1582 | + eq = self.compare_strings(lhs, rhs, "==", line) |
1557 | 1583 | return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line)) |
1558 | 1584 |
|
1559 | 1585 | # TODO: modify 'str' to use same interface as 'compare_bytes' as it would avoid |
|
0 commit comments