Skip to content

Commit 78dcec5

Browse files
authored
[mypyc] Enable --strict-bytes by default in mypyc (and require it) (#20548)
We'll be switching this on by default in mypy 2.0, but we can do it earlier in mypyc since backward compatibility requirements are not as strict. This makes the `bytes` primitive type not include `bytearray` any more. This allows more efficient primitives and makes the semantics more consistent. I'll ensure that the bytes C primitives don't have dead bytearray code paths after this has been merged. I'll also add a separate bytearray primitive type at some point.
1 parent d61ad17 commit 78dcec5

File tree

13 files changed

+80
-35
lines changed

13 files changed

+80
-35
lines changed

mypy/build.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2397,6 +2397,8 @@ def parse_file(self, *, temporary: bool = False) -> None:
23972397
self.source_hash = compute_hash(source)
23982398

23992399
self.parse_inline_configuration(source)
2400+
self.check_for_invalid_options()
2401+
24002402
self.size_hint = len(source)
24012403
if not cached:
24022404
self.tree = manager.parse_file(
@@ -2447,6 +2449,13 @@ def parse_inline_configuration(self, source: str) -> None:
24472449
for lineno, error in config_errors:
24482450
self.manager.errors.report(lineno, 0, error)
24492451

2452+
def check_for_invalid_options(self) -> None:
2453+
if self.options.mypyc and not self.options.strict_bytes:
2454+
self.manager.errors.set_file(self.xpath, self.id, options=self.options)
2455+
self.manager.errors.report(
2456+
1, 0, "Option --strict-bytes cannot be disabled when using mypyc", blocker=True
2457+
)
2458+
24502459
def semantic_analysis_pass1(self) -> None:
24512460
"""Perform pass 1 of semantic analysis, which happens immediately after parsing.
24522461

mypy/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,7 @@ def process_options(
13711371
fscache: FileSystemCache | None = None,
13721372
program: str = "mypy",
13731373
header: str = HEADER,
1374+
mypyc: bool = False,
13741375
) -> tuple[list[BuildSource], Options]:
13751376
"""Parse command line arguments.
13761377
@@ -1398,6 +1399,9 @@ def process_options(
13981399

13991400
options = Options()
14001401
strict_option_set = False
1402+
if mypyc:
1403+
# Mypyc has strict_bytes enabled by default
1404+
options.strict_bytes = True
14011405

14021406
def set_strict_flags() -> None:
14031407
nonlocal strict_option_set

mypy/typeshed/stubs/librt/librt/strings.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from mypy_extensions import i64, u8
55
@final
66
class BytesWriter:
77
def append(self, /, x: int) -> None: ...
8-
def write(self, /, b: bytes) -> None: ...
8+
def write(self, /, b: bytes | bytearray) -> None: ...
99
def getvalue(self) -> bytes: ...
1010
def truncate(self, /, size: i64) -> None: ...
1111
def __len__(self) -> i64: ...

mypyc/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def get_mypy_config(
189189
fscache: FileSystemCache | None,
190190
) -> tuple[list[BuildSource], list[BuildSource], Options]:
191191
"""Construct mypy BuildSources and Options from file and options lists"""
192-
all_sources, options = process_options(mypy_options, fscache=fscache)
192+
all_sources, options = process_options(mypy_options, fscache=fscache, mypyc=True)
193193
if only_compile_paths is not None:
194194
paths_set = set(only_compile_paths)
195195
mypyc_sources = [s for s in all_sources if s.path in paths_set]

mypyc/codegen/emit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def emit_cast(
657657
elif is_bytes_rprimitive(typ):
658658
if declare_dest:
659659
self.emit_line(f"PyObject *{dest};")
660-
check = "(PyBytes_Check({}) || PyByteArray_Check({}))"
660+
check = "(PyBytes_Check({}))"
661661
if likely:
662662
check = f"(likely{check})"
663663
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)

mypyc/test-data/commandline.test

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,13 @@ print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__module__)
312312
[out]
313313
B
314314
pkg2.mod2
315+
316+
[case testStrictBytesRequired]
317+
# cmd: --no-strict-bytes a.py
318+
319+
[file a.py]
320+
def f(b: bytes) -> None: pass
321+
f(bytearray())
322+
323+
[out]
324+
a.py:1: error: Option --strict-bytes cannot be disabled when using mypyc

mypyc/test-data/fixtures/ir.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class bytes:
167167
def __init__(self) -> None: ...
168168
@overload
169169
def __init__(self, x: object) -> None: ...
170-
def __add__(self, x: bytes) -> bytes: ...
170+
def __add__(self, x: bytes | bytearray) -> bytes: ...
171171
def __mul__(self, x: int) -> bytes: ...
172172
def __rmul__(self, x: int) -> bytes: ...
173173
def __eq__(self, x: object) -> bool: ...
@@ -178,8 +178,8 @@ def __getitem__(self, i: int) -> int: ...
178178
def __getitem__(self, i: slice) -> bytes: ...
179179
def join(self, x: Iterable[object]) -> bytes: ...
180180
def decode(self, encoding: str=..., errors: str=...) -> str: ...
181-
def translate(self, t: bytes) -> bytes: ...
182-
def startswith(self, t: bytes) -> bool: ...
181+
def translate(self, t: bytes | bytearray) -> bytes: ...
182+
def startswith(self, t: bytes | bytearray) -> bool: ...
183183
def __iter__(self) -> Iterator[int]: ...
184184

185185
class bytearray:
@@ -189,9 +189,12 @@ def __init__(self) -> None: pass
189189
def __init__(self, x: object) -> None: pass
190190
@overload
191191
def __init__(self, string: str, encoding: str, err: str = ...) -> None: pass
192-
def __add__(self, s: bytes) -> bytearray: ...
192+
def __add__(self, s: bytes | bytearray) -> bytearray: ...
193193
def __setitem__(self, i: int, o: int) -> None: ...
194+
@overload
194195
def __getitem__(self, i: int) -> int: ...
196+
@overload
197+
def __getitem__(self, i: slice) -> bytearray: ...
195198
def decode(self, x: str = ..., y: str = ...) -> str: ...
196199
def startswith(self, t: bytes) -> bool: ...
197200

mypyc/test-data/irbuild-bytes.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,16 @@ L0:
261261
r0 = CPyBytes_Startswith(a, b)
262262
r1 = truncate r0: i32 to builtins.bool
263263
return r1
264+
265+
[case testBytesVsBytearray]
266+
def bytes_func(b: bytes) -> None: pass
267+
def bytearray_func(ba: bytearray) -> None: pass
268+
269+
def foo(b: bytes, ba: bytearray) -> None:
270+
bytes_func(b)
271+
bytearray_func(ba)
272+
bytes_func(ba)
273+
bytearray_func(b)
274+
[out]
275+
main:7: error: Argument 1 to "bytes_func" has incompatible type "bytearray"; expected "bytes"
276+
main:8: error: Argument 1 to "bytearray_func" has incompatible type "bytes"; expected "bytearray"

mypyc/test-data/run-base64.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[case testAllBase64Features_librt_experimental]
2-
from typing import Any
2+
from typing import Any, cast
33
import base64
44
import binascii
55
import random
@@ -14,7 +14,7 @@ def test_encode_basic() -> None:
1414
assert b64encode(b"x") == b"eA=="
1515

1616
with assertRaises(TypeError):
17-
b64encode(bytearray(b"x"))
17+
b64encode(cast(Any, bytearray(b"x")))
1818

1919
def check_encode(b: bytes) -> None:
2020
assert b64encode(b) == getattr(base64, "b64encode")(b)
@@ -56,7 +56,7 @@ def test_decode_basic() -> None:
5656
assert b64decode(b"eA==") == b"x"
5757

5858
with assertRaises(TypeError):
59-
b64decode(bytearray(b"eA=="))
59+
b64decode(cast(Any, bytearray(b"eA==")))
6060

6161
for non_ascii in "\x80", "foo\u100bar", "foo\ua1234bar":
6262
with assertRaises(ValueError):

mypyc/test-data/run-bytes.test

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def test_concat() -> None:
7979
assert type(b1) == bytes
8080
assert type(b2) == bytes
8181
assert type(b3) == bytes
82-
brr1: bytes = bytearray(3)
83-
brr2: bytes = bytearray(range(5))
82+
brr1 = bytearray(3)
83+
brr2 = bytearray(range(5))
8484
b4 = b1 + brr1
8585
assert b4 == b'123\x00\x00\x00'
8686
assert type(brr1) == bytearray
@@ -94,9 +94,9 @@ def test_concat() -> None:
9494
b5 = brr2 + b2
9595
assert b5 == bytearray(b'\x00\x01\x02\x03\x04456')
9696
assert type(b5) == bytearray
97-
b5 = b2 + brr2
98-
assert b5 == b'456\x00\x01\x02\x03\x04'
99-
assert type(b5) == bytes
97+
b6 = b2 + brr2
98+
assert b6 == b'456\x00\x01\x02\x03\x04'
99+
assert type(b6) == bytes
100100

101101
def test_join() -> None:
102102
seq = (b'1', b'"', b'\xf0')
@@ -217,9 +217,9 @@ def test_startswith() -> None:
217217
assert test.startswith(bytearray(b'some'))
218218
assert not test.startswith(bytearray(b'other'))
219219

220-
test = bytearray(b'some string')
221-
assert test.startswith(b'some')
222-
assert not test.startswith(b'other')
220+
test2 = bytearray(b'some string')
221+
assert test2.startswith(b'some')
222+
assert not test2.startswith(b'other')
223223

224224
[case testBytesSlicing]
225225
def test_bytes_slicing() -> None:
@@ -257,34 +257,38 @@ def test_bytes_slicing() -> None:
257257
[case testBytearrayBasics]
258258
from typing import Any
259259

260+
from testutil import assertRaises
261+
260262
def test_basics() -> None:
261-
brr1: bytes = bytearray(3)
263+
brr1 = bytearray(3)
262264
assert brr1 == bytearray(b'\x00\x00\x00')
263265
assert brr1 == b'\x00\x00\x00'
264266
l = [10, 20, 30, 40]
265-
brr2: bytes = bytearray(l)
267+
brr2 = bytearray(l)
266268
assert brr2 == bytearray(b'\n\x14\x1e(')
267269
assert brr2 == b'\n\x14\x1e('
268-
brr3: bytes = bytearray(range(5))
270+
brr3 = bytearray(range(5))
269271
assert brr3 == bytearray(b'\x00\x01\x02\x03\x04')
270272
assert brr3 == b'\x00\x01\x02\x03\x04'
271-
brr4: bytes = bytearray('string', 'utf-8')
273+
brr4 = bytearray('string', 'utf-8')
272274
assert brr4 == bytearray(b'string')
273275
assert brr4 == b'string'
274276
assert len(brr1) == 3
275277
assert len(brr2) == 4
276278

277-
def f(b: bytes) -> bool:
278-
return True
279+
def f(b: bytes) -> str:
280+
return "xy"
279281

280282
def test_bytearray_passed_into_bytes() -> None:
281-
assert f(bytearray(3))
282283
brr1: Any = bytearray()
283-
assert f(brr1)
284+
with assertRaises(TypeError, "bytes object expected; got bytearray"):
285+
f(brr1)
286+
with assertRaises(TypeError, "bytes object expected; got bytearray"):
287+
b: bytes = brr1
284288

285289
[case testBytearraySlicing]
286290
def test_bytearray_slicing() -> None:
287-
b: bytes = bytearray(b'abcdefg')
291+
b = bytearray(b'abcdefg')
288292
zero = int()
289293
ten = 10 + zero
290294
two = 2 + zero
@@ -318,7 +322,7 @@ def test_bytearray_slicing() -> None:
318322
from testutil import assertRaises
319323

320324
def test_bytearray_indexing() -> None:
321-
b: bytes = bytearray(b'\xae\x80\xfe\x15')
325+
b = bytearray(b'\xae\x80\xfe\x15')
322326
assert b[0] == 174
323327
assert b[1] == 128
324328
assert b[2] == 254
@@ -347,10 +351,6 @@ def test_bytes_join() -> None:
347351
assert b' '.join([b'a', b'b']) == b'a b'
348352
assert b' '.join([]) == b''
349353

350-
x: bytes = bytearray(b' ')
351-
assert x.join([b'a', b'b']) == b'a b'
352-
assert type(x.join([b'a', b'b'])) == bytearray
353-
354354
y: bytes = bytes_subclass()
355355
assert y.join([]) == b'spook'
356356

0 commit comments

Comments
 (0)