Skip to content

Commit 57aa7ce

Browse files
committed
Introduce an overloading decorator
To give class and regular methods the same name but different behaviour Signed-off-by: liamhuber <[email protected]>
1 parent 65d4654 commit 57aa7ce

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

pyiron_workflow/overloading.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import functools
2+
3+
4+
def overloaded_classmethod(class_method):
5+
"""
6+
Decorator to define a method that behaves like both a classmethod and an
7+
instancemethod under the same name.
8+
9+
Args:
10+
instance_method: A method defined on the same object as the decorated method,
11+
to be used when an instance of the object calls the decorated method (
12+
instead of a class call)
13+
14+
Returns
15+
-------
16+
descriptor
17+
A descriptor that dispatches to the classmethod when accessed
18+
via the class, and to the given instance method when accessed
19+
via an instance.
20+
21+
Examples:
22+
>>> class Foo:
23+
... def __init__(self, y):
24+
... self.y = y
25+
...
26+
... @classmethod
27+
... def _doit_classmethod(cls, x):
28+
... return f"Class {cls.__name__} doing {x}"
29+
...
30+
... @overloaded_classmethod(class_method=_doit_classmethod)
31+
... def doit(self, x):
32+
... return f"Instance of type {type(self).__name__} doing {x} + {self.y}"
33+
...
34+
>>> Foo.doit(10)
35+
'Class Foo doing 10'
36+
>>> Foo(5).doit(20)
37+
'Instance of type Foo doing 20 + 5'
38+
"""
39+
40+
class Overloaded:
41+
def __init__(self, f_instance, f_class):
42+
self.f_instance = f_instance
43+
self.f_class = f_class
44+
functools.update_wrapper(self, f_instance)
45+
46+
def __get__(self, obj, cls):
47+
if obj is None:
48+
f_class = (
49+
cls.__dict__[self.f_class]
50+
if isinstance(self.f_class, str)
51+
else self.f_class
52+
)
53+
54+
if isinstance(f_class, classmethod):
55+
f_class = f_class.__func__
56+
57+
@functools.wraps(self.f_class)
58+
def bound(*args, **kwargs):
59+
return f_class(cls, *args, **kwargs)
60+
61+
return bound
62+
else:
63+
64+
@functools.wraps(self.f_instance)
65+
def bound(*args, **kwargs):
66+
return self.f_instance(obj, *args, **kwargs)
67+
68+
return bound
69+
70+
def wrapper(f_instance):
71+
return Overloaded(f_instance, class_method)
72+
73+
return wrapper

tests/unit/test_overloading.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
3+
from pyiron_workflow.overloading import (
4+
overloaded_classmethod,
5+
) # replace with actual module path
6+
7+
8+
def class_string(obj, x):
9+
return f"Class {obj.__name__} doing {x}"
10+
11+
12+
def instance_string(obj, x):
13+
return f"Instance of type {type(obj).__name__} doing {x}"
14+
15+
16+
class Foo:
17+
@overloaded_classmethod(class_method="_pseudo_classmethod")
18+
def undecorated_string(self, x):
19+
return instance_string(self, x)
20+
21+
@overloaded_classmethod(class_method="_classmethod")
22+
def decorated_string(self, x):
23+
return instance_string(self, x)
24+
25+
def _pseudo_classmethod(cls, x):
26+
return class_string(cls, x)
27+
28+
@classmethod
29+
def _classmethod(cls, y):
30+
return class_string(cls, y)
31+
32+
@overloaded_classmethod(class_method=_pseudo_classmethod)
33+
def undecorated_direct(self, x):
34+
return instance_string(self, x)
35+
36+
@overloaded_classmethod(class_method=_classmethod)
37+
def decorated_direct(self, x):
38+
return instance_string(self, x)
39+
40+
41+
class TestOverloadedClassMethod(unittest.TestCase):
42+
def test_instance_and_class_calls(self):
43+
self.assertEqual(Foo.undecorated_string(1), class_string(Foo, 1))
44+
self.assertEqual(Foo.decorated_string(2), class_string(Foo, 2))
45+
self.assertEqual(Foo.undecorated_direct(3), class_string(Foo, 3))
46+
self.assertEqual(Foo.decorated_direct(4), class_string(Foo, 4))
47+
48+
self.assertEqual(Foo().undecorated_string(1), instance_string(Foo(), 1))
49+
self.assertEqual(Foo().decorated_string(2), instance_string(Foo(), 2))
50+
self.assertEqual(Foo().undecorated_direct(3), instance_string(Foo(), 3))
51+
self.assertEqual(Foo().decorated_direct(4), instance_string(Foo(), 4))
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main()

0 commit comments

Comments
 (0)