Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 69 additions & 40 deletions src/ducktools/classbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@
from ._version import __version__, __version_tuple__ # noqa: F401

try:
from ._cached_methods import eq_cache, replace_cache, repr_cache
from ._cached_methods import eq_cache, replace_cache, repr_cache, delattr_cache
except ImportError: # pragma: nocover
# Needed for generating cached methods after deletion
eq_cache = {}
replace_cache = {}
repr_cache = {}
delattr_cache = {}


# Change this name if you make heavy modifications
Expand Down Expand Up @@ -375,11 +376,15 @@ def __get__(self, inst, cls):
)

if self.cached_generator:
gen, method = self.cached_generator(gen_cls, self.funcname)
method = self.cached_generator(gen_cls, self.funcname)
else:
gen = self.code_generator(gen_cls, self.funcname)
method = gen.generate()

# Annotations are only supported in non-cached generators
if gen.annotations:
apply_annotations(method, gen.annotations)

# Patch up the method name and annotations
try:
method.__qualname__ = f"{cls.__qualname__}.{self.funcname}"
Expand All @@ -388,9 +393,6 @@ def __get__(self, inst, cls):
# descriptor. Don't try to rename.
pass

if gen.annotations:
apply_annotations(method, gen.annotations)

if self.decorator:
method = self.decorator(method)

Expand Down Expand Up @@ -478,13 +480,13 @@ def _fix_consts(consts, active_pair, pairs):
return tuple(new_consts)


def counter_to_class_generator(
def convert_to_class_generator(
generic_generator,
argument_getter,
cache=None,
replace_strings=False,
):
# This takes a counting source generator and converts it into a function
# This takes a counting or no argument source generator and converts it into a function
# generator with cached methods backing it
@_simple_cache(cache_seed=cache)
def source_exec(*args, funcname):
Expand All @@ -494,11 +496,14 @@ def source_exec(*args, funcname):

def method_generator(cls, funcname):
args = argument_getter(cls)
argnames = args[0]
argcount = len(args[0])
exec_args = (argcount, *args[1:])
if len(args) > 0:
argnames = args[0]
argcount = len(args[0])
exec_args = (argcount, *args[1:])
else:
argnames = []
exec_args = ()

gen = generic_generator(*exec_args, funcname=funcname)
raw_func = source_exec(*exec_args, funcname=funcname)

arg_fixes = {
Expand All @@ -525,12 +530,14 @@ def method_generator(cls, funcname):
new_co_names = co_names
new_co_consts = co_consts

globs = {}

method = _FunctionType(
raw_func.__code__.replace(
co_names=new_co_names,
co_consts=new_co_consts,
),
gen.globs,
globs,
name=funcname,
argdefs=raw_func.__defaults__,
closure=raw_func.__closure__,
Expand All @@ -540,7 +547,7 @@ def method_generator(cls, funcname):
# Remove the module reference to avoid retrieving incorrect code
method.__module__ = None # type: ignore

return gen, method
return method

method_generator.get_stats = source_exec.get_stats # type: ignore
method_generator.clear_cache = source_exec.clear_cache # type: ignore
Expand Down Expand Up @@ -641,7 +648,6 @@ def class_repr_generator(cls, funcname="__repr__"):
return generic_repr_generator(field_names, funcname=funcname)


@_simple_cache()
def _counter_repr_generator(argcount, *, funcname="__repr__"):
field_names = [
f"{REPLACE_NAME}{i}_"
Expand Down Expand Up @@ -685,7 +691,6 @@ def class_eq_generator(cls, funcname="__eq__"):
return generic_eq_generator(field_names, funcname=funcname)


@_simple_cache()
def _counter_eq_generator(argcount, *, funcname="__eq__"):
# This is a cached accelerated eq generator
# It returns uglier source, but the source can be cached
Expand Down Expand Up @@ -746,28 +751,25 @@ def _get_counter_order_generator(argcount, operator, *, funcname):
def class_lt_generator(cls, funcname="__lt__"):
return get_class_order_generator(cls, "<", funcname=funcname)

@_simple_cache()
def _counter_lt_generator(argcount, *, funcname="__lt__"):
return _get_counter_order_generator(argcount, "<", funcname=funcname)

def class_le_generator(cls, funcname="__le__"):
return get_class_order_generator(cls, "<=", funcname=funcname)

@_simple_cache()
def _counter_le_generator(argcount, *, funcname="__le__"):
return _get_counter_order_generator(argcount, "<=", funcname=funcname)

def class_gt_generator(cls, funcname="__gt__"):
return get_class_order_generator(cls, ">", funcname=funcname)

@_simple_cache()

def _counter_gt_generator(argcount, *, funcname="__gt__"):
return _get_counter_order_generator(argcount, ">", funcname=funcname)

def class_ge_generator(cls, funcname="__ge__"):
return get_class_order_generator(cls, ">=", funcname=funcname)

@_simple_cache()
def _counter_ge_generator(argcount, *, funcname="__ge__"):
return _get_counter_order_generator(argcount, ">=", funcname=funcname)

Expand Down Expand Up @@ -816,7 +818,6 @@ def _field_replace_generator(cls, funcname="__replace__"):
]
return generic_replace_generator(field_pairs, funcname=funcname)

@_simple_cache()
def _counter_replace_generator(argcount, *, funcname="__replace__"):
field_pairs = [
(f"{REPLACE_NAME}{i}_", f"{REPLACE_NAME}{i}_") for i in range(argcount)
Expand Down Expand Up @@ -857,39 +858,52 @@ def frozen_setattr_generator(cls, funcname="__setattr__"):
return GeneratedCode(code, globs)


def frozen_delattr_generator(cls, funcname="__delattr__"):
def generic_frozen_delattr_generator(*, funcname="__delattr__"):
body = (
' raise TypeError(\n'
' f"{type(self).__name__!r} object "\n'
' f"does not support attribute deletion"\n'
' f"{type(self).__name__!r} object does not support attribute deletion"\n'
' )\n'
)
code = f"def {funcname}(self, name):\n{body}"
globs = {}
return GeneratedCode(code, globs)


def hash_generator(cls, funcname="__hash__"):
fields = get_fields(cls)
vals = ", ".join(
f"self.{name}"
for name, attrib in fields.items()
if attrib.compare
)
if len(fields) == 1:
def frozen_delattr_generator(cls, funcname="__delattr__"):
return generic_frozen_delattr_generator(funcname=funcname)


def generic_hash_generator(field_names, *, funcname="__hash__"):
vals = ", ".join(f"self.{name}" for name in field_names)
if len(field_names) == 1:
# Needs a trailing comma for only 1 argument
# to make a tuple
vals += ","

code = f"def {funcname}(self):\n return hash(({vals}))\n"
globs = {}
return GeneratedCode(code, globs)


def _counter_hash_generator(argcount, *, funcname="__hash__"):
field_names = [
f"{REPLACE_NAME}{i}_" for i in range(argcount)
]
return generic_hash_generator(field_names, funcname=funcname)


def hash_generator(cls, funcname="__hash__"):
field_names = [name for name, attrib in get_fields(cls).items() if attrib.compare]
return generic_hash_generator(field_names, funcname=funcname)


# As only the __get__ method refers to the class we can use the same
# Descriptor instances for every class.
init_maker = MethodMaker("__init__", init_generator)
repr_maker = MethodMaker(
"__repr__",
class_repr_generator,
cached_generator=counter_to_class_generator(
cached_generator=convert_to_class_generator(
_counter_repr_generator,
get_repr_args,
cache=repr_cache,
Expand All @@ -900,7 +914,7 @@ def hash_generator(cls, funcname="__hash__"):
eq_maker = MethodMaker(
"__eq__",
class_eq_generator,
cached_generator=counter_to_class_generator(
cached_generator=convert_to_class_generator(
_counter_eq_generator,
get_compare_args,
cache=eq_cache,
Expand All @@ -909,48 +923,63 @@ def hash_generator(cls, funcname="__hash__"):
lt_maker = MethodMaker(
"__lt__",
class_lt_generator,
cached_generator=counter_to_class_generator(
cached_generator=convert_to_class_generator(
_counter_lt_generator,
get_compare_args,
),
)
le_maker = MethodMaker(
"__le__",
class_le_generator,
cached_generator=counter_to_class_generator(
cached_generator=convert_to_class_generator(
_counter_le_generator,
get_compare_args,
),
)
gt_maker = MethodMaker(
"__gt__",
class_gt_generator,
cached_generator=counter_to_class_generator(
cached_generator=convert_to_class_generator(
_counter_gt_generator,
get_compare_args,
),
)
ge_maker = MethodMaker(
"__ge__",
class_ge_generator,
cached_generator=counter_to_class_generator(
cached_generator=convert_to_class_generator(
_counter_ge_generator,
get_compare_args,
),
)
replace_maker = MethodMaker(
"__replace__",
class_replace_generator,
cached_generator=counter_to_class_generator(
cached_generator=convert_to_class_generator(
_counter_replace_generator,
get_replace_args,
cache=replace_cache,
replace_strings=True,
),
)
frozen_setattr_maker = MethodMaker("__setattr__", frozen_setattr_generator)
frozen_delattr_maker = MethodMaker("__delattr__", frozen_delattr_generator)
hash_maker = MethodMaker("__hash__", hash_generator)
frozen_delattr_maker = MethodMaker(
"__delattr__",
frozen_delattr_generator,
cached_generator=convert_to_class_generator(
generic_frozen_delattr_generator,
lambda cls: (),
cache=delattr_cache,
)
)
hash_maker = MethodMaker(
"__hash__",
hash_generator,
cached_generator=convert_to_class_generator(
_counter_hash_generator,
get_compare_args,
)
)
default_methods = frozenset({init_maker, repr_maker, eq_maker})

# Special `__init__` maker for 'Field' subclasses - needs its own NOTHING option
Expand Down
10 changes: 7 additions & 3 deletions src/ducktools/classbuilder/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import typing_extensions

__lazy_modules__: list[str]

from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Mapping
from types import MappingProxyType

if sys.version_info >= (3, 14):
Expand All @@ -32,7 +32,7 @@ REPLACE_NAME: str

@typing.type_check_only
class GetFieldsProtocol(typing.Protocol):
def __call__(self, cls: type, *, local: bool = ...) -> dict[str, Field]: ...
def __call__(self, cls: type, *, local: bool = ...) -> Mapping[str, Field]: ...

def get_fields(cls: type, *, local: bool = ...) -> dict[str, Field]: ...

Expand Down Expand Up @@ -124,7 +124,7 @@ def get_compare_args(cls: type) -> tuple[str, ...]: ...
def get_repr_args(cls: type) -> tuple[str, ...]: ...
def get_replace_args(cls: type) -> tuple[str, ...]: ...

def counter_to_class_generator(
def convert_to_class_generator(
generic_generator: _ArgcountCodegenType,
argument_getter: Callable[[type], tuple],
cache: None | dict[str, types.FunctionType] = ...,
Expand Down Expand Up @@ -156,7 +156,11 @@ def generic_replace_generator(field_pairs: list[tuple[str, str]], *, funcname: s
def class_replace_generator(cls: type, funcname: str = ...) -> GeneratedCode: ...

def frozen_setattr_generator(cls: type, funcname: str = ...) -> GeneratedCode: ...

def generic_frozen_delattr_generator(*, funcname: str = ...) -> GeneratedCode: ...
def frozen_delattr_generator(cls: type, funcname: str = ...) -> GeneratedCode: ...

def generic_hash_generator(field_names: list[str], *, funcname: str = ...) -> GeneratedCode: ...
def hash_generator(cls: type, funcname: str = ...) -> GeneratedCode: ...

init_maker: MethodMaker
Expand Down
56 changes: 56 additions & 0 deletions src/ducktools/classbuilder/_cached_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,59 @@ def _replace_10(self, /, **changes):
(10,): _replace_10,
}

def _hash_0(self):
return hash(())

def _hash_1(self):
return hash((self._classbuilder_cache_names_0_,))

def _hash_2(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_))

def _hash_3(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_))

def _hash_4(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_, self._classbuilder_cache_names_3_))

def _hash_5(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_, self._classbuilder_cache_names_3_, self._classbuilder_cache_names_4_))

def _hash_6(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_, self._classbuilder_cache_names_3_, self._classbuilder_cache_names_4_, self._classbuilder_cache_names_5_))

def _hash_7(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_, self._classbuilder_cache_names_3_, self._classbuilder_cache_names_4_, self._classbuilder_cache_names_5_, self._classbuilder_cache_names_6_))

def _hash_8(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_, self._classbuilder_cache_names_3_, self._classbuilder_cache_names_4_, self._classbuilder_cache_names_5_, self._classbuilder_cache_names_6_, self._classbuilder_cache_names_7_))

def _hash_9(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_, self._classbuilder_cache_names_3_, self._classbuilder_cache_names_4_, self._classbuilder_cache_names_5_, self._classbuilder_cache_names_6_, self._classbuilder_cache_names_7_, self._classbuilder_cache_names_8_))

def _hash_10(self):
return hash((self._classbuilder_cache_names_0_, self._classbuilder_cache_names_1_, self._classbuilder_cache_names_2_, self._classbuilder_cache_names_3_, self._classbuilder_cache_names_4_, self._classbuilder_cache_names_5_, self._classbuilder_cache_names_6_, self._classbuilder_cache_names_7_, self._classbuilder_cache_names_8_, self._classbuilder_cache_names_9_))

hash_cache = {
(0,): _hash_0,
(1,): _hash_1,
(2,): _hash_2,
(3,): _hash_3,
(4,): _hash_4,
(5,): _hash_5,
(6,): _hash_6,
(7,): _hash_7,
(8,): _hash_8,
(9,): _hash_9,
(10,): _hash_10,
}

def _delattr(self, name):
raise TypeError(
f"{type(self).__name__!r} object does not support attribute deletion"
)

delattr_cache = {
(): _delattr,
}

Loading