@@ -146,8 +146,14 @@ THSTensor *THSTensor_(newWithTensorAndSize)(THLongTensor *indices, THTensor *val
146
146
THSTensor *self = (THSTensor *)THAlloc (sizeof (THSTensor));
147
147
THSTensor_ (rawInit)(self);
148
148
149
- nDimI = THLongTensor_size (indices, 0 );
150
- nDimV = THTensor_ (nDimension)(values) - 1 ;
149
+ // TODO: we may need to special case when only one of these are empty.
150
+ if (THLongTensor_nDimension (indices) == 0 && THTensor_ (nDimension)(values) == 0 && sizes != NULL ) {
151
+ nDimI = 0 ;
152
+ nDimV = THLongStorage_size (sizes);
153
+ } else {
154
+ nDimI = THLongTensor_size (indices, 0 );
155
+ nDimV = THTensor_ (nDimension)(values) - 1 ;
156
+ }
151
157
if (!sizes) {
152
158
ignore = THLongTensor_new ();
153
159
THLongTensor *computed_indices_sizes = THLongTensor_new ();
@@ -169,27 +175,30 @@ THSTensor *THSTensor_(newWithTensorAndSize)(THLongTensor *indices, THTensor *val
169
175
THArgCheck (THLongStorage_size (sizes) == nDimI + nDimV, 2 ,
170
176
" number of dimensions must be nDimI + nDimV" );
171
177
172
- THLongTensor *max_indices = THLongTensor_new ();
173
- ignore = THLongTensor_new ();
174
- THLongTensor_max (max_indices, ignore, indices, 1 , 0 );
175
- THLongTensor_free (ignore);
176
- for (int d = 0 ; d < nDimI; d++) {
177
- int64_t max_index_in_dim = THTensor_fastGet1d (max_indices, d);
178
- int64_t dim_size = sizes->data [d];
179
- THArgCheck (max_index_in_dim <= dim_size, 2 ,
180
- " sizes is inconsistent with indices: for dim %d, size is %lld but found index %lld" ,
181
- d, (long long )dim_size, (long long )max_index_in_dim);
182
- }
183
- for (int d = 0 ; d < nDimV; d++) {
184
- int64_t values_size = THTensor_ (size)(values, d + 1 );
185
- int64_t specified_size = sizes->data [nDimI + d];
186
- THArgCheck (values_size <= specified_size, 2 ,
187
- " values and sizes are inconsistent: sizes[%d] is %lld but values.size(%d) is %lld" ,
188
- d + nDimI, (long long )specified_size, d + 1 , (long long )values_size);
178
+ // TODO: we may need to special case when only one of these are empty.
179
+ if (!(THLongTensor_nDimension (indices) == 0 && THTensor_ (nDimension)(values) == 0 && sizes != NULL )) {
180
+ THLongTensor *max_indices = THLongTensor_new ();
181
+ ignore = THLongTensor_new ();
182
+ THLongTensor_max (max_indices, ignore, indices, 1 , 0 );
183
+ THLongTensor_free (ignore);
184
+ for (int d = 0 ; d < nDimI; d++) {
185
+ int64_t max_index_in_dim = THTensor_fastGet1d (max_indices, d);
186
+ int64_t dim_size = sizes->data [d];
187
+ THArgCheck (max_index_in_dim <= dim_size, 2 ,
188
+ " sizes is inconsistent with indices: for dim %d, size is %lld but found index %lld" ,
189
+ d, (long long )dim_size, (long long )max_index_in_dim);
190
+ }
191
+ for (int d = 0 ; d < nDimV; d++) {
192
+ int64_t values_size = THTensor_ (size)(values, d + 1 );
193
+ int64_t specified_size = sizes->data [nDimI + d];
194
+ THArgCheck (values_size <= specified_size, 2 ,
195
+ " values and sizes are inconsistent: sizes[%d] is %lld but values.size(%d) is %lld" ,
196
+ d + nDimI, (long long )specified_size, d + 1 , (long long )values_size);
197
+ }
198
+ THLongTensor_free (max_indices);
189
199
}
190
200
191
201
THSTensor_ (rawResize)(self, nDimI, nDimV, THLongStorage_data (sizes));
192
- THLongTensor_free (max_indices);
193
202
}
194
203
// NB: by default, we do NOT clone indices/values into the sparse tensor.
195
204
// Efficient API by default!
0 commit comments