Skip to content

Commit e7602a1

Browse files
pearufacebook-github-bot
authored andcommitted
Fix multiplication of 0-D sparse tensors (pytorch#70749)
Summary: Pull Request resolved: pytorch#70749 Fixes pytorch#65396 and a clang-tidy error. cc nikitaved pearu cpuhrsch Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D33439136 Pulled By: cpuhrsch fbshipit-source-id: 45ec58de7c18db183f891431d4a26e98fd0e924a
1 parent 4fa70a2 commit e7602a1

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

aten/src/ATen/native/sparse/SparseTensorMath.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,30 @@ SparseTensor& mul_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, con
5555
AT_ASSERT(t.is_sparse());
5656
AT_ASSERT(value.dim() == 0);
5757

58+
// Resolve a possibly sparse COO value to a strided tensor.
59+
Tensor value_;
60+
if (value.is_sparse()) {
61+
if (value._nnz() == 0) {
62+
r.resize_as_(t);
63+
return r.zero_();
64+
}
65+
value_ = value.values();
66+
} else {
67+
value_ = value;
68+
}
69+
// With broadcasting in action, value_ may be a 1-D tensor as long
70+
// as its shape is (1,).
71+
AT_ASSERT(value_.numel() == 1);
72+
5873
if (is_same_tensor(r, t)) {
59-
r._values().mul_(value);
74+
r._values().mul_(value_);
6075
} else {
6176
r.resize_as_(t);
6277
auto indices = r._indices();
6378
indices.resize_as_(t._indices());
6479
indices.copy_(t._indices());
6580
Tensor r_values = r._values(); // Sigh... needed because mul_out takes Tensor&
66-
at::mul_out(r_values, t._values(), value);
81+
at::mul_out(r_values, t._values(), value_);
6782
get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
6883
r._coalesced_(t.is_coalesced());
6984
}
@@ -213,7 +228,7 @@ SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, SparseT
213228

214229
Tensor div_sparse(const Tensor& self, const Tensor& value, c10::optional<c10::string_view> rounding_mode) {
215230
auto commonDtype = at::result_type(self, value);
216-
if (c10::isIntegralType(commonDtype, /*include_bool=*/true) && !rounding_mode.has_value()) {
231+
if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) {
217232
commonDtype = typeMetaToScalarType(at::get_default_dtype());
218233
}
219234
Tensor result = at::empty({0}, self.options().dtype(commonDtype));

test/test_sparse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,6 +1622,7 @@ def _test_basic_ops():
16221622
self._test_basic_ops_shape(9, 0, [10, 10, 10], [], dtype, device, coalesced)
16231623
self._test_basic_ops_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced)
16241624
self._test_basic_ops_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced)
1625+
self._test_basic_ops_shape(0, 0, [], [], dtype, device, coalesced)
16251626

16261627
def _test_basic_ops_hybrid():
16271628
self._test_basic_ops_shape(9, 12, [5, 6], [2, 3], dtype, device, coalesced)

0 commit comments

Comments
 (0)