-
Couldn't load subscription status.
- Fork 2.4k
Torch for Numpy users
Peter O'Connor edited this page Aug 25, 2018
·
17 revisions
torch equivalents of numpy functions
| Numpy | Torch |
|---|---|
| np.ndarray | torch.Tensor |
| np.float32 | torch.FloatTensor |
| np.float64 | torch.DoubleTensor |
| np.int8 | torch.CharTensor |
| np.uint8 | torch.ByteTensor |
| np.int16 | torch.ShortTensor |
| np.int32 | torch.IntTensor |
| np.int64 | torch.LongTensor |
| Numpy | Torch |
|---|---|
| np.empty([2,2]) | torch.Tensor(2,2) |
| np.empty_like(x) | x.new(x:size()) |
| np.eye | torch.eye |
| np.identity | torch.eye |
| np.ones | torch.ones |
| np.ones_like | torch.ones(x:size()) |
| np.zeros | torch.zeros |
| np.zeros_like | torch.zeros(x:size()) |
| Numpy | Torch |
|---|---|
| np.array([ [1,2],[3,4] ]) | torch.Tensor({{1,2},{3,4}}) |
| np.ascontiguousarray(x) | x:contiguous() |
| np.copy(x) | x:clone() |
| np.fromfile(file) | torch.Tensor(torch.Storage(file)) |
| np.frombuffer | ??? |
| np.fromfunction | ??? |
| np.fromiter | ??? |
| np.fromstring | ??? |
| np.loadtxt | ??? |
| np.concatenate | torch.cat |
| np.multiply | torch.cmul |
| Numpy | Torch |
|---|---|
| np.arange(10) | torch.range(0,9) |
| np.arange(2, 3, 0.1) | torch.linspace(2, 2.9, 10) |
| np.linspace(1, 4, 6) | torch.linspace(1, 4, 6) |
| np.logspace | torch.logspace |
| Numpy | Torch |
|---|---|
| np.diag | torch.diag |
| np.tril | torch.tril |
| np.triu | torch.triu |
| Numpy | Torch |
|---|---|
| x.shape | x:size() |
| x.strides | x:stride() |
| x.ndim | x:dim() |
| x.data | x:data() |
| x.size | x:nElement() |
| x.size == y.size | x:isSameSizeAs(y) |
| x.dtype | x:type() |
| Numpy | Torch |
|---|
| Numpy | Torch |
|---|---|
| x.reshape | x:reshape |
| x.resize | x:resize |
| ? | x:resizeAs |
| x.transpose | x:transpose() |
| x.flatten | x:view(x:nElement()) |
| x.squeeze | x:squeeze |
| Numpy | Torch |
|---|---|
| np.take(a, indices) | a[indices] |
| x[:,0] | x[{{},1}] |
| np.put | ???? |
| x.repeat | x:repeatTensor |
| x.fill | x:fill |
| np.choose | ??? |
| np.sort | sorted, indices = torch.sort(x, [dim]) |
| np.argsort | sorted, indices = torch.sort(x, [dim]) |
| np.nonzero | torch.find(x:gt(0), 1) (torchx) |
| Numpy | Torch |
|---|---|
| ndarray.min | mins, indices = torch.min(x, [dim]) |
| ndarray.argmin | mins, indices = torch.min(x, [dim]) |
| ndarray.max | maxs, indices = torch.max(x, [dim]) |
| ndarray.argmax | maxs, indices = torch.max(x, [dim]) |
| ndarray.clip | torch.clamp |
| ndarray.round | |
| ndarray.trace | torch.trace |
| ndarray.sum | torch.sum |
| ndarray.cumsum | torch.cumsum |
| ndarray.mean | torch.mean |
| ndarray.std | torch.std |
| ndarray.prod | torch.prod |
| ndarray.dot | torch.mm |
| ndarray.cumprod | torch.cumprod |
| ndarray.all | ??? |
| ndarray.any | ??? |
| Numpy | Torch |
|---|---|
| ndarray.lt | torch.lt |
| ndarray.le | torch.le |
| ndarray.gt | torch.gt |
| ndarray.ge | torch.ge |
| ndarray.eq | torch.eq |
| ndarray.ne | torch.ne |