@@ -55,15 +55,30 @@ SparseTensor& mul_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, con
55
55
AT_ASSERT (t.is_sparse ());
56
56
AT_ASSERT (value.dim () == 0 );
57
57
58
+ // Resolve a possibly sparse COO value to a strided tensor.
59
+ Tensor value_;
60
+ if (value.is_sparse ()) {
61
+ if (value._nnz () == 0 ) {
62
+ r.resize_as_ (t);
63
+ return r.zero_ ();
64
+ }
65
+ value_ = value.values ();
66
+ } else {
67
+ value_ = value;
68
+ }
69
+ // With broadcasting in action, value_ may be a 1-D tensor as long
70
+ // as its shape is (1,).
71
+ AT_ASSERT (value_.numel () == 1 );
72
+
58
73
if (is_same_tensor (r, t)) {
59
- r._values ().mul_ (value );
74
+ r._values ().mul_ (value_ );
60
75
} else {
61
76
r.resize_as_ (t);
62
77
auto indices = r._indices ();
63
78
indices.resize_as_ (t._indices ());
64
79
indices.copy_ (t._indices ());
65
80
Tensor r_values = r._values (); // Sigh... needed because mul_out takes Tensor&
66
- at::mul_out (r_values, t._values (), value );
81
+ at::mul_out (r_values, t._values (), value_ );
67
82
get_sparse_impl (r)->set_nnz_and_narrow (t._nnz ());
68
83
r._coalesced_ (t.is_coalesced ());
69
84
}
@@ -213,7 +228,7 @@ SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, SparseT
213
228
214
229
Tensor div_sparse (const Tensor& self, const Tensor& value, c10::optional<c10::string_view> rounding_mode) {
215
230
auto commonDtype = at::result_type (self, value);
216
- if (c10::isIntegralType (commonDtype, /* include_bool =*/ true ) && !rounding_mode.has_value ()) {
231
+ if (c10::isIntegralType (commonDtype, /* includeBool =*/ true ) && !rounding_mode.has_value ()) {
217
232
commonDtype = typeMetaToScalarType (at::get_default_dtype ());
218
233
}
219
234
Tensor result = at::empty ({0 }, self.options ().dtype (commonDtype));
0 commit comments