Skip to content

Commit b8f57fd

Browse files
[mypyc] feat: further optimize equality check with string literals [1/1] (#19883)
This PR further optimizes string equality checks against literals by getting rid of the PyUnicode_GET_LENGTH call against the literal value, which is not necessary since the value is known at compile-time I think this optimization will be helpful in cases where the non-literal string DOES match but is actually a subtype of string (actual strings instances that match would be caught by the identity check), or in cases where an exact string does NOT match.
1 parent d69419c commit b8f57fd

File tree

8 files changed

+96
-17
lines changed

8 files changed

+96
-17
lines changed

mypyc/irbuild/ll_builder.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import sys
1010
from collections.abc import Sequence
11-
from typing import Callable, Final, Optional
11+
from typing import Callable, Final, Optional, cast
12+
from typing_extensions import TypeGuard
1213

1314
from mypy.argmap import map_actuals_to_formals
1415
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
@@ -185,6 +186,7 @@
185186
from mypyc.primitives.str_ops import (
186187
str_check_if_true,
187188
str_eq,
189+
str_eq_literal,
188190
str_ssize_t_size_op,
189191
unicode_compare,
190192
)
@@ -1551,9 +1553,33 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -
15511553
def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
15521554
"""Compare two strings"""
15531555
if op == "==":
1556+
# We can specialize this case if one or both values are string literals
1557+
literal_fastpath = False
1558+
1559+
def is_string_literal(value: Value) -> TypeGuard[LoadLiteral]:
1560+
return isinstance(value, LoadLiteral) and is_str_rprimitive(value.type)
1561+
1562+
if is_string_literal(lhs):
1563+
if is_string_literal(rhs):
1564+
# we can optimize out the check entirely in some constant-folded cases
1565+
return self.true() if lhs.value == rhs.value else self.false()
1566+
1567+
# if lhs argument is string literal, switch sides to match specializer C api
1568+
lhs, rhs = rhs, lhs
1569+
literal_fastpath = True
1570+
elif is_string_literal(rhs):
1571+
literal_fastpath = True
1572+
1573+
if literal_fastpath:
1574+
literal_string = cast(str, cast(LoadLiteral, rhs).value)
1575+
literal_length = Integer(len(literal_string), c_pyssize_t_rprimitive, line)
1576+
return self.primitive_op(str_eq_literal, [lhs, rhs, literal_length], line)
1577+
15541578
return self.primitive_op(str_eq, [lhs, rhs], line)
1579+
15551580
elif op == "!=":
1556-
eq = self.primitive_op(str_eq, [lhs, rhs], line)
1581+
# perform a standard equality check, then negate
1582+
eq = self.compare_strings(lhs, rhs, "==", line)
15571583
return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line))
15581584

15591585
# TODO: modify 'str' to use same interface as 'compare_bytes' as it would avoid

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) {
735735
#define BOTHSTRIP 2
736736

737737
char CPyStr_Equal(PyObject *str1, PyObject *str2);
738+
char CPyStr_EqualLiteral(PyObject *str, PyObject *literal_str, Py_ssize_t literal_length);
738739
PyObject *CPyStr_Build(Py_ssize_t len, ...);
739740
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
740741
PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index);

mypyc/lib-rt/str_ops.c

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,33 @@ make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
6464
#undef BLOOM_UPDATE
6565
}
6666

67-
// Adapted from CPython 3.13.1 (_PyUnicode_Equal)
68-
char CPyStr_Equal(PyObject *str1, PyObject *str2) {
69-
if (str1 == str2) {
70-
return 1;
71-
}
72-
Py_ssize_t len = PyUnicode_GET_LENGTH(str1);
73-
if (PyUnicode_GET_LENGTH(str2) != len)
67+
static inline char _CPyStr_Equal_NoIdentCheck(PyObject *str1, PyObject *str2, Py_ssize_t str2_length) {
68+
// This helper function only exists to deduplicate code in CPyStr_Equal and CPyStr_EqualLiteral
69+
Py_ssize_t str1_length = PyUnicode_GET_LENGTH(str1);
70+
if (str1_length != str2_length)
7471
return 0;
7572
int kind = PyUnicode_KIND(str1);
7673
if (PyUnicode_KIND(str2) != kind)
7774
return 0;
7875
const void *data1 = PyUnicode_DATA(str1);
7976
const void *data2 = PyUnicode_DATA(str2);
80-
return memcmp(data1, data2, len * kind) == 0;
77+
return memcmp(data1, data2, str1_length * kind) == 0;
78+
}
79+
80+
// Adapted from CPython 3.13.1 (_PyUnicode_Equal)
81+
char CPyStr_Equal(PyObject *str1, PyObject *str2) {
82+
if (str1 == str2) {
83+
return 1;
84+
}
85+
Py_ssize_t str2_length = PyUnicode_GET_LENGTH(str2);
86+
return _CPyStr_Equal_NoIdentCheck(str1, str2, str2_length);
87+
}
88+
89+
char CPyStr_EqualLiteral(PyObject *str, PyObject *literal_str, Py_ssize_t literal_length) {
90+
if (str == literal_str) {
91+
return 1;
92+
}
93+
return _CPyStr_Equal_NoIdentCheck(str, literal_str, literal_length);
8194
}
8295

8396
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {

mypyc/primitives/str_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@
8888
error_kind=ERR_NEVER,
8989
)
9090

91+
str_eq_literal = custom_primitive_op(
92+
name="str_eq_literal",
93+
c_function_name="CPyStr_EqualLiteral",
94+
arg_types=[str_rprimitive, str_rprimitive, c_pyssize_t_rprimitive],
95+
return_type=bool_rprimitive,
96+
error_kind=ERR_NEVER,
97+
)
98+
9199
unicode_compare = custom_op(
92100
arg_types=[str_rprimitive, str_rprimitive],
93101
return_type=c_int_rprimitive,

mypyc/test-data/irbuild-classes.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2325,15 +2325,15 @@ def SetAttr.__setattr__(self, key, val):
23252325
r12 :: bit
23262326
L0:
23272327
r0 = 'regular_attr'
2328-
r1 = CPyStr_Equal(key, r0)
2328+
r1 = CPyStr_EqualLiteral(key, r0, 12)
23292329
if r1 goto L1 else goto L2 :: bool
23302330
L1:
23312331
r2 = unbox(int, val)
23322332
self.regular_attr = r2; r3 = is_error
23332333
goto L6
23342334
L2:
23352335
r4 = 'class_var'
2336-
r5 = CPyStr_Equal(key, r4)
2336+
r5 = CPyStr_EqualLiteral(key, r4, 9)
23372337
if r5 goto L3 else goto L4 :: bool
23382338
L3:
23392339
r6 = builtins :: module
@@ -2468,15 +2468,15 @@ def SetAttr.__setattr__(self, key, val):
24682468
r12 :: bit
24692469
L0:
24702470
r0 = 'regular_attr'
2471-
r1 = CPyStr_Equal(key, r0)
2471+
r1 = CPyStr_EqualLiteral(key, r0, 12)
24722472
if r1 goto L1 else goto L2 :: bool
24732473
L1:
24742474
r2 = unbox(int, val)
24752475
self.regular_attr = r2; r3 = is_error
24762476
goto L6
24772477
L2:
24782478
r4 = 'class_var'
2479-
r5 = CPyStr_Equal(key, r4)
2479+
r5 = CPyStr_EqualLiteral(key, r4, 9)
24802480
if r5 goto L3 else goto L4 :: bool
24812481
L3:
24822482
r6 = builtins :: module

mypyc/test-data/irbuild-dict.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ L2:
410410
k = r8
411411
v = r7
412412
r9 = 'name'
413-
r10 = CPyStr_Equal(k, r9)
413+
r10 = CPyStr_EqualLiteral(k, r9, 4)
414414
if r10 goto L3 else goto L4 :: bool
415415
L3:
416416
name = v

mypyc/test-data/irbuild-str.test

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,34 @@ L2:
740740
L3:
741741
keep_alive x
742742
return r2
743+
744+
[case testStrEqLiteral]
745+
from typing import Final
746+
literal: Final = "literal"
747+
def literal_rhs(x: str) -> bool:
748+
return x == literal
749+
def literal_lhs(x: str) -> bool:
750+
return literal == x
751+
def literal_both() -> bool:
752+
return literal == "literal"
753+
[out]
754+
def literal_rhs(x):
755+
x, r0 :: str
756+
r1 :: bool
757+
L0:
758+
r0 = 'literal'
759+
r1 = CPyStr_EqualLiteral(x, r0, 7)
760+
return r1
761+
def literal_lhs(x):
762+
x, r0 :: str
763+
r1 :: bool
764+
L0:
765+
r0 = 'literal'
766+
r1 = CPyStr_EqualLiteral(x, r0, 7)
767+
return r1
768+
def literal_both():
769+
r0, r1 :: str
770+
L0:
771+
r0 = 'literal'
772+
r1 = 'literal'
773+
return 1

mypyc/test-data/irbuild-unreachable.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ L0:
2020
r2 = CPyObject_GetAttr(r0, r1)
2121
r3 = cast(str, r2)
2222
r4 = 'x'
23-
r5 = CPyStr_Equal(r3, r4)
23+
r5 = CPyStr_EqualLiteral(r3, r4, 1)
2424
if r5 goto L2 else goto L1 :: bool
2525
L1:
2626
r6 = r5
@@ -54,7 +54,7 @@ L0:
5454
r2 = CPyObject_GetAttr(r0, r1)
5555
r3 = cast(str, r2)
5656
r4 = 'x'
57-
r5 = CPyStr_Equal(r3, r4)
57+
r5 = CPyStr_EqualLiteral(r3, r4, 1)
5858
if r5 goto L2 else goto L1 :: bool
5959
L1:
6060
r6 = r5

0 commit comments

Comments
 (0)