@@ -74,7 +74,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
74
74
explicit OperatorBase (
75
75
const c10::FunctionSchema& schema,
76
76
std::vector<c10::IValue> inputs,
77
- c10::List<at ::Tensor> outputs);
77
+ std::vector<caffe2 ::Tensor> outputs);
78
78
#endif
79
79
80
80
virtual ~OperatorBase () noexcept ;
@@ -250,15 +250,12 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
250
250
}
251
251
#if defined(EXPOSE_C2_OPS) || \
252
252
!defined (CAFFE2_IS_XPLAT_BUILD) && !defined (C10_MOBILE)
253
- at::Tensor output = newstyle_outputs_ [idx];
254
- if (!output.defined () || caffe2::Tensor ( output) .GetDeviceType () != type) {
253
+ auto & output = output_tensors_ [idx];
254
+ if (!output.defined () || output.GetDeviceType () != type) {
255
255
// Fix tensor type
256
- Tensor tensor = Tensor (type);
257
- output = at::Tensor (std::move (tensor.getIntrusivePtr ()));
256
+ output = Tensor (type);
258
257
}
259
- output_tensors_[idx] = caffe2::Tensor (output);
260
- newstyle_outputs_[idx] = std::move (output);
261
- return &output_tensors_[idx];
258
+ return &output;
262
259
#else
263
260
CAFFE_THROW (" Non-legacy operators are not legal in xplat/caffe2" );
264
261
#endif
@@ -280,9 +277,6 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
280
277
if (!isLegacyOperator ()) {
281
278
#if defined(EXPOSE_C2_OPS) || \
282
279
!defined (CAFFE2_IS_XPLAT_BUILD) && !defined (C10_MOBILE)
283
- newstyle_outputs_[idx] = at::Tensor (tensor);
284
-
285
- // also update the tensor in the hack
286
280
output_tensors_[idx] = std::move (tensor);
287
281
#else
288
282
CAFFE_THROW (" Non-legacy operators are not legal in xplat/caffe2" );
@@ -310,16 +304,12 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
310
304
}
311
305
#if defined(EXPOSE_C2_OPS) || \
312
306
!defined (CAFFE2_IS_XPLAT_BUILD) && !defined (C10_MOBILE)
313
- at::Tensor output = newstyle_outputs_ [idx];
314
- Tensor tensor = output.defined ()
315
- ? GetSizedTensorWithOptions (caffe2::Tensor (output), dims, options)
307
+ auto & output = output_tensors_ [idx];
308
+ output = output.defined ()
309
+ ? GetSizedTensorWithOptions (std::move (output), dims, options)
316
310
: caffe2::empty (dims, options);
317
- // assign it back in case it changed
318
- output = at::Tensor (std::move (tensor.getIntrusivePtr ()));
319
311
320
- output_tensors_[idx] = caffe2::Tensor (output);
321
- newstyle_outputs_[idx] = std::move (output);
322
- return &output_tensors_[idx];
312
+ return &output;
323
313
#else
324
314
CAFFE_THROW (" Non-legacy operators are not legal in xplat/caffe2" );
325
315
#endif
@@ -434,7 +424,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
434
424
}
435
425
#if defined(EXPOSE_C2_OPS) || \
436
426
!defined (CAFFE2_IS_XPLAT_BUILD) && !defined (C10_MOBILE)
437
- return newstyle_outputs_ .size ();
427
+ return output_tensors_ .size ();
438
428
#else
439
429
CAFFE_THROW (" Non-legacy operators are not legal in xplat/caffe2" );
440
430
#endif
@@ -599,8 +589,8 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
599
589
600
590
#if defined(EXPOSE_C2_OPS) || \
601
591
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
602
- c10::List<at ::Tensor> move_newstyle_outputs () && {
603
- return std::move (newstyle_outputs_ );
592
+ std::vector<caffe2 ::Tensor> move_output_tensors () && {
593
+ return std::move (output_tensors_ );
604
594
}
605
595
#endif
606
596
@@ -620,7 +610,6 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
620
610
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
621
611
std::unique_ptr<const c10::FunctionSchema> fn_schema_;
622
612
vector<c10::IValue> newstyle_inputs_;
623
- c10::List<at::Tensor> newstyle_outputs_;
624
613
#endif
625
614
// HACK
626
615
// We preserve the fact that Output() returns Tensor*
@@ -819,7 +808,7 @@ class Operator : public OperatorBase {
819
808
explicit Operator (
820
809
const c10::FunctionSchema& fn_schema,
821
810
std::vector<c10::IValue> inputs,
822
- c10::List<at ::Tensor> outputs,
811
+ std::vector<caffe2 ::Tensor> outputs,
823
812
StreamId stream = 0 )
824
813
: OperatorBase(fn_schema, std::move(inputs), std::move(outputs)) {
825
814
// In the constructor, we switch to the device so that the child class
0 commit comments