Skip to content

Commit 52b8a52

Browse files
wat3rBrofacebook-github-bot
authored andcommitted
move AliasWithNameOp to caffe2/operators
Summary: Pull Request resolved: pytorch#31281 Reviewed By: houseroad Differential Revision: D19053453 fbshipit-source-id: 350bfd5c001db9c17916dcae7ade8f56db1e9841
1 parent 0e548a7 commit 52b8a52

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

caffe2/operators/alias_with_name.cc

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "caffe2/operators/alias_with_name.h"
2+
3+
namespace caffe2 {
4+
5+
REGISTER_CPU_OPERATOR(AliasWithName, AliasWithNameOp<CPUContext>);
6+
7+
OPERATOR_SCHEMA(AliasWithName)
8+
.NumInputs(1)
9+
.NumOutputs(1)
10+
.AllowInplace({{0, 0}})
11+
.IdenticalTypeAndShape()
12+
.SetDoc(R"DOC(
13+
Similar with AliasOp, storing the alias name as operator argument.
14+
)DOC")
15+
.Arg("name", "name of the aliasing")
16+
.Arg("is_backward", "weather or not to alias forward or backward")
17+
.Input(0, "input", "Input tensor whose storage will be shared.")
18+
.Output(0, "output", "Tensor of same shape as input, sharing its storage.");
19+
20+
} // namespace caffe2
21+
22+
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
23+
AliasWithName,
24+
"_caffe2::AliasWithName(Tensor input, str name, bool is_backward = False) -> (Tensor output)",
25+
caffe2::AliasWithNameOp<caffe2::CPUContext>);

caffe2/operators/alias_with_name.cu

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "caffe2/core/context_gpu.h"
2+
#include "caffe2/operators/alias_with_name.h"
3+
4+
namespace caffe2 {
5+
6+
REGISTER_CUDA_OPERATOR(AliasWithName, AliasWithNameOp<CUDAContext>);
7+
8+
} // namespace caffe2
9+
10+
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(
11+
AliasWithName,
12+
caffe2::AliasWithNameOp<caffe2::CUDAContext>);

caffe2/operators/alias_with_name.h

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef ALIAS_WITH_NAME_OP_H_
2+
#define ALIAS_WITH_NAME_OP_H_
3+
4+
#include "caffe2/core/context.h"
5+
#include "caffe2/core/export_caffe2_op_to_c10.h"
6+
#include "caffe2/core/operator.h"
7+
8+
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(AliasWithName)
9+
10+
namespace caffe2 {
11+
12+
template <class Context>
13+
class AliasWithNameOp final : public Operator<Context> {
14+
public:
15+
USE_OPERATOR_CONTEXT_FUNCTIONS;
16+
template <class... Args>
17+
explicit AliasWithNameOp(Args&&... args)
18+
: Operator<Context>(std::forward<Args>(args)...),
19+
name_(this->template GetSingleArgument<std::string>(
20+
"name",
21+
"invalid_name")),
22+
is_backward_(
23+
this->template GetSingleArgument<bool>("is_backward", false)) {
24+
CAFFE_ENFORCE(
25+
OperatorBase::HasArgument("name"), "You have to specify argument name");
26+
}
27+
28+
bool RunOnDevice() override {
29+
auto& input = Input(0);
30+
CAFFE_ENFORCE_GE(input.numel(), 0, "Tensor is not initialized");
31+
32+
// This doesn't work anymore as this is "newstyle" operator
33+
// OutputTensorAlias(0, input);
34+
35+
OperatorBase::SetOutputTensor(0, input.Alias());
36+
return true;
37+
}
38+
39+
protected:
40+
std::string name_;
41+
bool is_backward_;
42+
};
43+
44+
} // namespace caffe2
45+
46+
#endif // ALIAS_WITH_NAME_OP_H_
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/env python3
2+
3+
import caffe2.python.hypothesis_test_util as hu
4+
import hypothesis.strategies as st
5+
import numpy as np
6+
from caffe2.python import core, utils
7+
from hypothesis import given
8+
9+
10+
class TestAliasWithNameOp(hu.HypothesisTestCase):
11+
@given(
12+
shape=st.lists(st.integers(0, 5), min_size=1, max_size=3),
13+
dtype=st.sampled_from([np.float32, np.int64]),
14+
**hu.gcs
15+
)
16+
def test_alias_with_name_op(self, shape, dtype, dc, gc):
17+
test_input = (100 * np.random.random(shape)).astype(dtype)
18+
test_inputs = [test_input]
19+
20+
alias_op = core.CreateOperator(
21+
"AliasWithName",
22+
["input"],
23+
["output"],
24+
device_option=gc,
25+
)
26+
alias_op.arg.add().CopyFrom(utils.MakeArgument("name", "whatever_name"))
27+
28+
def reference_func(x):
29+
return (x,)
30+
31+
self.assertReferenceChecks(gc, alias_op, test_inputs, reference_func)

caffe2/python/operator_test/torch_integration_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,15 @@ def _piecewise_linear_ref(X):
710710

711711
torch.testing.assert_allclose(torch.tensor(expected_output), actual_output)
712712

713+
def test_alias_with_name_is_in_place(self):
714+
device = "cuda" if workspace.has_cuda_support else "cpu"
715+
x = torch.Tensor([3, 42]).to(device)
716+
y = torch.ops._caffe2.AliasWithName(x, "new_name")
717+
x[1] = 6
718+
torch.testing.assert_allclose(x, torch.Tensor([3, 6]).to(device))
719+
# y should also change because y is alias of x
720+
torch.testing.assert_allclose(y, torch.Tensor([3, 6]).to(device))
721+
713722

714723
if __name__ == '__main__':
715724
unittest.main()

0 commit comments

Comments
 (0)