Skip to content

Commit 8214bc6

Browse files
committed
some python refactoring
1 parent 0eef0a3 commit 8214bc6

4 files changed

Lines changed: 227 additions & 55 deletions

File tree

python/pyarrow/array.pxi

Lines changed: 93 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4902,6 +4902,32 @@ def _validate_uniform_shape(uniform_shape, ndim):
49024902
"uniform_shape must contain non-negative values")
49034903

49044904

4905+
def _infer_uniform_shape(shape_rows, ndim):
4906+
if len(shape_rows) == 0:
4907+
return None
4908+
inferred = []
4909+
for i in range(ndim):
4910+
axis_size = shape_rows[0][i]
4911+
if all(shape[i] == axis_size for shape in shape_rows):
4912+
inferred.append(axis_size)
4913+
else:
4914+
inferred.append(None)
4915+
if all(x is None for x in inferred):
4916+
return None
4917+
return inferred
4918+
4919+
4920+
def _permutation_from_strides(arr):
4921+
"""Infer the dimension permutation from array strides.
4922+
4923+
Note: for arrays with size-1 dimensions, the inferred permutation
4924+
may be unreliable since size-1 strides are unconstrained. Callers
4925+
should skip permutation validation for such arrays.
4926+
"""
4927+
return [int(x) for x in
4928+
(-np.array(arr.strides, dtype=np.int64)).argsort(kind="stable")]
4929+
4930+
49054931
cdef class VariableShapeTensorArray(ExtensionArray):
49064932
"""
49074933
Concrete class for variable shape tensor extension arrays.
@@ -4974,13 +5000,11 @@ cdef class VariableShapeTensorArray(ExtensionArray):
49745000
cdef:
49755001
list arrays
49765002
list shape_rows
4977-
list inferred_uniform_shape
49785003
int array_ndim
49795004
int i
49805005
object base_dtype
49815006
DataType arrow_type
49825007
object normalized_permutation
4983-
object ndarray_permutation
49845008
object permutation_metadata
49855009
object shape_type
49865010
object values
@@ -4997,7 +5021,7 @@ cdef class VariableShapeTensorArray(ExtensionArray):
49975021
if value_type is not None and not isinstance(value_type, DataType):
49985022
try:
49995023
value_type = from_numpy_dtype(np.dtype(value_type))
5000-
except Exception as exc:
5024+
except (TypeError, ValueError) as exc:
50015025
raise TypeError(
50025026
"value_type must be a pyarrow.DataType or numpy dtype"
50035027
) from exc
@@ -5056,19 +5080,27 @@ cdef class VariableShapeTensorArray(ExtensionArray):
50565080
_validate_dim_names(dim_names, ndim)
50575081
normalized_permutation = _validate_permutation(permutation, ndim)
50585082

5083+
# Infer permutation if not provided by the user. Prefer arrays
5084+
# without size-1 dimensions since their strides are unambiguous.
5085+
if normalized_permutation is None:
5086+
for arr in arrays:
5087+
if all(s > 1 for s in arr.shape):
5088+
normalized_permutation = _permutation_from_strides(arr)
5089+
break
5090+
else:
5091+
# All arrays have size-1 dims; use first array's strides
5092+
normalized_permutation = _permutation_from_strides(arrays[0])
5093+
5094+
# Validate permutation consistency for arrays without size-1
5095+
# dims (size-1 strides are unconstrained, so skip those).
50595096
for i, arr in enumerate(arrays):
5060-
ndarray_permutation = (-np.array(arr.strides)).argsort(kind="stable")
5061-
ndarray_permutation_list = [int(x) for x in ndarray_permutation]
5062-
if normalized_permutation is None:
5063-
if i == 0:
5064-
normalized_permutation = ndarray_permutation_list
5065-
elif not np.array_equal(ndarray_permutation, normalized_permutation):
5066-
raise ValueError(
5067-
"All numpy arrays must have matching permutation.")
5068-
elif not np.array_equal(ndarray_permutation, normalized_permutation):
5097+
if any(s <= 1 for s in arr.shape):
5098+
continue
5099+
ndarray_permutation_list = _permutation_from_strides(arr)
5100+
if ndarray_permutation_list != normalized_permutation:
50695101
raise ValueError(
5070-
(f"obj[{i}] has permutation {ndarray_permutation_list}; expected "
5071-
f"{list(normalized_permutation)}"))
5102+
(f"obj[{i}] has permutation {ndarray_permutation_list}; "
5103+
f"expected {list(normalized_permutation)}"))
50725104

50735105
shape_rows = [
50745106
[int(x) for x in np.take(arr.shape, normalized_permutation)]
@@ -5084,17 +5116,24 @@ cdef class VariableShapeTensorArray(ExtensionArray):
50845116
(f"uniform_shape[{i}]={value} does not match input shape "
50855117
f"dimension values"))
50865118
else:
5087-
inferred_uniform_shape = []
5088-
for i in range(ndim):
5089-
axis_size = shape_rows[0][i]
5090-
if all(shape[i] == axis_size for shape in shape_rows):
5091-
inferred_uniform_shape.append(axis_size)
5092-
else:
5093-
inferred_uniform_shape.append(None)
5094-
if all(x is None for x in inferred_uniform_shape):
5095-
uniform_shape = None
5096-
else:
5097-
uniform_shape = inferred_uniform_shape
5119+
uniform_shape = _infer_uniform_shape(shape_rows, ndim)
5120+
5121+
# Verify that ravel(order="K") + inferred permutation are consistent
5122+
# by round-tripping the first non-empty array.
5123+
for arr in arrays:
5124+
if arr.size > 0:
5125+
raveled = np.ravel(arr, order="K")
5126+
physical_shape = tuple(
5127+
np.take(arr.shape, normalized_permutation))
5128+
reconstructed = raveled.reshape(physical_shape)
5129+
inv_perm = list(np.argsort(normalized_permutation))
5130+
reconstructed_logical = np.transpose(reconstructed, inv_perm)
5131+
if not np.array_equal(reconstructed_logical, arr):
5132+
raise ValueError(
5133+
"Array memory layout is incompatible with variable "
5134+
"shape tensor representation. Consider making the "
5135+
"array contiguous first with np.ascontiguousarray().")
5136+
break
50985137

50995138
values = array([np.ravel(arr, order="K") for arr in arrays], list_(arrow_type))
51005139
shapes = array(shape_rows, list_(int32(), list_size=ndim))
@@ -5123,7 +5162,7 @@ cdef class VariableShapeTensorArray(ExtensionArray):
51235162
list
51245163
List containing one ndarray per valid element and None for null elements.
51255164
"""
5126-
return [x.to_numpy_ndarray() if x.is_valid else None for x in self]
5165+
return [x.to_numpy() if x.is_valid else None for x in self]
51275166

51285167
def to_row_splits(self):
51295168
"""
@@ -5140,7 +5179,7 @@ cdef class VariableShapeTensorArray(ExtensionArray):
51405179
base = offsets[0].as_py()
51415180
if base == 0:
51425181
return offsets
5143-
return array([x - base for x in offsets.to_pylist()], type=offsets.type)
5182+
return _pc().subtract(offsets, base)
51445183

51455184
def to_offsets(self):
51465185
"""
@@ -5186,13 +5225,12 @@ cdef class VariableShapeTensorArray(ExtensionArray):
51865225
object data
51875226
object struct_arr
51885227
object ext_type
5189-
list inferred_uniform_shape
51905228
list shape_rows
51915229

51925230
if value_type is not None and not isinstance(value_type, DataType):
51935231
try:
51945232
value_type = from_numpy_dtype(np.dtype(value_type))
5195-
except Exception as exc:
5233+
except (TypeError, ValueError) as exc:
51965234
raise TypeError(
51975235
"value_type must be a pyarrow.DataType or numpy dtype"
51985236
) from exc
@@ -5240,6 +5278,9 @@ cdef class VariableShapeTensorArray(ExtensionArray):
52405278
elif ndim < 0:
52415279
raise ValueError("ndim must be non-negative")
52425280

5281+
_validate_dim_names(dim_names, ndim)
5282+
permutation = _validate_permutation(permutation, ndim)
5283+
52435284
shape_type = list_(int32(), list_size=ndim)
52445285
shape_arr = asarray(shapes, type=shape_type)
52455286
if isinstance(shape_arr, ChunkedArray):
@@ -5251,18 +5292,29 @@ cdef class VariableShapeTensorArray(ExtensionArray):
52515292
(f"shapes length ({len(shape_arr)}) must equal number of rows "
52525293
f"({len(splits) - 1})"))
52535294

5254-
if uniform_shape is None:
5255-
shape_rows = shape_arr.to_pylist()
5256-
if len(shape_rows) > 0:
5257-
inferred_uniform_shape = []
5258-
for i in range(ndim):
5259-
axis_size = shape_rows[0][i]
5260-
if all(shape[i] == axis_size for shape in shape_rows):
5261-
inferred_uniform_shape.append(axis_size)
5262-
else:
5263-
inferred_uniform_shape.append(None)
5264-
if any(x is not None for x in inferred_uniform_shape):
5265-
uniform_shape = inferred_uniform_shape
5295+
shape_rows = shape_arr.to_pylist()
5296+
5297+
# Validate that each row's shape product matches its segment length
5298+
for i in range(len(splits) - 1):
5299+
expected_size = 1
5300+
for dim in shape_rows[i]:
5301+
expected_size *= dim
5302+
actual_size = splits[i + 1] - splits[i]
5303+
if expected_size != actual_size:
5304+
raise ValueError(
5305+
(f"shapes[{i}] product ({expected_size}) does not match "
5306+
f"row_splits interval ({actual_size})"))
5307+
5308+
if uniform_shape is not None:
5309+
_validate_uniform_shape(uniform_shape, ndim)
5310+
for i, value in enumerate(uniform_shape):
5311+
if value is not None:
5312+
if any(shape[i] != value for shape in shape_rows):
5313+
raise ValueError(
5314+
(f"uniform_shape[{i}]={value} does not match "
5315+
f"shape dimension values"))
5316+
else:
5317+
uniform_shape = _infer_uniform_shape(shape_rows, ndim)
52665318

52675319
data = ListArray.from_arrays(row_splits_arr, values_arr)
52685320
struct_arr = StructArray.from_arrays([data, shape_arr], names=["data", "shape"])

python/pyarrow/scalar.pxi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,16 +1545,18 @@ cdef class VariableShapeTensorScalar(ExtensionScalar):
15451545
Concrete class for variable shape tensor extension scalar.
15461546
"""
15471547

1548-
def to_numpy_ndarray(self):
1548+
def to_numpy(self):
15491549
"""
1550-
Convert variable shape tensor extension scalar to a numpy array.
1550+
Convert variable shape tensor extension scalar to a numpy.ndarray.
15511551
15521552
The conversion is zero-copy if data is primitive numeric and without nulls.
15531553
15541554
Returns
15551555
-------
15561556
numpy.ndarray
15571557
"""
1558+
if not self.is_valid:
1559+
raise ValueError("Cannot convert null scalar to numpy array")
15581560
return self.to_tensor().to_numpy()
15591561

15601562
def to_tensor(self):
@@ -1565,6 +1567,8 @@ cdef class VariableShapeTensorScalar(ExtensionScalar):
15651567
-------
15661568
tensor : pyarrow.Tensor
15671569
"""
1570+
if not self.is_valid:
1571+
raise ValueError("Cannot convert null scalar to Tensor")
15681572
cdef:
15691573
CVariableShapeTensorType* c_type = static_pointer_cast[CVariableShapeTensorType, CDataType](
15701574
self.wrapped.get().type).get()

0 commit comments

Comments
 (0)