Skip to content

Commit 24a0467

Browse files
qqaatwpytorchmergebot
authored andcommitted
Add opset16 onnx support for torch.scatter_add (pytorch#79103)
Fixes pytorch#32960 Pull Request resolved: pytorch#79103 Approved by: https://github.com/BowenBao
1 parent 5ca9253 commit 24a0467

File tree

4 files changed

+164
-0
lines changed

4 files changed

+164
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
ir_version: 8
2+
producer_name: "pytorch"
3+
producer_version: "CURRENT_VERSION"
4+
graph {
5+
node {
6+
input: "onnx::ScatterElements_0"
7+
input: "onnx::ScatterElements_1"
8+
input: "onnx::ScatterElements_2"
9+
output: "3"
10+
name: "ScatterElements_0"
11+
op_type: "ScatterElements"
12+
attribute {
13+
name: "axis"
14+
i: 1
15+
type: INT
16+
}
17+
attribute {
18+
name: "reduction"
19+
s: "add"
20+
type: STRING
21+
}
22+
}
23+
name: "torch_jit"
24+
input {
25+
name: "onnx::ScatterElements_0"
26+
type {
27+
tensor_type {
28+
elem_type: 1
29+
shape {
30+
dim {
31+
dim_value: 3
32+
}
33+
dim {
34+
dim_value: 3
35+
}
36+
}
37+
}
38+
}
39+
}
40+
input {
41+
name: "onnx::ScatterElements_1"
42+
type {
43+
tensor_type {
44+
elem_type: 7
45+
shape {
46+
dim {
47+
dim_value: 3
48+
}
49+
dim {
50+
dim_value: 2
51+
}
52+
}
53+
}
54+
}
55+
}
56+
input {
57+
name: "onnx::ScatterElements_2"
58+
type {
59+
tensor_type {
60+
elem_type: 1
61+
shape {
62+
dim {
63+
dim_value: 3
64+
}
65+
dim {
66+
dim_value: 2
67+
}
68+
}
69+
}
70+
}
71+
}
72+
output {
73+
name: "3"
74+
type {
75+
tensor_type {
76+
elem_type: 1
77+
shape {
78+
dim {
79+
dim_value: 3
80+
}
81+
dim {
82+
dim_value: 3
83+
}
84+
}
85+
}
86+
}
87+
}
88+
}
89+
opset_import {
90+
version: 16
91+
}

test/onnx/test_operators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,16 @@ def test_scatter_add_opset11(self):
847847
opset_version=11,
848848
)
849849

850+
def test_scatter_add_opset16(self):
851+
data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
852+
indices = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
853+
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
854+
self.assertONNX(
855+
lambda data, index: data.scatter_add(1, indices, values),
856+
(data, (indices, values)),
857+
opset_version=16,
858+
)
859+
850860
def test_master_opset(self):
851861
x = torch.randn(2, 3).float()
852862
y = torch.randn(2, 3).float()

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3892,6 +3892,31 @@ def forward(self, src, index):
38923892
index = torch.tensor([[0, 1], [0, 1], [0, 1]], dtype=torch.int64)
38933893
self.run_test(ScatterModel(), (src, index))
38943894

3895+
@skipIfUnsupportedMinOpsetVersion(16)
3896+
def test_scatter_add_index_not_unique(self):
3897+
class ScatterModel(torch.nn.Module):
3898+
def forward(self, input, indices, values):
3899+
return input.scatter_add(1, indices, values)
3900+
3901+
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
3902+
indices = torch.tensor([[0, 0], [1, 1], [2, 2]], dtype=torch.int64)
3903+
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
3904+
self.run_test(ScatterModel(), input_args=(input, indices, values))
3905+
3906+
@torch.jit.script
3907+
def scatter_sum(src: Tensor, index: Tensor):
3908+
size = src.size()
3909+
out = torch.zeros(size, dtype=src.dtype)
3910+
return out.scatter_add_(1, index, src)
3911+
3912+
class ScatterModel(torch.nn.Module):
3913+
def forward(self, src, index):
3914+
return scatter_sum(src, index)
3915+
3916+
src = torch.rand(3, 2)
3917+
index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
3918+
self.run_test(ScatterModel(), (src, index))
3919+
38953920
@skipIfUnsupportedMinOpsetVersion(9)
38963921
def test_bucketize(self):
38973922
class BucketModel(torch.nn.Module):

torch/onnx/symbolic_opset16.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,41 @@ def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
4747
mode_s=mode_s,
4848
padding_mode_s=padding_mode_s,
4949
)
50+
51+
52+
@symbolic_helper.parse_args("v", "i", "v", "v")
53+
def scatter_add(g, self, dim, index, src):
54+
if symbolic_helper.is_caffe2_aten_fallback():
55+
return g.at("scatter", self, dim, index, src, overload_name="src")
56+
57+
src_type = src.type().scalarType()
58+
src_sizes = symbolic_helper._get_tensor_sizes(src)
59+
index_sizes = symbolic_helper._get_tensor_sizes(index)
60+
61+
if src_sizes != index_sizes:
62+
return symbolic_helper._unimplemented(
63+
"scatter_add",
64+
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
65+
)
66+
67+
src = symbolic_helper._maybe_get_scalar(src)
68+
if symbolic_helper._is_value(src):
69+
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
70+
else:
71+
# Check if scalar "src" has same type as self (PyTorch allows different
72+
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
73+
if self.type().scalarType() != src_type:
74+
src = g.op(
75+
"Cast",
76+
src,
77+
to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
78+
)
79+
80+
return g.op(
81+
"ScatterElements",
82+
self,
83+
index,
84+
src,
85+
axis_i=dim,
86+
reduction_s="add",
87+
)

0 commit comments

Comments
 (0)