Skip to content

Commit e5fe66d

Browse files
bddppqfacebook-github-bot
authored andcommitted
Add support for specifying device_option in Functional (pytorch#9619)
Summary: e.g. ``` Functional.Add(x, y, device_option=DeviceOption(HIP, 0)) ``` Pull Request resolved: pytorch#9619 Differential Revision: D8966599 Pulled By: bddppq fbshipit-source-id: 22235e42f19278e79802642798bf0ee70a1202f6
1 parent 37fc58f commit e5fe66d

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

caffe2/python/functional.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from __future__ import unicode_literals
55

66
from caffe2.python import core, workspace
7+
from caffe2.proto import caffe2_pb2
8+
from caffe2.python.onnx.workspace import Workspace
79
from collections import namedtuple
810
from six import string_types
911

@@ -28,7 +30,7 @@ def getitem(self, key):
2830
class _Functional(object):
2931
def __getattribute__(self, op_type):
3032
def op_func(*inputs, **args):
31-
ws = workspace.C.Workspace()
33+
ws = Workspace()
3234
schema = OpSchema.get(op_type)
3335
input_prefix = 'input_'
3436
output_prefix = 'output_'
@@ -86,16 +88,18 @@ def get_name_list(prefix, num, max_num):
8688
output_names = get_name_list(
8789
output_prefix, max_output, max_output
8890
)
89-
for i, input_blob in enumerate(inputs):
90-
ws.create_blob(input_names[i]).feed(input_blob)
9191

9292
op = core.CreateOperator(
9393
op_type, input_names, output_names, **args
9494
)
95-
ws._run_operator(op.SerializeToString())
96-
# RunOperator
97-
output_values = [ws.fetch_blob(x) for x in output_names]
98-
return namedtupledict('output', output_names)(*output_values)
95+
device_option = args.get('device_option', core.DeviceOption(caffe2_pb2.CPU))
96+
with core.DeviceScope(device_option):
97+
for i, input_blob in enumerate(inputs):
98+
ws.FeedBlob(input_names[i], input_blob)
99+
# RunOperator
100+
ws.RunOperatorOnce(op)
101+
output_values = [ws.FetchBlob(x) for x in output_names]
102+
return namedtupledict('output', output_names)(*output_values)
99103

100104
return op_func
101105

caffe2/python/functional_test.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from __future__ import print_function
44
from __future__ import unicode_literals
55

6+
import unittest
7+
68
from caffe2.python import core
79
from hypothesis import given
810
import hypothesis.strategies as st
@@ -44,11 +46,11 @@ def _tensor_splits(draw, add_axis=False):
4446

4547

4648
class TestFunctional(hu.HypothesisTestCase):
47-
@given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]))
48-
def test_relu(self, X, engine):
49+
@given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
50+
def test_relu(self, X, engine, gc, dc):
4951
X += 0.02 * np.sign(X)
5052
X[X == 0.0] += 0.02
51-
output = Functional.Relu(X)
53+
output = Functional.Relu(X, device_option=gc)
5254
Y_l = output[0]
5355
Y_d = output["output_0"]
5456

@@ -66,11 +68,11 @@ def test_relu(self, X, engine):
6668
Y_d, Y_ref, err_msg='Functional Relu result mismatch'
6769
)
6870

69-
@given(tensor_splits=_tensor_splits())
70-
def test_concat(self, tensor_splits):
71+
@given(tensor_splits=_tensor_splits(), **hu.gcs)
72+
def test_concat(self, tensor_splits, gc, dc):
7173
# Input Size: 1 -> inf
7274
axis, _, splits = tensor_splits
73-
concat_result, split_info = Functional.Concat(*splits, axis=axis)
75+
concat_result, split_info = Functional.Concat(*splits, axis=axis, device_option=gc)
7476

7577
concat_result_ref = np.concatenate(splits, axis=axis)
7678
split_info_ref = np.array([a.shape[axis] for a in splits])
@@ -87,8 +89,8 @@ def test_concat(self, tensor_splits):
8789
err_msg='Functional Concat split info mismatch'
8890
)
8991

90-
@given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans())
91-
def test_split(self, tensor_splits, split_as_arg):
92+
@given(tensor_splits=_tensor_splits(), split_as_arg=st.booleans(), **hu.gcs)
93+
def test_split(self, tensor_splits, split_as_arg, gc, dc):
9294
# Output Size: 1 - inf
9395
axis, split_info, splits = tensor_splits
9496

@@ -100,7 +102,7 @@ def test_split(self, tensor_splits, split_as_arg):
100102
else:
101103
input_tensors = [np.concatenate(splits, axis=axis), split_info]
102104
kwargs = dict(axis=axis, num_output=len(splits))
103-
result = Functional.Split(*input_tensors, **kwargs)
105+
result = Functional.Split(*input_tensors, device_option=gc, **kwargs)
104106

105107
def split_ref(input, split=split_info):
106108
s = np.cumsum([0] + list(split))
@@ -114,3 +116,7 @@ def split_ref(input, split=split_info):
114116
np.testing.assert_array_equal(
115117
result[i], ref, err_msg='Functional Relu result mismatch'
116118
)
119+
120+
121+
if __name__ == '__main__':
122+
unittest.main()

0 commit comments

Comments
 (0)