Skip to content

Commit 5ee3358

Browse files
committed
python 2 support
1 parent 6954783 commit 5ee3358

15 files changed

+197
-64
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ torch.egg-info/
44
*/**/__pycache__
55
torch/__init__.py
66
torch/csrc/generic/TensorMethods.cpp
7+
*/**/*.pyc

setup.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from setuptools import setup, Extension
2+
from os.path import expanduser
23
from tools.cwrap import cwrap
4+
import platform
35

46
################################################################################
57
# Generate __init__.py from templates
@@ -49,6 +51,16 @@
4951
################################################################################
5052
# Declare the package
5153
################################################################################
54+
extra_link_args = []
55+
56+
# TODO: remove and properly submodule TH in the repo itself
57+
th_path = expanduser("~/torch/install/")
58+
th_header_path = th_path + "include"
59+
th_lib_path = th_path + "lib"
60+
if platform.system() == 'Darwin':
61+
extra_link_args.append('-L' + th_lib_path)
62+
extra_link_args.append('-Wl,-rpath,' + th_lib_path)
63+
5264
sources = [
5365
"torch/csrc/Module.cpp",
5466
"torch/csrc/Tensor.cpp",
@@ -59,9 +71,13 @@
5971
libraries=['TH'],
6072
sources=sources,
6173
language='c++',
62-
include_dirs=["torch/csrc"])
74+
include_dirs=(["torch/csrc", th_header_path]),
75+
extra_link_args = extra_link_args,
76+
)
77+
6378

6479

6580
setup(name="torch", version="0.1",
6681
ext_modules=[C],
67-
packages=['torch'])
82+
packages=['torch'],
83+
)

test/smoke.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
3+
a = torch.FloatTensor(4, 3)
4+
b = torch.FloatTensor(3, 4)
5+
6+
a.add(b)
7+
8+
c = a.storage()
9+
10+
d = a.select(0, 1)
11+
12+
print(c)
13+
print(a)
14+
print(b)
15+
print(d)
16+
17+
18+
a.fill(0)
19+
20+
print(a[1])
21+
22+
print(a.ge(long(0)))
23+
print(a.ge(0))
24+

tools/__init__.py

Whitespace-only changes.

tools/cwrap.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def signature_hash(self):
151151
'THStorage': Template('return THPStorage_(newObject)($expr)'),
152152
'THLongStorage': Template('return THPLongStorage_newObject($expr)'),
153153
'bool': Template('return PyBool_FromLong($expr)'),
154-
'long': Template('return PyLong_FromLong($expr)'),
154+
'long': Template('return PyInt_FromLong($expr)'),
155155
'double': Template('return PyFloat_FromDouble($expr)'),
156156
'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'),
157157
# TODO
@@ -397,16 +397,19 @@ def argfilter():
397397
CONSTANT arguments are literals.
398398
Repeated arguments do not need to be specified twice.
399399
"""
400-
provided = set()
400+
# use class rather than nonlocal to maintain 2.7 compat
401+
# see http://stackoverflow.com/questions/3190706/nonlocal-keyword-in-python-2-x
402+
# TODO: check this works
403+
class context:
404+
provided = set()
401405
def is_already_provided(arg):
402-
nonlocal provided
403406
ret = False
404407
ret |= arg.name == 'self'
405408
ret |= arg.name == '_res_new'
406409
ret |= arg.type == 'CONSTANT'
407410
ret |= arg.type == 'EXPRESSION'
408-
ret |= arg.name in provided
409-
provided.add(arg.name)
411+
ret |= arg.name in context.provided
412+
context.provided.add(arg.name)
410413
return ret
411414
return is_already_provided
412415

torch/Storage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ def __repr__(self):
77
return str(self)
88

99
def __iter__(self):
10-
return map(lambda i: self[i], range(self.size()))
10+
return iter(map(lambda i: self[i], range(self.size())))
1111

torch/Tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ def __str__(self):
4242
return _printing.printTensor(self)
4343

4444
def __iter__(self):
45-
return map(lambda i: self.select(0, i), range(self.size(0)))
45+
return iter(map(lambda i: self.select(0, i), range(self.size(0))))

torch/csrc/Module.cpp

+34-16
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
#include "THP.h"
77

8+
#if PY_MAJOR_VERSION == 2
9+
#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;}
10+
#else
811
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
12+
#endif
913

1014
static PyObject* module;
1115
static PyObject* tensor_classes;
@@ -34,21 +38,21 @@ static bool THPModule_loadClasses(PyObject *self)
3438
PyObject *torch_module = PyImport_ImportModule("torch");
3539
PyObject* module_dict = PyModule_GetDict(torch_module);
3640

37-
THPDoubleStorageClass = PyMapping_GetItemString(module_dict, "DoubleStorage");
38-
THPFloatStorageClass = PyMapping_GetItemString(module_dict, "FloatStorage");
39-
THPLongStorageClass = PyMapping_GetItemString(module_dict, "LongStorage");
40-
THPIntStorageClass = PyMapping_GetItemString(module_dict, "IntStorage");
41-
THPShortStorageClass = PyMapping_GetItemString(module_dict, "ShortStorage");
42-
THPCharStorageClass = PyMapping_GetItemString(module_dict, "CharStorage");
43-
THPByteStorageClass = PyMapping_GetItemString(module_dict, "ByteStorage");
44-
45-
THPDoubleTensorClass = PyMapping_GetItemString(module_dict, "DoubleTensor");
46-
THPFloatTensorClass = PyMapping_GetItemString(module_dict, "FloatTensor");
47-
THPLongTensorClass = PyMapping_GetItemString(module_dict, "LongTensor");
48-
THPIntTensorClass = PyMapping_GetItemString(module_dict, "IntTensor");
49-
THPShortTensorClass = PyMapping_GetItemString(module_dict, "ShortTensor");
50-
THPCharTensorClass = PyMapping_GetItemString(module_dict, "CharTensor");
51-
THPByteTensorClass = PyMapping_GetItemString(module_dict, "ByteTensor");
41+
THPDoubleStorageClass = PyMapping_GetItemString(module_dict,(char*)"DoubleStorage");
42+
THPFloatStorageClass = PyMapping_GetItemString(module_dict,(char*)"FloatStorage");
43+
THPLongStorageClass = PyMapping_GetItemString(module_dict,(char*)"LongStorage");
44+
THPIntStorageClass = PyMapping_GetItemString(module_dict,(char*)"IntStorage");
45+
THPShortStorageClass = PyMapping_GetItemString(module_dict,(char*)"ShortStorage");
46+
THPCharStorageClass = PyMapping_GetItemString(module_dict,(char*)"CharStorage");
47+
THPByteStorageClass = PyMapping_GetItemString(module_dict,(char*)"ByteStorage");
48+
49+
THPDoubleTensorClass = PyMapping_GetItemString(module_dict,(char*)"DoubleTensor");
50+
THPFloatTensorClass = PyMapping_GetItemString(module_dict,(char*)"FloatTensor");
51+
THPLongTensorClass = PyMapping_GetItemString(module_dict,(char*)"LongTensor");
52+
THPIntTensorClass = PyMapping_GetItemString(module_dict,(char*)"IntTensor");
53+
THPShortTensorClass = PyMapping_GetItemString(module_dict,(char*)"ShortTensor");
54+
THPCharTensorClass = PyMapping_GetItemString(module_dict,(char*)"CharTensor");
55+
THPByteTensorClass = PyMapping_GetItemString(module_dict,(char*)"ByteTensor");
5256
PySet_Add(tensor_classes, THPDoubleTensorClass);
5357
PySet_Add(tensor_classes, THPFloatTensorClass);
5458
PySet_Add(tensor_classes, THPLongTensorClass);
@@ -314,13 +318,15 @@ static PyMethodDef TorchMethods[] = {
314318
{NULL, NULL, 0, NULL}
315319
};
316320

321+
#if PY_MAJOR_VERSION != 2
317322
static struct PyModuleDef torchmodule = {
318323
PyModuleDef_HEAD_INIT,
319324
"torch.C",
320325
NULL,
321326
-1,
322327
TorchMethods
323328
};
329+
#endif
324330

325331
static void errorHandler(const char *msg, void *data)
326332
{
@@ -338,10 +344,17 @@ static void updateErrorHandlers()
338344
THSetArgErrorHandler(errorHandlerArg, NULL);
339345
}
340346

347+
#if PY_MAJOR_VERSION == 2
348+
PyMODINIT_FUNC initC()
349+
#else
341350
PyMODINIT_FUNC PyInit_C()
351+
#endif
342352
{
353+
#if PY_MAJOR_VERSION == 2
354+
ASSERT_TRUE(module = Py_InitModule("torch.C", TorchMethods));
355+
#else
343356
ASSERT_TRUE(module = PyModule_Create(&torchmodule));
344-
357+
#endif
345358
ASSERT_TRUE(tensor_classes = PySet_New(NULL));
346359
ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0);
347360

@@ -363,5 +376,10 @@ PyMODINIT_FUNC PyInit_C()
363376

364377
updateErrorHandlers();
365378

379+
#if PY_MAJOR_VERSION == 2
380+
#else
366381
return module;
382+
#endif
367383
}
384+
385+
#undef ASSERT_TRUE

torch/csrc/THP.h

+9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
#include <stdbool.h>
22
#include <TH/TH.h>
33

4+
// Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/
5+
// define PyInt_* macros for Python 3.x
6+
#ifndef PyInt_Check
7+
#define PyInt_Check PyLong_Check
8+
#define PyInt_FromLong PyLong_FromLong
9+
#define PyInt_AsLong PyLong_AsLong
10+
#define PyInt_Type PyLong_Type
11+
#endif
12+
413
#include "Exceptions.h"
514
#include "utils.h"
615

torch/csrc/generic/Storage.cpp

+23-11
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ PyObject * THPStorage_(newObject)(THStorage *ptr)
99
// TODO: error checking
1010
PyObject *args = PyTuple_New(0);
1111
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));
12+
1213
PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs);
1314
Py_DECREF(args);
1415
Py_DECREF(kwargs);
@@ -30,17 +31,17 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
3031
{
3132
HANDLE_TH_ERRORS
3233
static const char *keywords[] = {"cdata", NULL};
33-
PyObject *number_arg = NULL;
34-
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", (char **)keywords, &PyLong_Type, &number_arg))
34+
void* number_arg = NULL;
35+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O&", (char **)keywords,
36+
THPUtils_getLong, &number_arg))
3537
return NULL;
36-
3738
THPStorage *self = (THPStorage *)type->tp_alloc(type, 0);
3839
if (self != NULL) {
3940
if (kwargs) {
40-
self->cdata = (THStorage*)PyLong_AsVoidPtr(number_arg);
41+
self->cdata = (THStorage*)number_arg;
4142
THStorage_(retain)(self->cdata);
4243
} else if (/* !kwargs && */ number_arg) {
43-
self->cdata = THStorage_(newWithSize)(PyLong_AsLong(number_arg));
44+
self->cdata = THStorage_(newWithSize)((long) number_arg);
4445
} else {
4546
self->cdata = THStorage_(new)();
4647
}
@@ -66,8 +67,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
6667
{
6768
HANDLE_TH_ERRORS
6869
/* Integer index */
69-
if (PyLong_Check(index)) {
70-
long nindex = PyLong_AsLong(index);
70+
long nindex;
71+
if ((PyLong_Check(index) || PyInt_Check(index))
72+
&& THPUtils_getLong(index, &nindex) == 1 ) {
7173
if (nindex < 0)
7274
nindex += THStorage_(size)(self->cdata);
7375
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
@@ -89,7 +91,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
8991
THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength);
9092
return THPStorage_(newObject)(new_storage);
9193
}
92-
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported");
94+
char err_string[512];
95+
snprintf (err_string, 512,
96+
"%s %s", "Only indexing with integers and slices supported, but got type: ",
97+
index->ob_type->tp_name);
98+
PyErr_SetString(PyExc_RuntimeError, err_string);
9399
return NULL;
94100
END_HANDLE_TH_ERRORS
95101
}
@@ -101,8 +107,10 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
101107
if (!THPUtils_(parseReal)(value, &rvalue))
102108
return -1;
103109

104-
if (PyLong_Check(index)) {
105-
THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue);
110+
long nindex;
111+
if ((PyLong_Check(index) || PyInt_Check(index))
112+
&& THPUtils_getLong(index, &nindex) == 1) {
113+
THStorage_(set)(self->cdata, nindex, rvalue);
106114
return 0;
107115
} else if (PySlice_Check(index)) {
108116
Py_ssize_t start, stop, len;
@@ -114,7 +122,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
114122
THStorage_(set)(self->cdata, start, rvalue);
115123
return 0;
116124
}
117-
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported at the moment");
125+
char err_string[512];
126+
snprintf (err_string, 512, "%s %s",
127+
"Only indexing with integers and slices supported, but got type: ",
128+
index->ob_type->tp_name);
129+
PyErr_SetString(PyExc_RuntimeError, err_string);
118130
return -1;
119131
END_HANDLE_TH_ERRORS_RET(-1)
120132
}

torch/csrc/generic/StorageMethods.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static PyObject * THPStorage_(resize)(THPStorage *self, PyObject *number_arg)
4040
HANDLE_TH_ERRORS
4141
if (!PyLong_Check(number_arg))
4242
return NULL;
43-
size_t newsize = PyLong_AsSize_t(number_arg);
43+
long newsize = PyLong_AsLong(number_arg);
4444
if (PyErr_Occurred())
4545
return NULL;
4646
THStorage_(resize)(self->cdata, newsize);

0 commit comments

Comments
 (0)