Skip to content

Add @disjoint_base decorator in the stdlib #14599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/stubtest_stdlib.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,8 @@ jobs:
check-latest: true
- name: Install dependencies
run: pip install -r requirements-tests.txt
# Temporary to get @disjoint_base support; can remove once mypy 1.18 is released
- name: Install mypy from git
run: pip install git+https://github.com/JelleZijlstra/mypy.git@03ce7f0f0ece2b0dbcda701f6df1488eda484363
- name: Run stubtest
run: python tests/stubtest_stdlib.py
1 change: 1 addition & 0 deletions stdlib/@tests/stubtest_allowlists/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ tkinter.simpledialog.[A-Z_]+
tkinter.simpledialog.TclVersion
tkinter.simpledialog.TkVersion
tkinter.Text.count # stubtest somehow thinks that index1 parameter has a default value, but it doesn't in any of the overloads
builtins.tuple # should have @disjoint_base but hits pyright issue


# ===============================================================
Expand Down
4 changes: 3 additions & 1 deletion stdlib/_asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ from collections.abc import Awaitable, Callable, Coroutine, Generator
from contextvars import Context
from types import FrameType, GenericAlias
from typing import Any, Literal, TextIO, TypeVar
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, disjoint_base

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_TaskYieldType: TypeAlias = Future[object] | None

@disjoint_base
class Future(Awaitable[_T]):
_state: str
@property
Expand Down Expand Up @@ -49,6 +50,7 @@ else:
# While this is true in general, here it's sort-of okay to have a covariant subclass,
# since the only reason why `asyncio.Future` is invariant is the `set_result()` method,
# and `asyncio.Task.set_result()` always raises.
@disjoint_base
class Task(Future[_T_co]): # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments]
if sys.version_info >= (3, 12):
def __init__(
Expand Down
5 changes: 4 additions & 1 deletion stdlib/_csv.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import sys
from _typeshed import SupportsWrite
from collections.abc import Iterable
from typing import Any, Final, Literal, type_check_only
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, disjoint_base

__version__: Final[str]

Expand All @@ -24,6 +24,7 @@ class Error(Exception): ...

_DialectLike: TypeAlias = str | Dialect | csv.Dialect | type[Dialect | csv.Dialect]

@disjoint_base
class Dialect:
delimiter: str
quotechar: str | None
Expand All @@ -48,6 +49,7 @@ class Dialect:

if sys.version_info >= (3, 10):
# This class calls itself _csv.reader.
@disjoint_base
class Reader:
@property
def dialect(self) -> Dialect: ...
Expand All @@ -56,6 +58,7 @@ if sys.version_info >= (3, 10):
def __next__(self) -> list[str]: ...

# This class calls itself _csv.writer.
@disjoint_base
class Writer:
@property
def dialect(self) -> Dialect: ...
Expand Down
3 changes: 2 additions & 1 deletion stdlib/_hashlib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from _typeshed import ReadableBuffer
from collections.abc import Callable
from types import ModuleType
from typing import AnyStr, Protocol, final, overload, type_check_only
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, disjoint_base

_DigestMod: TypeAlias = str | Callable[[], _HashObject] | ModuleType | None

Expand All @@ -22,6 +22,7 @@ class _HashObject(Protocol):
def hexdigest(self) -> str: ...
def update(self, obj: ReadableBuffer, /) -> None: ...

@disjoint_base
class HASH:
@property
def digest_size(self) -> int: ...
Expand Down
3 changes: 2 additions & 1 deletion stdlib/_interpreters.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import types
from collections.abc import Callable
from typing import Any, Final, Literal, SupportsIndex, TypeVar, overload
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, disjoint_base

_R = TypeVar("_R")

Expand All @@ -12,6 +12,7 @@ class InterpreterError(Exception): ...
class InterpreterNotFoundError(InterpreterError): ...
class NotShareableError(ValueError): ...

@disjoint_base
class CrossInterpreterBufferView:
def __buffer__(self, flags: int, /) -> memoryview: ...

Expand Down
93 changes: 66 additions & 27 deletions stdlib/_io.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from io import BufferedIOBase, RawIOBase, TextIOBase, UnsupportedOperation as Un
from os import _Opener
from types import TracebackType
from typing import IO, Any, BinaryIO, Final, Generic, Literal, Protocol, TextIO, TypeVar, overload, type_check_only
from typing_extensions import Self
from typing_extensions import Self, disjoint_base

_T = TypeVar("_T")

Expand All @@ -22,32 +22,62 @@ def open_code(path: str) -> IO[bytes]: ...

BlockingIOError = builtins.BlockingIOError

class _IOBase:
def __iter__(self) -> Iterator[bytes]: ...
def __next__(self) -> bytes: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ...
def close(self) -> None: ...
def fileno(self) -> int: ...
def flush(self) -> None: ...
def isatty(self) -> bool: ...
def readable(self) -> bool: ...
read: Callable[..., Any]
def readlines(self, hint: int = -1, /) -> list[bytes]: ...
def seek(self, offset: int, whence: int = 0, /) -> int: ...
def seekable(self) -> bool: ...
def tell(self) -> int: ...
def truncate(self, size: int | None = None, /) -> int: ...
def writable(self) -> bool: ...
write: Callable[..., Any]
def writelines(self, lines: Iterable[ReadableBuffer], /) -> None: ...
def readline(self, size: int | None = -1, /) -> bytes: ...
def __del__(self) -> None: ...
@property
def closed(self) -> bool: ...
def _checkClosed(self) -> None: ... # undocumented
if sys.version_info >= (3, 12):
@disjoint_base
class _IOBase:
def __iter__(self) -> Iterator[bytes]: ...
def __next__(self) -> bytes: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ...
def close(self) -> None: ...
def fileno(self) -> int: ...
def flush(self) -> None: ...
def isatty(self) -> bool: ...
def readable(self) -> bool: ...
read: Callable[..., Any]
def readlines(self, hint: int = -1, /) -> list[bytes]: ...
def seek(self, offset: int, whence: int = 0, /) -> int: ...
def seekable(self) -> bool: ...
def tell(self) -> int: ...
def truncate(self, size: int | None = None, /) -> int: ...
def writable(self) -> bool: ...
write: Callable[..., Any]
def writelines(self, lines: Iterable[ReadableBuffer], /) -> None: ...
def readline(self, size: int | None = -1, /) -> bytes: ...
def __del__(self) -> None: ...
@property
def closed(self) -> bool: ...
def _checkClosed(self) -> None: ... # undocumented

else:
class _IOBase:
def __iter__(self) -> Iterator[bytes]: ...
def __next__(self) -> bytes: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ...
def close(self) -> None: ...
def fileno(self) -> int: ...
def flush(self) -> None: ...
def isatty(self) -> bool: ...
def readable(self) -> bool: ...
read: Callable[..., Any]
def readlines(self, hint: int = -1, /) -> list[bytes]: ...
def seek(self, offset: int, whence: int = 0, /) -> int: ...
def seekable(self) -> bool: ...
def tell(self) -> int: ...
def truncate(self, size: int | None = None, /) -> int: ...
def writable(self) -> bool: ...
write: Callable[..., Any]
def writelines(self, lines: Iterable[ReadableBuffer], /) -> None: ...
def readline(self, size: int | None = -1, /) -> bytes: ...
def __del__(self) -> None: ...
@property
def closed(self) -> bool: ...
def _checkClosed(self) -> None: ... # undocumented

class _RawIOBase(_IOBase):
def readall(self) -> bytes: ...
Expand All @@ -65,6 +95,7 @@ class _BufferedIOBase(_IOBase):
def read(self, size: int | None = -1, /) -> bytes: ...
def read1(self, size: int = -1, /) -> bytes: ...

@disjoint_base
class FileIO(RawIOBase, _RawIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes
mode: str
# The type of "name" equals the argument passed in to the constructor,
Expand All @@ -79,6 +110,7 @@ class FileIO(RawIOBase, _RawIOBase, BinaryIO): # type: ignore[misc] # incompat
def seek(self, pos: int, whence: int = 0, /) -> int: ...
def read(self, size: int | None = -1, /) -> bytes | MaybeNone: ...

@disjoint_base
class BytesIO(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of methods in the base classes
def __init__(self, initial_bytes: ReadableBuffer = b"") -> None: ...
# BytesIO does not contain a "name" field. This workaround is necessary
Expand Down Expand Up @@ -119,6 +151,7 @@ class _BufferedReaderStream(Protocol):

_BufferedReaderStreamT = TypeVar("_BufferedReaderStreamT", bound=_BufferedReaderStream, default=_BufferedReaderStream)

@disjoint_base
class BufferedReader(BufferedIOBase, _BufferedIOBase, BinaryIO, Generic[_BufferedReaderStreamT]): # type: ignore[misc] # incompatible definitions of methods in the base classes
raw: _BufferedReaderStreamT
if sys.version_info >= (3, 14):
Expand All @@ -130,6 +163,7 @@ class BufferedReader(BufferedIOBase, _BufferedIOBase, BinaryIO, Generic[_Buffere
def seek(self, target: int, whence: int = 0, /) -> int: ...
def truncate(self, pos: int | None = None, /) -> int: ...

@disjoint_base
class BufferedWriter(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes
raw: RawIOBase
if sys.version_info >= (3, 14):
Expand All @@ -141,6 +175,7 @@ class BufferedWriter(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore
def seek(self, target: int, whence: int = 0, /) -> int: ...
def truncate(self, pos: int | None = None, /) -> int: ...

@disjoint_base
class BufferedRandom(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of methods in the base classes
mode: str
name: Any
Expand All @@ -154,6 +189,7 @@ class BufferedRandom(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore
def peek(self, size: int = 0, /) -> bytes: ...
def truncate(self, pos: int | None = None, /) -> int: ...

@disjoint_base
class BufferedRWPair(BufferedIOBase, _BufferedIOBase, Generic[_BufferedReaderStreamT]):
if sys.version_info >= (3, 14):
def __init__(self, reader: _BufferedReaderStreamT, writer: RawIOBase, buffer_size: int = 131072, /) -> None: ...
Expand Down Expand Up @@ -200,6 +236,7 @@ class _WrappedBuffer(Protocol):

_BufferT_co = TypeVar("_BufferT_co", bound=_WrappedBuffer, default=_WrappedBuffer, covariant=True)

@disjoint_base
class TextIOWrapper(TextIOBase, _TextIOBase, TextIO, Generic[_BufferT_co]): # type: ignore[misc] # incompatible definitions of write in the base classes
def __init__(
self,
Expand Down Expand Up @@ -234,6 +271,7 @@ class TextIOWrapper(TextIOBase, _TextIOBase, TextIO, Generic[_BufferT_co]): # t
def seek(self, cookie: int, whence: int = 0, /) -> int: ...
def truncate(self, pos: int | None = None, /) -> int: ...

@disjoint_base
class StringIO(TextIOBase, _TextIOBase, TextIO): # type: ignore[misc] # incompatible definitions of write in the base classes
def __init__(self, initial_value: str | None = "", newline: str | None = "\n") -> None: ...
# StringIO does not contain a "name" field. This workaround is necessary
Expand All @@ -246,6 +284,7 @@ class StringIO(TextIOBase, _TextIOBase, TextIO): # type: ignore[misc] # incomp
def seek(self, pos: int, whence: int = 0, /) -> int: ...
def truncate(self, pos: int | None = None, /) -> int: ...

@disjoint_base
class IncrementalNewlineDecoder:
def __init__(self, decoder: codecs.IncrementalDecoder | None, translate: bool, errors: str = "strict") -> None: ...
def decode(self, input: ReadableBuffer | str, final: bool = False) -> str: ...
Expand Down
2 changes: 2 additions & 0 deletions stdlib/_lsprof.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ from _typeshed import structseq
from collections.abc import Callable
from types import CodeType
from typing import Any, Final, final
from typing_extensions import disjoint_base

@disjoint_base
class Profiler:
def __init__(
self, timer: Callable[[], float] | None = None, timeunit: float = 0.0, subcalls: bool = True, builtins: bool = True
Expand Down
5 changes: 5 additions & 0 deletions stdlib/_multibytecodec.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from _typeshed import ReadableBuffer
from codecs import _ReadableStream, _WritableStream
from collections.abc import Iterable
from typing import final, type_check_only
from typing_extensions import disjoint_base

# This class is not exposed. It calls itself _multibytecodec.MultibyteCodec.
@final
Expand All @@ -10,6 +11,7 @@ class _MultibyteCodec:
def decode(self, input: ReadableBuffer, errors: str | None = None) -> str: ...
def encode(self, input: str, errors: str | None = None) -> bytes: ...

@disjoint_base
class MultibyteIncrementalDecoder:
errors: str
def __init__(self, errors: str = "strict") -> None: ...
Expand All @@ -18,6 +20,7 @@ class MultibyteIncrementalDecoder:
def reset(self) -> None: ...
def setstate(self, state: tuple[bytes, int], /) -> None: ...

@disjoint_base
class MultibyteIncrementalEncoder:
errors: str
def __init__(self, errors: str = "strict") -> None: ...
Expand All @@ -26,6 +29,7 @@ class MultibyteIncrementalEncoder:
def reset(self) -> None: ...
def setstate(self, state: int, /) -> None: ...

@disjoint_base
class MultibyteStreamReader:
errors: str
stream: _ReadableStream
Expand All @@ -35,6 +39,7 @@ class MultibyteStreamReader:
def readlines(self, sizehintobj: int | None = None, /) -> list[str]: ...
def reset(self) -> None: ...

@disjoint_base
class MultibyteStreamWriter:
errors: str
stream: _WritableStream
Expand Down
4 changes: 3 additions & 1 deletion stdlib/_pickle.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from _typeshed import ReadableBuffer, SupportsWrite
from collections.abc import Callable, Iterable, Iterator, Mapping
from pickle import PickleBuffer as PickleBuffer
from typing import Any, Protocol, type_check_only
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, disjoint_base

@type_check_only
class _ReadableFileobj(Protocol):
Expand Down Expand Up @@ -57,6 +57,7 @@ class PicklerMemoProxy:
def clear(self, /) -> None: ...
def copy(self, /) -> dict[int, tuple[int, Any]]: ...

@disjoint_base
class Pickler:
fast: bool
dispatch_table: Mapping[type, Callable[[Any], _ReducedType]]
Expand Down Expand Up @@ -84,6 +85,7 @@ class UnpicklerMemoProxy:
def clear(self, /) -> None: ...
def copy(self, /) -> dict[int, tuple[int, Any]]: ...

@disjoint_base
class Unpickler:
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions stdlib/_queue.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from types import GenericAlias
from typing import Any, Generic, TypeVar
from typing_extensions import disjoint_base

_T = TypeVar("_T")

class Empty(Exception): ...

@disjoint_base
class SimpleQueue(Generic[_T]):
def __init__(self) -> None: ...
def empty(self) -> bool: ...
Expand Down
3 changes: 2 additions & 1 deletion stdlib/_random.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, disjoint_base

# Actually Tuple[(int,) * 625]
_State: TypeAlias = tuple[int, ...]

@disjoint_base
class Random:
if sys.version_info >= (3, 10):
def __init__(self, seed: object = ..., /) -> None: ...
Expand Down
3 changes: 2 additions & 1 deletion stdlib/_socket.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from _typeshed import ReadableBuffer, WriteableBuffer
from collections.abc import Iterable
from socket import error as error, gaierror as gaierror, herror as herror, timeout as timeout
from typing import Any, Final, SupportsIndex, overload
from typing_extensions import CapsuleType, TypeAlias
from typing_extensions import CapsuleType, TypeAlias, disjoint_base

_CMSG: TypeAlias = tuple[int, int, bytes]
_CMSGArg: TypeAlias = tuple[int, int, ReadableBuffer]
Expand Down Expand Up @@ -731,6 +731,7 @@ if sys.platform != "win32" and sys.platform != "darwin":

# ===== Classes =====

@disjoint_base
class socket:
@property
def family(self) -> int: ...
Expand Down
Loading
Loading