Skip to content

Commit 5ad885b

Browse files
Neha Shahfacebook-github-bot
Neha Shah
authored andcommitted
[Caffe2][Pruning] Make the caffe2 Sum operator support long types (pytorch#40379)
Summary: Pull Request resolved: pytorch#40379 The current sum operator doesn't support Long .. hence modify the code Test Plan: Write a test case Reviewed By: jspark1105, yinghai Differential Revision: D21917365 fbshipit-source-id: b37d2c100c70d17d2f89c309e40360ddfab584ee
1 parent b623bde commit 5ad885b

15 files changed

+184
-65
lines changed

caffe2/operators/utility_ops.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,8 @@ bool WeightedSumOp<CUDAContext>::RunOnDevice() {
2626

2727
template <>
2828
bool SumOp<CUDAContext>::RunOnDevice() {
29-
if (Input(0).IsType<float>()) {
30-
return DoRunWithType<float, float>();
31-
} else if (Input(0).IsType<at::Half>()) {
32-
return DoRunWithType<at::Half, at::Half>();
33-
} else if (Input(0).IsType<int32_t>()) {
34-
return DoRunWithType<int32_t, int32_t>();
35-
} else {
36-
CAFFE_THROW("Unsupported inputs");
37-
}
38-
return false;
29+
return DispatchHelper<TensorTypes<float, int32_t, int64_t>>::call(
30+
this, Input(0));
3931
}
4032

4133
REGISTER_CUDA_OPERATOR(Print, PrintOp<CUDAContext>);

caffe2/operators/utility_ops.h

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ class FlattenToVecOp : public Operator<Context> {
239239
bool RunOnDevice() override {
240240
auto& input = Input(0);
241241
auto* output = Output(0);
242-
CAFFE_ENFORCE_GE(
243-
input.dim(), 1, "The rank of the tensor must be >= 1.");
242+
CAFFE_ENFORCE_GE(input.dim(), 1, "The rank of the tensor must be >= 1.");
244243
output->Resize(input.numel());
245244

246245
context_.CopyItemsSameDevice(
@@ -280,7 +279,7 @@ class SumOp : public Operator<Context> {
280279
USE_OPERATOR_CONTEXT_FUNCTIONS;
281280
USE_SIMPLE_CTOR_DTOR(SumOp);
282281

283-
template <typename T, typename M>
282+
template <typename T>
284283
bool DoRunWithType() {
285284
auto& input0 = Input(0);
286285

@@ -331,16 +330,8 @@ class SumOp : public Operator<Context> {
331330
}
332331

333332
bool RunOnDevice() override {
334-
if (Input(0).template IsType<float>()) {
335-
return DoRunWithType<float, float>();
336-
} else if (Input(0).template IsType<int>()) {
337-
return DoRunWithType<int, int>();
338-
} else {
339-
CAFFE_THROW(
340-
"Sum operator only supports 32-bit float and ints, but",
341-
" input was of type ",
342-
Input(0).dtype().name());
343-
}
333+
return DispatchHelper<TensorTypes<float, int32_t, int64_t>>::call(
334+
this, Input(0));
344335
}
345336
};
346337

@@ -369,7 +360,8 @@ class WeightedSumOp : public Operator<Context> {
369360
template <typename T>
370361
bool DoRunWithType() {
371362
// the code is written this way because of 10.1 + gcc 7.3.1 compiler bug
372-
// as discussed at https://devtalk.nvidia.com/default/topic/1048037/linux/cuda-10-1-nvidia-you-re-now-quot-fixing-quot-gcc-bugs-that-gcc-doesn-t-even-have/
363+
// as discussed at
364+
// https://devtalk.nvidia.com/default/topic/1048037/linux/cuda-10-1-nvidia-you-re-now-quot-fixing-quot-gcc-bugs-that-gcc-doesn-t-even-have/
373365
const int input_size = (*this).InputSize();
374366
CAFFE_ENFORCE_EQ(input_size % 2, 0);
375367
const auto& X0 = Input(0);
@@ -751,14 +743,14 @@ class ScatterOp : public Operator<CPUContext> {
751743
template <class... Args>
752744
explicit ScatterOp(Args&&... args)
753745
: Operator<CPUContext>(std::forward<Args>(args)...),
754-
OP_SINGLE_ARG(int, "axis", axis_, 1) {
755-
}
746+
OP_SINGLE_ARG(int, "axis", axis_, 1) {}
756747

757748
virtual ~ScatterOp() noexcept override {}
758749

759750
bool RunOnDevice() override {
760-
761-
TORCH_CHECK(Context::GetDeviceType() == kCPU, "ScatterOp currently only supports CPU.")
751+
TORCH_CHECK(
752+
Context::GetDeviceType() == kCPU,
753+
"ScatterOp currently only supports CPU.")
762754

763755
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
764756
this, this->template Input<Tensor>(INDICES, CPU));
@@ -775,7 +767,8 @@ class ScatterOp : public Operator<CPUContext> {
775767
// ONNX allows negative axis to index from the back, valid range: [-r, r].
776768
axis_ = data.canonical_axis_index(axis_);
777769

778-
CAFFE_ENFORCE_GE(data.dim(), axis_ + 1, "DATA should be at least [axis+1]-D");
770+
CAFFE_ENFORCE_GE(
771+
data.dim(), axis_ + 1, "DATA should be at least [axis+1]-D");
779772
CAFFE_ENFORCE_GE(axis_, 0, "Axis should be non-negative");
780773
CAFFE_ENFORCE_LT(axis_, data.dim(), "Axis out of range");
781774

@@ -818,14 +811,20 @@ class ScatterOp : public Operator<CPUContext> {
818811
// src offset can be computed as i * J_src * K + j * K + k.
819812
// dst offset can be computed as i * J_dst * K + idxs[idxs_offset] * K + K
820813
// Note that idxs and src should have the same rank and shape.
821-
// dst should have the same rank as idxs and src, but the dimension of dim axis can be different.
822-
// That is why in the above equation, there is the difference of J_src and J_dst.
823-
for (int64_t outer_batch = 0; outer_batch < outer_dims_product; ++outer_batch) {
814+
// dst should have the same rank as idxs and src, but the dimension of dim
815+
// axis can be different. That is why in the above equation, there is the
816+
// difference of J_src and J_dst.
817+
for (int64_t outer_batch = 0; outer_batch < outer_dims_product;
818+
++outer_batch) {
824819
for (int64_t i = 0; i < N; ++i) {
825-
for (int64_t inner_batch = 0; inner_batch < idxs_block_size; ++inner_batch) {
826-
auto idxs_elem_idx = outer_batch * idxs_batch_size + i * idxs_block_size + inner_batch;
827-
auto src_elem_idx = outer_batch * src_batch_size + i * src_block_size + inner_batch;
828-
auto dst_elem_idx = outer_batch * dst_batch_size + idxs[idxs_elem_idx] * dst_block_size + inner_batch;
820+
for (int64_t inner_batch = 0; inner_batch < idxs_block_size;
821+
++inner_batch) {
822+
auto idxs_elem_idx =
823+
outer_batch * idxs_batch_size + i * idxs_block_size + inner_batch;
824+
auto src_elem_idx =
825+
outer_batch * src_batch_size + i * src_block_size + inner_batch;
826+
auto dst_elem_idx = outer_batch * dst_batch_size +
827+
idxs[idxs_elem_idx] * dst_block_size + inner_batch;
829828

830829
auto src = src_base + src_elem_idx * item_bytesize;
831830
auto dst = out + dst_elem_idx * item_bytesize;
@@ -1401,7 +1400,8 @@ class RangeOp : public Operator<Context> {
14011400
T step = 1;
14021401

14031402
for (int i = 0; i < InputSize(); ++i) {
1404-
CAFFE_ENFORCE_EQ(Input(i).numel(), 1, "All inputs must be scalar/1D tensor.");
1403+
CAFFE_ENFORCE_EQ(
1404+
Input(i).numel(), 1, "All inputs must be scalar/1D tensor.");
14051405
}
14061406

14071407
switch (InputSize()) {

caffe2/python/operator_test/utility_ops_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import numpy as np
1313
import random
1414
import six
15-
import unittest
1615

1716

1817
class TestUtilityOps(serial.SerializedTestCase):
@@ -270,6 +269,45 @@ def mx_grad(a):
270269
)
271270
self.assertDeviceChecks(dc, op, inputs, [0, 1, 2])
272271

272+
@serial.given(
273+
n=st.integers(1, 8), m=st.integers(1, 10), d=st.integers(1, 4),
274+
in_place=st.booleans(), engine=st.sampled_from(["", "CUDNN"]),
275+
seed=st.integers(min_value=0, max_value=65535),
276+
dtype=st.sampled_from([np.int32, np.int64, np.float32]),
277+
**hu.gcs)
278+
def test_sum(
279+
self, n, m, d, in_place, engine, seed, dtype, gc, dc):
280+
input_names = []
281+
input_vars = []
282+
np.random.seed(seed)
283+
for i in range(m):
284+
X_name = 'X' + str(i)
285+
input_names.extend([X_name])
286+
var = np.random.rand(n, d).astype(dtype)
287+
vars()[X_name] = var
288+
input_vars.append(var)
289+
290+
def sum_op_ref(*args):
291+
res = np.zeros((n, d))
292+
for i in range(m):
293+
res = res + args[i]
294+
return (res, )
295+
296+
op = core.CreateOperator(
297+
"Sum",
298+
input_names,
299+
[input_names[0]] if in_place else ['Y'],
300+
engine=engine,
301+
)
302+
303+
self.assertReferenceChecks(
304+
device_option=gc,
305+
op=op,
306+
inputs=input_vars,
307+
reference=sum_op_ref,
308+
)
309+
self.assertDeviceChecks(dc, op, input_vars, [0])
310+
273311
@serial.given(
274312
inputs=hu.lengths_tensor().flatmap(
275313
lambda pair: st.tuples(

0 commit comments

Comments
 (0)