Skip to content

Commit b4cb6cd

Browse files
committed
misc fixes
1 parent 7f84762 commit b4cb6cd

File tree

4 files changed

+65
-24
lines changed

4 files changed

+65
-24
lines changed

python/pyarrow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def print_entry(label, value):
234234
FixedSizeBinaryScalar, DictionaryScalar,
235235
MapScalar, StructScalar, UnionScalar,
236236
RunEndEncodedScalar, Bool8Scalar, ExtensionScalar,
237-
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar)
237+
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar,
238+
VariableShapeTensorScalar)
238239

239240
# Buffers, allocation
240241
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,

python/pyarrow/array.pxi

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4974,35 +4974,37 @@ cdef class VariableShapeTensorArray(ExtensionArray):
49744974
]
49754975
]
49764976
"""
4977-
assert isinstance(obj, list), 'obj must be a list of numpy arrays'
4977+
if not isinstance(obj, list) or len(obj) == 0:
4978+
raise TypeError('obj must be a non-empty list of numpy arrays')
49784979
numpy_type = obj[0].dtype
49794980
arrow_type = from_numpy_dtype(numpy_type)
49804981
ndim = obj[0].ndim
49814982
permutations = [(-np.array(o.strides)).argsort(kind="stable") for o in obj]
49824983
permutation = permutations[0]
49834984
shapes = [np.take(o.shape, permutation) for o in obj]
49844985

4985-
if not all([o.dtype == numpy_type for o in obj]):
4986+
if not all(o.dtype == numpy_type for o in obj):
49864987
raise TypeError('All numpy arrays must have matching dtype.')
49874988

4988-
if not all([o.ndim == ndim for o in obj]):
4989+
if not all(o.ndim == ndim for o in obj):
49894990
raise ValueError('All numpy arrays must have matching ndim.')
49904991

4991-
if not all([np.array_equal(p, permutation) for p in permutations]):
4992+
if not all(np.array_equal(p, permutation) for p in permutations):
49924993
raise ValueError('All numpy arrays must have matching permutation.')
49934994

49944995
for shape in shapes:
49954996
if len(shape) < 2:
49964997
raise ValueError(
4997-
"Cannot convert 1D array or scalar to fixed shape tensor array")
4998+
"Cannot convert 1D array or scalar to variable shape tensor array")
49984999
if np.prod(shape) == 0:
49995000
raise ValueError("Expected a non-empty ndarray")
50005001

50015002
values = array([np.ravel(o, order="K") for o in obj], list_(arrow_type))
50025003
shapes = array(shapes, list_(int32(), list_size=ndim))
50035004
struct_arr = StructArray.from_arrays([values, shapes], names=["data", "shape"])
50045005

5005-
return ExtensionArray.from_storage(variable_shape_tensor(arrow_type, ndim, permutation=permutation), struct_arr)
5006+
ext_type = variable_shape_tensor(arrow_type, ndim, permutation=permutation)
5007+
return ExtensionArray.from_storage(ext_type, struct_arr)
50065008

50075009

50085010
cdef dict _array_classes = {

python/pyarrow/tests/test_extension_type.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,8 +1639,10 @@ def test_tensor_array_from_numpy(np_type_str):
16391639

16401640

16411641
@pytest.mark.numpy
1642-
@pytest.mark.parametrize("value_type", (np.int8, np.int32, np.int64, np.float64))
1642+
@pytest.mark.parametrize("value_type", (
1643+
"int8", "int32", "int64", "float64"))
16431644
def test_variable_shape_tensor_class_methods(value_type):
1645+
value_type = getattr(np, value_type)
16441646
ndim = 2
16451647
shape_type = pa.list_(pa.int32(), ndim)
16461648
arrow_type = pa.from_numpy_dtype(value_type)
@@ -1707,8 +1709,10 @@ def test_variable_shape_tensor_class_methods(value_type):
17071709
[], dtype=value_type).reshape(shapes[1].as_py()))
17081710

17091711

1710-
@pytest.mark.parametrize("value_type", (np.int8(), np.int64(), np.float32()))
1712+
@pytest.mark.numpy
1713+
@pytest.mark.parametrize("value_type", ("int8", "int64", "float32"))
17111714
def test_variable_shape_tensor_array_from_numpy(value_type):
1715+
value_type = np.dtype(value_type).type()
17121716
arrow_type = pa.from_numpy_dtype(value_type)
17131717

17141718
arr = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
@@ -1729,38 +1733,31 @@ def test_variable_shape_tensor_array_from_numpy(value_type):
17291733
pa.VariableShapeTensorArray.from_numpy_ndarray([arr.astype(np.int32()), arr])
17301734

17311735
flat_arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=value_type)
1732-
bw = value_type.itemsize
17331736

17341737
arr = flat_arr.reshape(1, 3, 4)
17351738
tensor_array_from_numpy = pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17361739
assert tensor_array_from_numpy.type.ndim == 3
17371740
assert tensor_array_from_numpy.type.permutation == [0, 1, 2]
17381741
assert tensor_array_from_numpy[0].to_tensor() == pa.Tensor.from_numpy(arr)
17391742

1740-
arr = as_strided(flat_arr, shape=(1, 2, 3, 2),
1741-
strides=(bw * 12, bw * 6, bw, bw * 3))
1742-
tensor_array_from_numpy = pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
1743-
assert tensor_array_from_numpy.type.ndim == 4
1744-
assert tensor_array_from_numpy.type.permutation == [0, 1, 3, 2]
1745-
assert tensor_array_from_numpy[0].to_tensor() == pa.Tensor.from_numpy(arr)
1746-
17471743
arr = flat_arr.reshape(1, 2, 3, 2)
17481744
result = pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17491745
expected = np.array(
17501746
[[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]], dtype=value_type)
17511747
np.testing.assert_array_equal(result[0].to_numpy_ndarray(), expected)
17521748

17531749
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=value_type)
1754-
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to fixed"):
1750+
msg = "Cannot convert 1D array or scalar to variable"
1751+
with pytest.raises(ValueError, match=msg):
17551752
pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17561753

17571754
arr = np.array(1, dtype=value_type)
1758-
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to fixed"):
1755+
with pytest.raises(ValueError, match=msg):
17591756
pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17601757

17611758
arr = np.array([], dtype=value_type)
17621759

1763-
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to fixed"):
1760+
with pytest.raises(ValueError, match=msg):
17641761
pa.VariableShapeTensorArray.from_numpy_ndarray([arr.reshape((0))])
17651762

17661763
with pytest.raises(ValueError, match="Expected a non-empty ndarray"):
@@ -1770,6 +1767,45 @@ def test_variable_shape_tensor_array_from_numpy(value_type):
17701767
pa.VariableShapeTensorArray.from_numpy_ndarray([arr.reshape((3, 0, 2))])
17711768

17721769

1770+
@pytest.mark.numpy
1771+
@pytest.mark.parametrize("dtype", (
1772+
"int8", "int16", "int32", "int64", "float32", "float64"))
1773+
@pytest.mark.parametrize("order", ("F", "C"))
1774+
def test_variable_shape_tensor_permutation_2d(dtype, order):
1775+
"""Roundtrip 2D arrays with C and Fortran order."""
1776+
dtype = np.dtype(dtype)
1777+
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
1778+
dtype=dtype, order=order)
1779+
expected_perm = [1, 0] if order == "F" else [0, 1]
1780+
1781+
result = pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
1782+
assert result.type.permutation == expected_perm
1783+
assert result.type.ndim == 2
1784+
1785+
tensor = result[0].to_tensor()
1786+
assert tensor == pa.Tensor.from_numpy(arr)
1787+
assert list(tensor.shape) == list(arr.shape)
1788+
assert list(tensor.strides) == list(arr.strides)
1789+
np.testing.assert_array_equal(result[0].to_numpy_ndarray(), arr)
1790+
1791+
# Verify stored shape is the physical shape (permuted), per spec
1792+
stored_shape = result.storage.field("shape")[0].as_py()
1793+
assert stored_shape == list(np.take(arr.shape, expected_perm))
1794+
1795+
1796+
@pytest.mark.numpy
1797+
def test_variable_shape_tensor_permutation_multi_array():
1798+
"""Multiple variable-shape 2D arrays with the same non-trivial permutation."""
1799+
arr1 = np.array([[10, 20], [30, 40], [50, 60]], dtype=np.int64, order="F")
1800+
arr2 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64, order="F")
1801+
result = pa.VariableShapeTensorArray.from_numpy_ndarray([arr1, arr2])
1802+
assert result.type.permutation == [1, 0]
1803+
assert len(result) == 2
1804+
for i, expected in enumerate([arr1, arr2]):
1805+
np.testing.assert_array_equal(result[i].to_numpy_ndarray(), expected)
1806+
assert result[i].to_tensor() == pa.Tensor.from_numpy(expected)
1807+
1808+
17731809
@pytest.mark.parametrize("tensor_type", (
17741810
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
17751811
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1]),
@@ -1969,7 +2005,7 @@ def test_variable_shape_tensor_type_is_picklable(pickle_module):
19692005
'fixed_shape_tensor[value_type=int64, shape=[2,2,3], dim_names=[C,H,W]]'
19702006
)
19712007
])
1972-
def test_tensor_type_str(tensor_type, text, pickle_module):
2008+
def test_tensor_type_str(tensor_type, text):
19732009
tensor_type_str = tensor_type.__str__()
19742010
assert text in tensor_type_str
19752011

python/pyarrow/types.pxi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,7 +2075,7 @@ cdef class VariableShapeTensorType(BaseExtensionType):
20752075

20762076
def __reduce__(self):
20772077
return variable_shape_tensor, (self.value_type, self.ndim,
2078-
self.permutation, self.dim_names, self.uniform_shape)
2078+
self.dim_names, self.permutation, self.uniform_shape)
20792079

20802080
def __arrow_ext_scalar_class__(self):
20812081
return VariableShapeTensorScalar
@@ -5936,8 +5936,10 @@ def variable_shape_tensor(DataType value_type, ndim, dim_names=None, permutation
59365936
vector[optional[int64_t]] c_uniform_shape
59375937
shared_ptr[CDataType] c_tensor_ext_type
59385938

5939-
assert value_type is not None
5940-
assert ndim is not None
5939+
if value_type is None:
5940+
raise TypeError('value_type must not be None')
5941+
if ndim is None:
5942+
raise TypeError('ndim must not be None')
59415943

59425944
c_ndim = ndim
59435945

0 commit comments

Comments
 (0)