Skip to content

Commit dd25111

Browse files
peterbell10pytorchmergebot
authored andcommitted
[caffe2] Remove OperatorBase::newstyle_outputs_ (pytorch#67093)
`OperatorBase` maintains `output_tensors_` and `newstyle_outputs_` which hold the same list of tensors except one is `vector<caffe2::Tensor>` and the other is `List<at::Tensor>`. This instead maintains only `output_tensors_` and handles the conversions inside of export_caffe2_op_to_c10. Differential Revision: [D32289811](https://our.internmc.facebook.com/intern/diff/D32289811) Pull Request resolved: pytorch#67093 Approved by: https://github.com/dagitses, https://github.com/malfet
1 parent e137dcc commit dd25111

File tree

4 files changed

+39
-45
lines changed

4 files changed

+39
-45
lines changed

caffe2/contrib/aten/aten_op_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22
#include <unordered_map>
33
#include <string>
4-
#include <ATen/ATen.h>
4+
#include <ATen/Functions.h>
55
#include <c10/macros/Macros.h>
66
#include <c10/util/irange.h>
77
#include <caffe2/core/context.h>

caffe2/core/export_caffe2_op_to_c10.h

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <c10/util/irange.h>
1313
#include <torch/csrc/jit/frontend/function_schema_parser.h>
1414
#include <torch/library.h>
15+
#include <caffe2/core/tensor.h>
1516
#include <vector>
1617

1718
namespace caffe2 {
@@ -20,19 +21,19 @@ namespace detail {
2021
constexpr const char* PREALLOCATED_OUTPUT_ARGNAME =
2122
"_caffe2_preallocated_outputs";
2223

23-
using _CallCaffe2OpFunc = c10::List<at::Tensor>(
24+
using _CallCaffe2OpFunc = std::vector<caffe2::Tensor>(
2425
const c10::FunctionSchema& schema,
25-
std::vector<c10::IValue>&& inputs,
26-
c10::List<at::Tensor>&& outputs);
26+
std::vector<c10::IValue> &&inputs,
27+
std::vector<caffe2::Tensor> &&outputs);
2728

2829
template <class Caffe2Operator>
29-
inline c10::List<at::Tensor> _call_caffe2_op(
30+
inline std::vector<caffe2::Tensor> _call_caffe2_op(
3031
const c10::FunctionSchema& schema,
31-
std::vector<c10::IValue>&& inputs,
32-
c10::List<at::Tensor>&& outputs) {
32+
std::vector<c10::IValue> &&inputs,
33+
std::vector<caffe2::Tensor> &&outputs) {
3334
Caffe2Operator op(schema, std::move(inputs), std::move(outputs), -1);
3435
op.Run(-1);
35-
return std::move(op).move_newstyle_outputs();
36+
return std::move(op).move_output_tensors();
3637
}
3738

3839
// This function is inline in the hope that compilers optimizing for speed will
@@ -62,7 +63,6 @@ inline void _call_caffe2_op_from_c10(
6263
*OptionalType::create(ListType::ofTensors())));
6364
IValue preallocated_outputs = torch::jit::pop(*stack);
6465

65-
const size_t num_outputs = schema.returns().size();
6666
const size_t num_inputs = schema.arguments().size() -
6767
1; // -1 because the last argument is the list of preallocated tensors
6868

@@ -71,7 +71,7 @@ inline void _call_caffe2_op_from_c10(
7171
// either the schema doesn't support preallocated outputs or it does but
7272
// they haven't been passed in. Pass a list of uninitialized tensors to
7373
// the caffe2 operator as preallocated outputs.
74-
outputs.resize(num_outputs);
74+
outputs.resize(schema.returns().size());
7575
} else {
7676
AT_ASSERT(preallocated_outputs.isTensorList());
7777
outputs = std::move(preallocated_outputs).toTensorList();
@@ -81,7 +81,15 @@ inline void _call_caffe2_op_from_c10(
8181
// instances in the cache.
8282
std::vector<IValue> inputs = torch::jit::pop(*stack, num_inputs);
8383

84-
outputs = (*call_op)(schema, std::move(inputs), std::move(outputs));
84+
// Convert outputs to caffe2::Tensor
85+
const size_t num_outputs = outputs.size();
86+
std::vector<caffe2::Tensor> outputs_c2(num_outputs);
87+
for (auto i : c10::irange(num_outputs)) {
88+
outputs_c2[i] = caffe2::Tensor(outputs.extract(i));
89+
}
90+
91+
outputs_c2 = (*call_op)(schema, std::move(inputs), std::move(outputs_c2));
92+
TORCH_INTERNAL_ASSERT(num_outputs == outputs_c2.size());
8593

8694
bool return_tensor_list = false;
8795
if (schema.returns().size() == 1) {
@@ -93,11 +101,13 @@ inline void _call_caffe2_op_from_c10(
93101
}
94102
}
95103
if (return_tensor_list) {
96-
// We should not unwrap the list if we expect tensor list in the schema.
104+
for (const auto i : c10::irange(num_outputs)) {
105+
outputs.set(i, at::Tensor(std::move(outputs_c2[i])));
106+
}
97107
torch::jit::push(*stack, outputs);
98108
} else {
99-
for (const auto i : c10::irange(outputs.size())) {
100-
torch::jit::push(*stack, outputs.extract(i));
109+
for (const auto i : c10::irange(num_outputs)) {
110+
torch::jit::push(*stack, at::Tensor(std::move(outputs_c2[i])));
101111
}
102112
}
103113

caffe2/core/operator.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
5959
device_option_(
6060
operator_def.has_device_option() ? operator_def.device_option()
6161
: DeviceOption()),
62-
#if defined(EXPOSE_C2_OPS) || \
63-
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
64-
newstyle_outputs_(),
65-
#endif
6662
input_size_(operator_def.input_size()),
6763
event_(std::make_unique<Event>(device_option_)) {
6864
static GlobalInitIsCalledGuard guard;
@@ -124,14 +120,13 @@ compute_input_size_(const std::vector<c10::IValue>& inputs) {
124120
OperatorBase::OperatorBase(
125121
const c10::FunctionSchema& fn_schema,
126122
std::vector<c10::IValue> inputs,
127-
c10::List<at::Tensor> outputs)
123+
std::vector<caffe2::Tensor> outputs)
128124
// NOLINTNEXTLINE(performance-move-const-arg)
129125
: fn_schema_(make_unique<c10::FunctionSchema>(std::move(fn_schema))),
130126
newstyle_inputs_(std::move(inputs)),
131-
newstyle_outputs_(std::move(outputs)),
127+
output_tensors_(std::move(outputs)),
132128
input_size_(compute_input_size_(newstyle_inputs_)) {
133129
input_tensors_.resize(input_size_);
134-
output_tensors_.resize(newstyle_outputs_.size());
135130
}
136131
#endif
137132

caffe2/core/operator.h

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
7474
explicit OperatorBase(
7575
const c10::FunctionSchema& schema,
7676
std::vector<c10::IValue> inputs,
77-
c10::List<at::Tensor> outputs);
77+
std::vector<caffe2::Tensor> outputs);
7878
#endif
7979

8080
virtual ~OperatorBase() noexcept;
@@ -250,15 +250,12 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
250250
}
251251
#if defined(EXPOSE_C2_OPS) || \
252252
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
253-
at::Tensor output = newstyle_outputs_[idx];
254-
if (!output.defined() || caffe2::Tensor(output).GetDeviceType() != type) {
253+
auto &output = output_tensors_[idx];
254+
if (!output.defined() || output.GetDeviceType() != type) {
255255
// Fix tensor type
256-
Tensor tensor = Tensor(type);
257-
output = at::Tensor(std::move(tensor.getIntrusivePtr()));
256+
output = Tensor(type);
258257
}
259-
output_tensors_[idx] = caffe2::Tensor(output);
260-
newstyle_outputs_[idx] = std::move(output);
261-
return &output_tensors_[idx];
258+
return &output;
262259
#else
263260
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
264261
#endif
@@ -280,9 +277,6 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
280277
if (!isLegacyOperator()) {
281278
#if defined(EXPOSE_C2_OPS) || \
282279
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
283-
newstyle_outputs_[idx] = at::Tensor(tensor);
284-
285-
// also update the tensor in the hack
286280
output_tensors_[idx] = std::move(tensor);
287281
#else
288282
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
@@ -310,16 +304,12 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
310304
}
311305
#if defined(EXPOSE_C2_OPS) || \
312306
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
313-
at::Tensor output = newstyle_outputs_[idx];
314-
Tensor tensor = output.defined()
315-
? GetSizedTensorWithOptions(caffe2::Tensor(output), dims, options)
307+
auto &output = output_tensors_[idx];
308+
output = output.defined()
309+
? GetSizedTensorWithOptions(std::move(output), dims, options)
316310
: caffe2::empty(dims, options);
317-
// assign it back in case it changed
318-
output = at::Tensor(std::move(tensor.getIntrusivePtr()));
319311

320-
output_tensors_[idx] = caffe2::Tensor(output);
321-
newstyle_outputs_[idx] = std::move(output);
322-
return &output_tensors_[idx];
312+
return &output;
323313
#else
324314
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
325315
#endif
@@ -434,7 +424,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
434424
}
435425
#if defined(EXPOSE_C2_OPS) || \
436426
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
437-
return newstyle_outputs_.size();
427+
return output_tensors_.size();
438428
#else
439429
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
440430
#endif
@@ -599,8 +589,8 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
599589

600590
#if defined(EXPOSE_C2_OPS) || \
601591
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
602-
c10::List<at::Tensor> move_newstyle_outputs() && {
603-
return std::move(newstyle_outputs_);
592+
std::vector<caffe2::Tensor> move_output_tensors() && {
593+
return std::move(output_tensors_);
604594
}
605595
#endif
606596

@@ -620,7 +610,6 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
620610
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
621611
std::unique_ptr<const c10::FunctionSchema> fn_schema_;
622612
vector<c10::IValue> newstyle_inputs_;
623-
c10::List<at::Tensor> newstyle_outputs_;
624613
#endif
625614
// HACK
626615
// We preserve the fact that Output() returns Tensor*
@@ -819,7 +808,7 @@ class Operator : public OperatorBase {
819808
explicit Operator(
820809
const c10::FunctionSchema& fn_schema,
821810
std::vector<c10::IValue> inputs,
822-
c10::List<at::Tensor> outputs,
811+
std::vector<caffe2::Tensor> outputs,
823812
StreamId stream = 0)
824813
: OperatorBase(fn_schema, std::move(inputs), std::move(outputs)) {
825814
// In the constructor, we switch to the device so that the child class

0 commit comments

Comments
 (0)