@@ -239,8 +239,7 @@ class FlattenToVecOp : public Operator<Context> {
239
239
bool RunOnDevice () override {
240
240
auto & input = Input (0 );
241
241
auto * output = Output (0 );
242
- CAFFE_ENFORCE_GE (
243
- input.dim (), 1 , " The rank of the tensor must be >= 1." );
242
+ CAFFE_ENFORCE_GE (input.dim (), 1 , " The rank of the tensor must be >= 1." );
244
243
output->Resize (input.numel ());
245
244
246
245
context_.CopyItemsSameDevice (
@@ -280,7 +279,7 @@ class SumOp : public Operator<Context> {
280
279
USE_OPERATOR_CONTEXT_FUNCTIONS;
281
280
USE_SIMPLE_CTOR_DTOR (SumOp);
282
281
283
- template <typename T, typename M >
282
+ template <typename T>
284
283
bool DoRunWithType () {
285
284
auto & input0 = Input (0 );
286
285
@@ -331,16 +330,8 @@ class SumOp : public Operator<Context> {
331
330
}
332
331
333
332
bool RunOnDevice () override {
334
- if (Input (0 ).template IsType <float >()) {
335
- return DoRunWithType<float , float >();
336
- } else if (Input (0 ).template IsType <int >()) {
337
- return DoRunWithType<int , int >();
338
- } else {
339
- CAFFE_THROW (
340
- " Sum operator only supports 32-bit float and ints, but" ,
341
- " input was of type " ,
342
- Input (0 ).dtype ().name ());
343
- }
333
+ return DispatchHelper<TensorTypes<float , int32_t , int64_t >>::call (
334
+ this , Input (0 ));
344
335
}
345
336
};
346
337
@@ -369,7 +360,8 @@ class WeightedSumOp : public Operator<Context> {
369
360
template <typename T>
370
361
bool DoRunWithType () {
371
362
// the code is written this way because of 10.1 + gcc 7.3.1 compiler bug
372
- // as discussed at https://devtalk.nvidia.com/default/topic/1048037/linux/cuda-10-1-nvidia-you-re-now-quot-fixing-quot-gcc-bugs-that-gcc-doesn-t-even-have/
363
+ // as discussed at
364
+ // https://devtalk.nvidia.com/default/topic/1048037/linux/cuda-10-1-nvidia-you-re-now-quot-fixing-quot-gcc-bugs-that-gcc-doesn-t-even-have/
373
365
const int input_size = (*this ).InputSize ();
374
366
CAFFE_ENFORCE_EQ (input_size % 2 , 0 );
375
367
const auto & X0 = Input (0 );
@@ -751,14 +743,14 @@ class ScatterOp : public Operator<CPUContext> {
751
743
template <class ... Args>
752
744
explicit ScatterOp (Args&&... args)
753
745
: Operator<CPUContext>(std::forward<Args>(args)...),
754
- OP_SINGLE_ARG(int , " axis" , axis_, 1 ) {
755
- }
746
+ OP_SINGLE_ARG(int , " axis" , axis_, 1 ) {}
756
747
757
748
virtual ~ScatterOp () noexcept override {}
758
749
759
750
bool RunOnDevice () override {
760
-
761
- TORCH_CHECK (Context::GetDeviceType () == kCPU , " ScatterOp currently only supports CPU." )
751
+ TORCH_CHECK (
752
+ Context::GetDeviceType () == kCPU ,
753
+ " ScatterOp currently only supports CPU." )
762
754
763
755
return DispatchHelper<TensorTypes<int32_t , int64_t >>::call (
764
756
this , this ->template Input <Tensor>(INDICES, CPU));
@@ -775,7 +767,8 @@ class ScatterOp : public Operator<CPUContext> {
775
767
// ONNX allows negative axis to index from the back, valid range: [-r, r].
776
768
axis_ = data.canonical_axis_index (axis_);
777
769
778
- CAFFE_ENFORCE_GE (data.dim (), axis_ + 1 , " DATA should be at least [axis+1]-D" );
770
+ CAFFE_ENFORCE_GE (
771
+ data.dim (), axis_ + 1 , " DATA should be at least [axis+1]-D" );
779
772
CAFFE_ENFORCE_GE (axis_, 0 , " Axis should be non-negative" );
780
773
CAFFE_ENFORCE_LT (axis_, data.dim (), " Axis out of range" );
781
774
@@ -818,14 +811,20 @@ class ScatterOp : public Operator<CPUContext> {
818
811
// src offset can be computed as i * J_src * K + j * K + k.
819
812
// dst offset can be computed as i * J_dst * K + idxs[idxs_offset] * K + K
820
813
// Note that idxs and src should have the same rank and shape.
821
- // dst should have the same rank as idxs and src, but the dimension of dim axis can be different.
822
- // That is why in the above equation, there is the difference of J_src and J_dst.
823
- for (int64_t outer_batch = 0 ; outer_batch < outer_dims_product; ++outer_batch) {
814
+ // dst should have the same rank as idxs and src, but the dimension of dim
815
+ // axis can be different. That is why in the above equation, there is the
816
+ // difference of J_src and J_dst.
817
+ for (int64_t outer_batch = 0 ; outer_batch < outer_dims_product;
818
+ ++outer_batch) {
824
819
for (int64_t i = 0 ; i < N; ++i) {
825
- for (int64_t inner_batch = 0 ; inner_batch < idxs_block_size; ++inner_batch) {
826
- auto idxs_elem_idx = outer_batch * idxs_batch_size + i * idxs_block_size + inner_batch;
827
- auto src_elem_idx = outer_batch * src_batch_size + i * src_block_size + inner_batch;
828
- auto dst_elem_idx = outer_batch * dst_batch_size + idxs[idxs_elem_idx] * dst_block_size + inner_batch;
820
+ for (int64_t inner_batch = 0 ; inner_batch < idxs_block_size;
821
+ ++inner_batch) {
822
+ auto idxs_elem_idx =
823
+ outer_batch * idxs_batch_size + i * idxs_block_size + inner_batch;
824
+ auto src_elem_idx =
825
+ outer_batch * src_batch_size + i * src_block_size + inner_batch;
826
+ auto dst_elem_idx = outer_batch * dst_batch_size +
827
+ idxs[idxs_elem_idx] * dst_block_size + inner_batch;
829
828
830
829
auto src = src_base + src_elem_idx * item_bytesize;
831
830
auto dst = out + dst_elem_idx * item_bytesize;
@@ -1401,7 +1400,8 @@ class RangeOp : public Operator<Context> {
1401
1400
T step = 1 ;
1402
1401
1403
1402
for (int i = 0 ; i < InputSize (); ++i) {
1404
- CAFFE_ENFORCE_EQ (Input (i).numel (), 1 , " All inputs must be scalar/1D tensor." );
1403
+ CAFFE_ENFORCE_EQ (
1404
+ Input (i).numel (), 1 , " All inputs must be scalar/1D tensor." );
1405
1405
}
1406
1406
1407
1407
switch (InputSize ()) {
0 commit comments