Skip to content

Commit dfa914a

Browse files
c00wfacebook-github-bot
authored andcommitted
Modify lazy_dyndep loading to trigger inside workspace. (pytorch#41687)
Summary: Pull Request resolved: pytorch#41687 Specifically, this makes a new library (lazy), which can be used from both core and workspace. This allows workspace.Createnet to trigger lazy loading of dyndep dependencies. Test Plan: Added a unit test specifically for workspace.CreateNet Reviewed By: dzhulgakov Differential Revision: D22441877 fbshipit-source-id: 3a9d1af9962585d08ea2566c9c85bec7377d39f2
1 parent af5d0bf commit dfa914a

File tree

5 files changed

+38
-19
lines changed

5 files changed

+38
-19
lines changed

caffe2/python/core.py

+6-17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from caffe2.proto import caffe2_pb2
1515
from caffe2.python import scope, utils, workspace
16+
from caffe2.python.lazy import TriggerLazyImport
1617
from caffe2.python.control_ops_grad import \
1718
gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output
1819

@@ -49,18 +50,6 @@ def _InitDataType():
4950

5051
_InitDataType()
5152

52-
_import_lazy_calls = []
53-
54-
def RegisterLazyImport(lazy):
55-
global _import_lazy_calls
56-
_import_lazy_calls += [lazy]
57-
58-
59-
def _import_lazy():
60-
global _import_lazy_calls
61-
for lazy in _import_lazy_calls:
62-
lazy()
63-
6453

6554
def _GetRegisteredOperators():
6655
return set(workspace.RegisteredOperators())
@@ -71,7 +60,7 @@ def _GetRegisteredOperators():
7160

7261
def RefreshRegisteredOperators(trigger_lazy=True):
7362
if trigger_lazy:
74-
_import_lazy()
63+
TriggerLazyImport()
7564
global _REGISTERED_OPERATORS
7665
_REGISTERED_OPERATORS = _GetRegisteredOperators()
7766

@@ -80,7 +69,7 @@ def RefreshRegisteredOperators(trigger_lazy=True):
8069

8170

8271
def GlobalInit(args):
83-
_import_lazy()
72+
TriggerLazyImport()
8473
_GLOBAL_INIT_ARGS.extend(args[1:])
8574
C.global_init(args)
8675

@@ -94,7 +83,7 @@ def IsOperator(op_type):
9483

9584

9685
def IsOperatorWithEngine(op_type, engine):
97-
_import_lazy()
86+
TriggerLazyImport()
9887
return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS
9988

10089

@@ -294,7 +283,7 @@ def __getattr__(self, op_type):
294283
op_type, *args, **kwargs)
295284

296285
def __dir__(self):
297-
_import_lazy()
286+
TriggerLazyImport()
298287
additional_methods = [
299288
op
300289
for op in _REGISTERED_OPERATORS
@@ -2228,7 +2217,7 @@ def __getattr__(self, op_type):
22282217
op_type, *args, **kwargs)
22292218

22302219
def __dir__(self):
2231-
_import_lazy()
2220+
TriggerLazyImport()
22322221
additional_methods = [
22332222
op
22342223
for op in _REGISTERED_OPERATORS

caffe2/python/lazy.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## @package workspace
2+
# Module caffe2.python.lazy
3+
4+
_import_lazy_calls = []
5+
6+
def RegisterLazyImport(lazy):
7+
global _import_lazy_calls
8+
_import_lazy_calls += [lazy]
9+
10+
11+
def TriggerLazyImport():
12+
global _import_lazy_calls
13+
for lazy in _import_lazy_calls:
14+
lazy()

caffe2/python/lazy_dyndep.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import unicode_literals
77

88
import os
9-
from caffe2.python import core, dyndep
9+
from caffe2.python import dyndep, lazy
1010

1111

1212
def RegisterOpsLibrary(name):
@@ -81,4 +81,4 @@ def _import_lazy():
8181
finally:
8282
_LAZY_IMPORTED_DYNDEPS.remove(name)
8383

84-
core.RegisterLazyImport(_import_lazy)
84+
lazy.RegisterLazyImport(_import_lazy)

caffe2/python/lazy_dyndep_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,20 @@ def handlernoop(e):
113113
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
114114
core.RefreshRegisteredOperators()
115115

116+
def test_workspacecreatenet(self):
117+
from caffe2.python import workspace, lazy_dyndep
118+
import tempfile
119+
120+
with tempfile.NamedTemporaryFile() as f:
121+
lazy_dyndep.RegisterOpsLibrary(f.name)
122+
called = False
123+
124+
def handler(e):
125+
raise ValueError("test")
126+
lazy_dyndep.SetErrorHandler(handler)
127+
with self.assertRaises(ValueError, msg="test"):
128+
workspace.CreateNet("fake")
129+
116130

117131
if __name__ == "__main__":
118132
unittest.main()

caffe2/python/workspace.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from caffe2.proto import caffe2_pb2
2121
from caffe2.python import scope, utils
22+
from caffe2.python.lazy import TriggerLazyImport
2223

2324
import caffe2.python._import_c_extension as C
2425

@@ -172,6 +173,7 @@ def ResetWorkspace(root_folder=None):
172173

173174

174175
def CreateNet(net, overwrite=False, input_blobs=None):
176+
TriggerLazyImport()
175177
if input_blobs is None:
176178
input_blobs = []
177179
for input_blob in input_blobs:

0 commit comments

Comments
 (0)