Skip to content

Commit 93a9bb9

Browse files
authored
Don't override Tensor, Storage macros defined outside torch/csrc in t… (pytorch#8243)
* Don't override Tensor, Storage macros defined outside torch/csrc in torch/csrc. This PR does the following: 1) Removes THSTensor macros in torch/csrc, which aren't used. 2) For macros defined outside of torch/csrc (THTensor, THTensor_, THStorage, THStorage_): a) No longer override them, i.e. previously THTensor could actually be THCTensor if a generic file was included from a file including THCP.h. b) Instead, introduce new macros THW* (e.g. THWTensor) to represent a (potentially empty) wildcard character. In addition to making this code easier to read and codemod, this allows us to more freely change TH/THC; for example: currently in the THC random code, the state is casted to THByteTensor*; this happens to work because the macros don't happen to override THByteTensor. But if THByteTensor just becomes an alias of THTensor (which is the plan for a single tensor type), then this no longer works. The whole thing is a bit of a mess previously because you really have to understand which macros and redefined and which aren't. We could also rename the macros that live in torch/csrc (e.g. the THPTensor macros), but since that is more self contained, I punted for now. * Don't change the plugin.
1 parent a466c12 commit 93a9bb9

16 files changed

+159
-177
lines changed

torch/csrc/THP.h

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
#define LIBRARY_STATE_TYPE
2525
#define LIBRARY_STATE_TYPE_NOARGS
2626

27+
#define THWStorage THStorage
28+
#define THWStorage_(NAME) THStorage_(NAME)
29+
#define THWTensor THTensor
30+
#define THWTensor_(NAME) THTensor_(NAME)
31+
2732
#include "PtrWrapper.h"
2833
#include "Exceptions.h"
2934
#include "Generator.h"

torch/csrc/cuda/Storage.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <THC/THCGenerateAllTypes.h>
2020

2121
template<>
22-
void THPPointer<THStorage>::free() {
22+
void THPPointer<THCStorage>::free() {
2323
if (ptr)
2424
THCStorage_free(LIBRARY_STATE ptr);
2525
}

torch/csrc/cuda/override_macros.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#include "undef_macros.h"
22

3-
#define THStoragePtr THCStoragePtr
3+
#define THWStoragePtr THCStoragePtr
44
#define THPStoragePtr THCPStoragePtr
5-
#define THTensorPtr THCTensorPtr
5+
#define THWTensorPtr THCTensorPtr
66
#define THPTensorPtr THCPTensorPtr
77

8-
#define THStorage THCStorage
9-
#define THStorage_(NAME) THCStorage_(NAME)
10-
#define THTensor THCTensor
11-
#define THTensor_(NAME) THCTensor_(NAME)
8+
#define THWStorage THCStorage
9+
#define THWStorage_(NAME) THCStorage_(NAME)
10+
#define THWTensor THCTensor
11+
#define THWTensor_(NAME) THCTensor_(NAME)
1212

1313
#define THPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
1414
#define THPStorage THCPStorage
@@ -29,10 +29,7 @@
2929
#define THPTensorStateless THCPTensorStateless
3030

3131

32-
#define THSTensorPtr THCSTensorPtr
3332
#define THSPTensorPtr THCSPTensorPtr
34-
#define THSTensor THCSTensor
35-
#define THSTensor_(NAME) THCSTensor_(NAME)
3633

3734
#define THSPTensor_(NAME) TH_CONCAT_4(THCSP,Real,Tensor_,NAME)
3835
#define THSPTensor_stateless_(NAME) TH_CONCAT_4(THCSP,Real,Tensor_stateless_,NAME)

torch/csrc/cuda/restore_macros.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
#define THTensor TH_CONCAT_3(TH,Real,Tensor)
3-
#define THTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME)
2+
#define THWTensor TH_CONCAT_3(TH,Real,Tensor)
3+
#define THWTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME)
44

55
#define THPTensor TH_CONCAT_3(THP,Real,Tensor)
66
#define THPTensorStr TH_CONCAT_STRING_3(torch.,Real,Tensor)
@@ -13,8 +13,8 @@
1313
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
1414

1515
#ifdef _THP_CORE
16-
#define THStoragePtr TH_CONCAT_3(TH,Real,StoragePtr)
17-
#define THTensorPtr TH_CONCAT_3(TH,Real,TensorPtr)
16+
#define THWStoragePtr TH_CONCAT_3(TH,Real,StoragePtr)
17+
#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr)
1818
#define THPStoragePtr TH_CONCAT_3(THP,Real,StoragePtr)
1919
#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr)
2020
#endif

torch/csrc/cuda/undef_macros.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
#undef THPStorageClass
2323
#undef THPStorageType
2424

25-
#undef THStorage
26-
#undef THStorage_
27-
#undef THTensor
28-
#undef THTensor_
25+
#undef THWStorage
26+
#undef THWStorage_
27+
#undef THWTensor
28+
#undef THWTensor_
2929

30-
#undef THStoragePtr
30+
#undef THWStoragePtr
3131
#undef THPStoragePtr
32-
#undef THTensorPtr
32+
#undef THWTensorPtr
3333
#undef THPTensorPtr
3434

3535

@@ -44,9 +44,6 @@
4444
#undef THSPTensorStateless
4545
#undef THSPTensorType
4646

47-
#undef THSTensor
48-
#undef THSTensor_
49-
#undef THSTensorPtr
5047
#undef THSPTensorPtr
5148

5249

torch/csrc/distributed/override_macros.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#include "undef_macros.h"
22

3-
#define THStoragePtr THDStoragePtr
3+
#define THWStoragePtr THDStoragePtr
44
#define THPStoragePtr THDPStoragePtr
5-
#define THTensorPtr THDTensorPtr
5+
#define THWTensorPtr THDTensorPtr
66
#define THPTensorPtr THDPTensorPtr
77

8-
#define THStorage THDStorage
9-
#define THStorage_(NAME) THDStorage_(NAME)
10-
#define THTensor THDTensor
11-
#define THTensor_(NAME) THDTensor_(NAME)
8+
#define THWStorage THDStorage
9+
#define THWStorage_(NAME) THDStorage_(NAME)
10+
#define THWTensor THDTensor
11+
#define THWTensor_(NAME) THDTensor_(NAME)
1212

1313
#define THPStorage_(NAME) TH_CONCAT_4(THDP,Real,Storage_,NAME)
1414
#define THPStorage THDPStorage

torch/csrc/distributed/undef_macros.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
#undef THPStorageClass
2121
#undef THPStorageType
2222

23-
#undef THStorage
24-
#undef THStorage_
25-
#undef THTensor
26-
#undef THTensor_
23+
#undef THWStorage
24+
#undef THWStorage_
25+
#undef THWTensor
26+
#undef THWTensor_
2727

28-
#undef THStoragePtr
28+
#undef THWStoragePtr
2929
#undef THPStoragePtr
30-
#undef THTensorPtr
30+
#undef THWTensorPtr
3131
#undef THPTensorPtr
3232

3333
#undef THHostTensor

torch/csrc/generic/Storage.cpp

+41-41
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,32 @@
44

55
PyObject *THPStorageClass = NULL;
66

7-
PyObject * THPStorage_(New)(THStorage *ptr)
7+
PyObject * THPStorage_(New)(THWStorage *ptr)
88
{
99
TORCH_ASSERT(ptr);
1010
PyTypeObject *type = (PyTypeObject *)THPStorageClass;
1111
PyObject *obj = type->tp_alloc(type, 0);
1212
if (obj) {
1313
((THPStorage *)obj)->cdata = ptr;
1414
} else {
15-
THStorage_(free)(LIBRARY_STATE ptr);
15+
THWStorage_(free)(LIBRARY_STATE ptr);
1616
}
1717
return obj;
1818
}
1919

2020
static void THPStorage_(dealloc)(THPStorage* self)
2121
{
22-
THStorage_(free)(LIBRARY_STATE self->cdata);
22+
THWStorage_(free)(LIBRARY_STATE self->cdata);
2323
Py_TYPE(self)->tp_free((PyObject*)self);
2424
}
2525

26-
static THStorage* THPStorage_(newWithAllocator)(int64_t size, THAllocator* allocator)
26+
static THWStorage* THPStorage_(newWithAllocator)(int64_t size, THAllocator* allocator)
2727
{
2828
#if defined(THC_GENERIC_FILE) || defined(THD_GENERIC_FILE)
2929
THPUtils_setError(THPStorageStr " does not support custom allocators");
3030
return NULL;
3131
#else
32-
return THStorage_(newWithAllocator)(LIBRARY_STATE size, allocator, NULL);
32+
return THWStorage_(newWithAllocator)(LIBRARY_STATE size, allocator, NULL);
3333
#endif
3434
}
3535

@@ -55,7 +55,7 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
5555
if (num_args == 0) {
5656
PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
5757
if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) {
58-
THStorage *ptr = (THStorage*)PyLong_AsVoidPtr(cdata_ptr);
58+
THWStorage *ptr = (THWStorage*)PyLong_AsVoidPtr(cdata_ptr);
5959
self->cdata = ptr;
6060
return (PyObject*)self.release();
6161
}
@@ -68,7 +68,7 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
6868
if (allocator) {
6969
self->cdata = THPStorage_(newWithAllocator)(0, allocator);
7070
} else {
71-
self->cdata = THStorage_(new)(LIBRARY_STATE_NOARGS);
71+
self->cdata = THWStorage_(new)(LIBRARY_STATE_NOARGS);
7272
}
7373
return (PyObject*)self.release();
7474
}
@@ -81,7 +81,7 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
8181
if (allocator) {
8282
self->cdata = THPStorage_(newWithAllocator)(size, allocator);
8383
} else {
84-
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE size);
84+
self->cdata = THWStorage_(newWithSize)(LIBRARY_STATE size);
8585
}
8686
return (PyObject*)self.release();
8787
}
@@ -117,11 +117,11 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
117117
"%" PRId64 ", but the viewed storage has only %" PRId64 " element(s) after offset %" PRId64,
118118
size, numel - offset, offset);
119119

120-
real *data_ptr = THStorage_(data)(LIBRARY_STATE storage_arg->cdata) + offset;
121-
THStoragePtr storage(THStorage_(newWithData)(LIBRARY_STATE data_ptr, size));
120+
real *data_ptr = THWStorage_(data)(LIBRARY_STATE storage_arg->cdata) + offset;
121+
THWStoragePtr storage(THWStorage_(newWithData)(LIBRARY_STATE data_ptr, size));
122122
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
123123
storage->view = storage_arg->cdata;
124-
THStorage_(retain)(LIBRARY_STATE storage_arg->cdata);
124+
THWStorage_(retain)(LIBRARY_STATE storage_arg->cdata);
125125
self->cdata = storage.release();
126126
return (PyObject*)self.release();
127127
#endif
@@ -135,7 +135,7 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
135135
Py_ssize_t length = PySequence_Length(first_arg);
136136
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
137137
THPUtils_typename(first_arg));
138-
self->cdata = THStorage_(newWithSize)(LIBRARY_STATE length);
138+
self->cdata = THWStorage_(newWithSize)(LIBRARY_STATE length);
139139
THPObjectPtr item;
140140
try {
141141
for (Py_ssize_t i = 0; i < length; i++) {
@@ -177,7 +177,7 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
177177
static Py_ssize_t THPStorage_(length)(THPStorage *self)
178178
{
179179
HANDLE_TH_ERRORS
180-
return THStorage_(size)(LIBRARY_STATE self->cdata);
180+
return THWStorage_(size)(LIBRARY_STATE self->cdata);
181181
END_HANDLE_TH_ERRORS_RET(-1)
182182
}
183183

@@ -188,13 +188,13 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
188188
if (THPUtils_checkLong(index)) {
189189
int64_t nindex = THPUtils_unpackLong(index);
190190
if (nindex < 0)
191-
nindex += THStorage_(size)(LIBRARY_STATE self->cdata);
191+
nindex += THWStorage_(size)(LIBRARY_STATE self->cdata);
192192
if (nindex < 0 || nindex >= self->cdata->size) {
193193
PyErr_Format(PyExc_IndexError, "index %" PRId64 " out of range for storage of "
194194
"size %" PRId64, (int64_t) nindex, (int64_t) self->cdata->size);
195195
return NULL;
196196
}
197-
real value = THStorage_(get)(LIBRARY_STATE self->cdata, nindex);
197+
real value = THWStorage_(get)(LIBRARY_STATE self->cdata, nindex);
198198
return THPUtils_(newReal)(value);
199199
/* Slice index */
200200
} else if (PySlice_Check(index)) {
@@ -203,7 +203,7 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
203203
return NULL;
204204
#else
205205
Py_ssize_t start, stop, slicelength, step;
206-
int64_t len = THStorage_(size)(LIBRARY_STATE self->cdata);
206+
int64_t len = THWStorage_(size)(LIBRARY_STATE self->cdata);
207207
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
208208
return NULL;
209209
if (step != 1) {
@@ -212,11 +212,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
212212
return NULL;
213213
}
214214

215-
real *data = THStorage_(data)(LIBRARY_STATE self->cdata);
216-
THStoragePtr new_storage(THStorage_(newWithData)(LIBRARY_STATE data + start, slicelength));
215+
real *data = THWStorage_(data)(LIBRARY_STATE self->cdata);
216+
THWStoragePtr new_storage(THWStorage_(newWithData)(LIBRARY_STATE data + start, slicelength));
217217
new_storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
218218
new_storage->view = self->cdata;
219-
THStorage_(retain)(LIBRARY_STATE self->cdata);
219+
THWStorage_(retain)(LIBRARY_STATE self->cdata);
220220

221221
PyObject *_ret = THPStorage_(New)(new_storage);
222222
new_storage.release();
@@ -242,11 +242,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
242242
real rvalue = THPUtils_(unpackReal)(value);
243243
if (THPUtils_checkLong(index)) {
244244
int64_t nindex = THPUtils_unpackLong(index);
245-
THStorage_(set)(LIBRARY_STATE self->cdata, nindex, rvalue);
245+
THWStorage_(set)(LIBRARY_STATE self->cdata, nindex, rvalue);
246246
return 0;
247247
} else if (PySlice_Check(index)) {
248248
Py_ssize_t start, stop, slicelength, step;
249-
int64_t len = THStorage_(size)(LIBRARY_STATE self->cdata);
249+
int64_t len = THWStorage_(size)(LIBRARY_STATE self->cdata);
250250
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
251251
return -1;
252252
if (step != 1) {
@@ -257,7 +257,7 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
257257
// TODO: check the bounds only once
258258
// TODO: fill?
259259
for (;start < stop; start++)
260-
THStorage_(set)(LIBRARY_STATE self->cdata, start, rvalue);
260+
THWStorage_(set)(LIBRARY_STATE self->cdata, start, rvalue);
261261
return 0;
262262
}
263263
THPUtils_setError("can't index a " THPStorageStr " with %s",
@@ -319,33 +319,33 @@ static struct PyMemberDef THPStorage_(members)[] = {
319319
{NULL}
320320
};
321321

322-
extern THPCopyList THStorage_(copy_functions);
323-
THPCopyList THStorage_(copy_functions);
322+
extern THPCopyList THWStorage_(copy_functions);
323+
THPCopyList THWStorage_(copy_functions);
324324

325325
void THPStorage_(initCopyMethods)()
326326
{
327327
#ifndef THD_GENERIC_FILE
328-
auto& h = THStorage_(copy_functions);
328+
auto& h = THWStorage_(copy_functions);
329329
// copy from CPU types
330-
THPInsertStorageCopyFunction<THPStorage, THPByteStorage>(&THPByteStorageType, h, &THStorage_(copyByte));
331-
THPInsertStorageCopyFunction<THPStorage, THPCharStorage>(&THPCharStorageType, h, &THStorage_(copyChar));
332-
THPInsertStorageCopyFunction<THPStorage, THPShortStorage>(&THPShortStorageType, h, &THStorage_(copyShort));
333-
THPInsertStorageCopyFunction<THPStorage, THPIntStorage>(&THPIntStorageType, h, &THStorage_(copyInt));
334-
THPInsertStorageCopyFunction<THPStorage, THPLongStorage>(&THPLongStorageType, h, &THStorage_(copyLong));
335-
THPInsertStorageCopyFunction<THPStorage, THPHalfStorage>(&THPHalfStorageType, h, &THStorage_(copyHalf));
336-
THPInsertStorageCopyFunction<THPStorage, THPFloatStorage>(&THPFloatStorageType, h, &THStorage_(copyFloat));
337-
THPInsertStorageCopyFunction<THPStorage, THPDoubleStorage>(&THPDoubleStorageType, h, &THStorage_(copyDouble));
330+
THPInsertStorageCopyFunction<THPStorage, THPByteStorage>(&THPByteStorageType, h, &THWStorage_(copyByte));
331+
THPInsertStorageCopyFunction<THPStorage, THPCharStorage>(&THPCharStorageType, h, &THWStorage_(copyChar));
332+
THPInsertStorageCopyFunction<THPStorage, THPShortStorage>(&THPShortStorageType, h, &THWStorage_(copyShort));
333+
THPInsertStorageCopyFunction<THPStorage, THPIntStorage>(&THPIntStorageType, h, &THWStorage_(copyInt));
334+
THPInsertStorageCopyFunction<THPStorage, THPLongStorage>(&THPLongStorageType, h, &THWStorage_(copyLong));
335+
THPInsertStorageCopyFunction<THPStorage, THPHalfStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
336+
THPInsertStorageCopyFunction<THPStorage, THPFloatStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
337+
THPInsertStorageCopyFunction<THPStorage, THPDoubleStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
338338
#ifdef THC_GENERIC_FILE
339339
// copy from GPU types
340-
THPInsertStorageCopyFunction<THPStorage, THCPByteStorage>(&THCPByteStorageType, h, &THStorage_(copyCudaByte));
341-
THPInsertStorageCopyFunction<THPStorage, THCPCharStorage>(&THCPCharStorageType, h, &THStorage_(copyCudaChar));
342-
THPInsertStorageCopyFunction<THPStorage, THCPShortStorage>(&THCPShortStorageType, h, &THStorage_(copyCudaShort));
343-
THPInsertStorageCopyFunction<THPStorage, THCPIntStorage>(&THCPIntStorageType, h, &THStorage_(copyCudaInt));
344-
THPInsertStorageCopyFunction<THPStorage, THCPLongStorage>(&THCPLongStorageType, h, &THStorage_(copyCudaLong));
345-
THPInsertStorageCopyFunction<THPStorage, THCPFloatStorage>(&THCPFloatStorageType, h, &THStorage_(copyCudaFloat));
346-
THPInsertStorageCopyFunction<THPStorage, THCPDoubleStorage>(&THCPDoubleStorageType, h, &THStorage_(copyCudaDouble));
340+
THPInsertStorageCopyFunction<THPStorage, THCPByteStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
341+
THPInsertStorageCopyFunction<THPStorage, THCPCharStorage>(&THCPCharStorageType, h, &THWStorage_(copyCudaChar));
342+
THPInsertStorageCopyFunction<THPStorage, THCPShortStorage>(&THCPShortStorageType, h, &THWStorage_(copyCudaShort));
343+
THPInsertStorageCopyFunction<THPStorage, THCPIntStorage>(&THCPIntStorageType, h, &THWStorage_(copyCudaInt));
344+
THPInsertStorageCopyFunction<THPStorage, THCPLongStorage>(&THCPLongStorageType, h, &THWStorage_(copyCudaLong));
345+
THPInsertStorageCopyFunction<THPStorage, THCPFloatStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
346+
THPInsertStorageCopyFunction<THPStorage, THCPDoubleStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
347347
#ifdef CUDA_HALF_TENSOR
348-
THPInsertStorageCopyFunction<THPStorage, THCPHalfStorage>(&THCPHalfStorageType, h, &THStorage_(copyCudaHalf));
348+
THPInsertStorageCopyFunction<THPStorage, THCPHalfStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
349349
#endif
350350
// add CPU <- GPU copies to base type
351351
#define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)

torch/csrc/generic/Storage.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
struct THPStorage {
66
PyObject_HEAD
7-
THStorage *cdata;
7+
THWStorage *cdata;
88
};
99

10-
THP_API PyObject * THPStorage_(New)(THStorage *ptr);
10+
THP_API PyObject * THPStorage_(New)(THWStorage *ptr);
1111
extern PyObject *THPStorageClass;
1212

1313
#ifdef _THP_CORE

0 commit comments

Comments
 (0)