Skip to content

Commit b8f3f8a

Browse files
authored
Merge pull request #759 from pyiron/loading
Better loading
2 parents 529c2d7 + a31e284 commit b8f3f8a

File tree

3 files changed

+229
-26
lines changed

3 files changed

+229
-26
lines changed

pyiron_workflow/node.py

Lines changed: 101 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pyiron_snippets.colors import SeabornColors
1919
from pyiron_snippets.dotdict import DotDict
2020

21+
from pyiron_workflow import overloading
2122
from pyiron_workflow.channels import (
2223
AccumulatingInputSignal,
2324
Channel,
@@ -978,15 +979,56 @@ def save_checkpoint(self, backend: BackendIdentifier | StorageInterface = "pickl
978979
"""
979980
self.graph_root.save(backend=backend)
980981

981-
def load(
982-
self,
982+
@classmethod
983+
def _new_instance_from_storage(
984+
cls,
983985
backend: BackendIdentifier | StorageInterface = "pickle",
984986
only_requested=False,
985987
filename: str | Path | None = None,
988+
_node: Node | None = None,
986989
**kwargs,
987990
):
988991
"""
992+
Loads a node from file returns its instance.
989993
994+
Args:
995+
backend (str | StorageInterface): The interface to use for serializing the
996+
node. (Default is "pickle", which loads the standard pickling back end.)
997+
only_requested (bool): Whether to _only_ try loading from the specified
998+
backend, or to loop through all available backends. (Default is False,
999+
try to load whatever you can find.)
1000+
filename (str | Path | None): The name of the file (without extensions)
1001+
from which to load the node. (Default is None, which uses the node's
1002+
lexical path.)
1003+
**kwargs: back end-specific arguments (only likely to work in combination
1004+
with :param:`only_requested`, otherwise there's nothing to be specific
1005+
_to_.)
1006+
1007+
Raises:
1008+
FileNotFoundError: when nothing got loaded.
1009+
"""
1010+
inst = None
1011+
for selected_backend in available_backends(
1012+
backend=backend, only_requested=only_requested
1013+
):
1014+
inst = selected_backend.load(node=_node, filename=filename, **kwargs)
1015+
if inst is not None:
1016+
break
1017+
if inst is None:
1018+
raise FileNotFoundError(
1019+
f"Could not find saved content at {filename} using backend={backend} "
1020+
f"using only_request={only_requested}."
1021+
)
1022+
return inst
1023+
1024+
def _update_instance_from_storage(
1025+
self,
1026+
backend: BackendIdentifier | StorageInterface = "pickle",
1027+
only_requested=False,
1028+
filename: str | Path | None = None,
1029+
**kwargs,
1030+
):
1031+
"""
9901032
Loads the node file and set the loaded state as the node's own.
9911033
9921034
Args:
@@ -995,8 +1037,8 @@ def load(
9951037
only_requested (bool): Whether to _only_ try loading from the specified
9961038
backend, or to loop through all available backends. (Default is False,
9971039
try to load whatever you can find.)
998-
filename (str | Path | None): The name of the file (without extensions) at
999-
which to save the node. (Default is None, which uses the node's
1040+
filename (str | Path | None): The name of the file (without extensions)
1041+
from which to load the node. (Default is None, which uses the node's
10001042
lexical path.)
10011043
**kwargs: back end-specific arguments (only likely to work in combination
10021044
with :param:`only_requested`, otherwise there's nothing to be specific
@@ -1012,16 +1054,13 @@ def load(
10121054
"is the correct thing to do, you can set `self.running=True` where "
10131055
"`self` is this node object."
10141056
)
1015-
for selected_backend in available_backends(
1016-
backend=backend, only_requested=only_requested
1017-
):
1018-
inst = selected_backend.load(
1019-
node=self if filename is None else None, filename=filename, **kwargs
1020-
)
1021-
if inst is not None:
1022-
break
1023-
if inst is None:
1024-
raise FileNotFoundError(f"{self.label} could not find saved content.")
1057+
inst = self.__class__._new_instance_from_storage(
1058+
backend=backend,
1059+
only_requested=only_requested,
1060+
filename=filename,
1061+
_node=self if filename is None else None,
1062+
**kwargs,
1063+
)
10251064

10261065
if inst.__class__ != self.__class__:
10271066
raise TypeError(
@@ -1031,6 +1070,44 @@ def load(
10311070
)
10321071
self.__setstate__(inst.__getstate__())
10331072

1073+
@overloading.overloaded_classmethod(class_method=_new_instance_from_storage)
1074+
def load(
1075+
self,
1076+
backend: BackendIdentifier | StorageInterface = "pickle",
1077+
only_requested=False,
1078+
filename: str | Path | None = None,
1079+
**kwargs,
1080+
):
1081+
"""
1082+
Load a node from storage, either as a new instance (when used as a class
1083+
method) or by updating the current instance (when called as a regular instance
1084+
method).
1085+
1086+
Args:
1087+
backend (str | StorageInterface): The interface to use for serializing the
1088+
node. (Default is "pickle", which loads the standard pickling back end.)
1089+
only_requested (bool): Whether to _only_ try loading from the specified
1090+
backend, or to loop through all available backends. (Default is False,
1091+
try to load whatever you can find.)
1092+
filename (str | Path | None): The name of the file (without extensions)
1093+
from which to load the node. (Default is None, which uses the node's
1094+
lexical path.)
1095+
**kwargs: back end-specific arguments (only likely to work in combination
1096+
with :param:`only_requested`, otherwise there's nothing to be specific
1097+
_to_.)
1098+
1099+
Raises:
1100+
FileNotFoundError: when nothing got loaded.
1101+
TypeError: when loading into an exisiting instance and the saved node has a
1102+
different class name.
1103+
"""
1104+
return self._update_instance_from_storage(
1105+
backend=backend,
1106+
only_requested=only_requested,
1107+
filename=filename,
1108+
**kwargs,
1109+
)
1110+
10341111
load.__doc__ = cast(str, load.__doc__) + _save_load_warnings
10351112

10361113
def delete_storage(
@@ -1048,12 +1125,11 @@ def delete_storage(
10481125
Args:
10491126
backend (str | StorageInterface): The interface to use for serializing the
10501127
node. (Default is "pickle", which loads the standard pickling back end.)
1051-
only_requested (bool): Whether to _only_ try loading from the specified
1052-
backend, or to loop through all available backends. (Default is False,
1053-
try to load whatever you can find.)
1054-
filename (str | Path | None): The name of the file (without extensions) at
1055-
which to save the node. (Default is None, which uses the node's
1056-
lexical path.)
1128+
only_requested (bool): Whether to _only_ search for files using the
1129+
specifiedmbackend, or to loop through all available backends. (Default
1130+
is False, try to remove whatever you can find.)
1131+
filename (str | Path | None): The name of the file (without extensions) to
1132+
remove. (Default is None, which uses the node's lexical path.)
10571133
delete_even_if_not_empty (bool): Whether to delete the file even if it is
10581134
not empty. (Default is False, which will only delete the file if it is
10591135
empty, i.e. has no content in it.)
@@ -1084,12 +1160,11 @@ def has_saved_content(
10841160
Args:
10851161
backend (str | StorageInterface): The interface to use for serializing the
10861162
node. (Default is "pickle", which loads the standard pickling back end.)
1087-
only_requested (bool): Whether to _only_ try loading from the specified
1088-
backend, or to loop through all available backends. (Default is False,
1089-
try to load whatever you can find.)
1090-
filename (str | Path | None): The name of the file (without extensions) at
1091-
which to save the node. (Default is None, which uses the node's
1092-
lexical path.)
1163+
only_requested (bool): Whether to _only_ search for files using the
1164+
specified backend, or to loop through all available backends. (Default
1165+
is False, try to finding whatever you can find.)
1166+
filename (str | Path | None): The name of the file (without extensions) to
1167+
look for. (Default is None, which uses the node's lexical path.)
10931168
**kwargs: back end-specific arguments (only likely to work in combination
10941169
with :param:`only_requested`, otherwise there's nothing to be specific
10951170
_to_.)

pyiron_workflow/overloading.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import functools
2+
3+
4+
def overloaded_classmethod(class_method):
5+
"""
6+
Decorator to define a method that behaves like both a classmethod and an
7+
instancemethod under the same name.
8+
9+
Args:
10+
instance_method: A method defined on the same object as the decorated method,
11+
to be used when an instance of the object calls the decorated method (
12+
instead of a class call)
13+
14+
Returns
15+
-------
16+
descriptor
17+
A descriptor that dispatches to the classmethod when accessed
18+
via the class, and to the given instance method when accessed
19+
via an instance.
20+
21+
Examples:
22+
>>> class Foo:
23+
... def __init__(self, y):
24+
... self.y = y
25+
...
26+
... @classmethod
27+
... def _doit_classmethod(cls, x):
28+
... return f"Class {cls.__name__} doing {x}"
29+
...
30+
... @overloaded_classmethod(class_method=_doit_classmethod)
31+
... def doit(self, x):
32+
... return f"Instance of type {type(self).__name__} doing {x} + {self.y}"
33+
...
34+
>>> Foo.doit(10)
35+
'Class Foo doing 10'
36+
>>> Foo(5).doit(20)
37+
'Instance of type Foo doing 20 + 5'
38+
"""
39+
40+
class Overloaded:
41+
def __init__(self, f_instance, f_class):
42+
self.f_instance = f_instance
43+
self.f_class = f_class
44+
functools.update_wrapper(self, f_instance)
45+
46+
def __get__(self, obj, cls):
47+
if obj is None:
48+
f_class = (
49+
cls.__dict__[self.f_class]
50+
if isinstance(self.f_class, str)
51+
else self.f_class
52+
)
53+
54+
if isinstance(f_class, classmethod):
55+
f_class = f_class.__func__
56+
57+
@functools.wraps(self.f_class)
58+
def bound(*args, **kwargs):
59+
return f_class(cls, *args, **kwargs)
60+
61+
return bound
62+
else:
63+
64+
@functools.wraps(self.f_instance)
65+
def bound(*args, **kwargs):
66+
return self.f_instance(obj, *args, **kwargs)
67+
68+
return bound
69+
70+
def wrapper(f_instance):
71+
return Overloaded(f_instance, class_method)
72+
73+
return wrapper

tests/unit/test_overloading.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
3+
from pyiron_workflow.overloading import (
4+
overloaded_classmethod,
5+
) # replace with actual module path
6+
7+
8+
def class_string(obj, x):
9+
return f"Class {obj.__name__} doing {x}"
10+
11+
12+
def instance_string(obj, x):
13+
return f"Instance of type {type(obj).__name__} doing {x}"
14+
15+
16+
class Foo:
17+
@overloaded_classmethod(class_method="_pseudo_classmethod")
18+
def undecorated_string(self, x):
19+
return instance_string(self, x)
20+
21+
@overloaded_classmethod(class_method="_classmethod")
22+
def decorated_string(self, x):
23+
return instance_string(self, x)
24+
25+
def _pseudo_classmethod(cls, x):
26+
return class_string(cls, x)
27+
28+
@classmethod
29+
def _classmethod(cls, y):
30+
return class_string(cls, y)
31+
32+
@overloaded_classmethod(class_method=_pseudo_classmethod)
33+
def undecorated_direct(self, x):
34+
return instance_string(self, x)
35+
36+
@overloaded_classmethod(class_method=_classmethod)
37+
def decorated_direct(self, x):
38+
return instance_string(self, x)
39+
40+
41+
class TestOverloadedClassMethod(unittest.TestCase):
42+
def test_instance_and_class_calls(self):
43+
self.assertEqual(Foo.undecorated_string(1), class_string(Foo, 1))
44+
self.assertEqual(Foo.decorated_string(2), class_string(Foo, 2))
45+
self.assertEqual(Foo.undecorated_direct(3), class_string(Foo, 3))
46+
self.assertEqual(Foo.decorated_direct(4), class_string(Foo, 4))
47+
48+
self.assertEqual(Foo().undecorated_string(1), instance_string(Foo(), 1))
49+
self.assertEqual(Foo().decorated_string(2), instance_string(Foo(), 2))
50+
self.assertEqual(Foo().undecorated_direct(3), instance_string(Foo(), 3))
51+
self.assertEqual(Foo().decorated_direct(4), instance_string(Foo(), 4))
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main()

0 commit comments

Comments
 (0)