@@ -4902,6 +4902,32 @@ def _validate_uniform_shape(uniform_shape, ndim):
49024902 " uniform_shape must contain non-negative values" )
49034903
49044904
4905+ def _infer_uniform_shape (shape_rows , ndim ):
4906+ if len (shape_rows) == 0 :
4907+ return None
4908+ inferred = []
4909+ for i in range (ndim):
4910+ axis_size = shape_rows[0 ][i]
4911+ if all (shape[i] == axis_size for shape in shape_rows):
4912+ inferred.append(axis_size)
4913+ else :
4914+ inferred.append(None )
4915+ if all (x is None for x in inferred):
4916+ return None
4917+ return inferred
4918+
4919+
4920+ def _permutation_from_strides (arr ):
4921+ """ Infer the dimension permutation from array strides.
4922+
4923+ Note: for arrays with size-1 dimensions, the inferred permutation
4924+ may be unreliable since size-1 strides are unconstrained. Callers
4925+ should skip permutation validation for such arrays.
4926+ """
4927+ return [int (x) for x in
4928+ (- np.array(arr.strides, dtype = np.int64)).argsort(kind = " stable" )]
4929+
4930+
49054931cdef class VariableShapeTensorArray(ExtensionArray):
49064932 """
49074933 Concrete class for variable shape tensor extension arrays.
@@ -4974,13 +5000,11 @@ cdef class VariableShapeTensorArray(ExtensionArray):
49745000 cdef:
49755001 list arrays
49765002 list shape_rows
4977- list inferred_uniform_shape
49785003 int array_ndim
49795004 int i
49805005 object base_dtype
49815006 DataType arrow_type
49825007 object normalized_permutation
4983- object ndarray_permutation
49845008 object permutation_metadata
49855009 object shape_type
49865010 object values
@@ -4997,7 +5021,7 @@ cdef class VariableShapeTensorArray(ExtensionArray):
49975021 if value_type is not None and not isinstance (value_type, DataType):
49985022 try :
49995023 value_type = from_numpy_dtype(np.dtype(value_type))
5000- except Exception as exc:
5024+ except ( TypeError , ValueError ) as exc:
50015025 raise TypeError (
50025026 " value_type must be a pyarrow.DataType or numpy dtype"
50035027 ) from exc
@@ -5056,19 +5080,27 @@ cdef class VariableShapeTensorArray(ExtensionArray):
50565080 _validate_dim_names(dim_names, ndim)
50575081 normalized_permutation = _validate_permutation(permutation, ndim)
50585082
5083+ # Infer permutation if not provided by the user. Prefer arrays
5084+ # without size-1 dimensions since their strides are unambiguous.
5085+ if normalized_permutation is None :
5086+ for arr in arrays:
5087+ if all (s > 1 for s in arr.shape):
5088+ normalized_permutation = _permutation_from_strides(arr)
5089+ break
5090+ else :
5091+ # All arrays have size-1 dims; use first array's strides
5092+ normalized_permutation = _permutation_from_strides(arrays[0 ])
5093+
5094+ # Validate permutation consistency for arrays without size-1
5095+ # dims (size-1 strides are unconstrained, so skip those).
50595096 for i, arr in enumerate (arrays):
5060- ndarray_permutation = (- np.array(arr.strides)).argsort(kind = " stable" )
5061- ndarray_permutation_list = [int (x) for x in ndarray_permutation]
5062- if normalized_permutation is None :
5063- if i == 0 :
5064- normalized_permutation = ndarray_permutation_list
5065- elif not np.array_equal(ndarray_permutation, normalized_permutation):
5066- raise ValueError (
5067- " All numpy arrays must have matching permutation." )
5068- elif not np.array_equal(ndarray_permutation, normalized_permutation):
5097+ if any (s <= 1 for s in arr.shape):
5098+ continue
5099+ ndarray_permutation_list = _permutation_from_strides(arr)
5100+ if ndarray_permutation_list != normalized_permutation:
50695101 raise ValueError (
5070- (f" obj[{i}] has permutation {ndarray_permutation_list}; expected "
5071- f" {list(normalized_permutation)}" ))
5102+ (f" obj[{i}] has permutation {ndarray_permutation_list}; "
5103+ f" expected {list(normalized_permutation)}" ))
50725104
50735105 shape_rows = [
50745106 [int (x) for x in np.take(arr.shape, normalized_permutation)]
@@ -5084,17 +5116,24 @@ cdef class VariableShapeTensorArray(ExtensionArray):
50845116 (f" uniform_shape[{i}]={value} does not match input shape "
50855117 f" dimension values" ))
50865118 else :
5087- inferred_uniform_shape = []
5088- for i in range (ndim):
5089- axis_size = shape_rows[0 ][i]
5090- if all (shape[i] == axis_size for shape in shape_rows):
5091- inferred_uniform_shape.append(axis_size)
5092- else :
5093- inferred_uniform_shape.append(None )
5094- if all (x is None for x in inferred_uniform_shape):
5095- uniform_shape = None
5096- else :
5097- uniform_shape = inferred_uniform_shape
5119+ uniform_shape = _infer_uniform_shape(shape_rows, ndim)
5120+
5121+ # Verify that ravel(order="K") + inferred permutation are consistent
5122+ # by round-tripping the first non-empty array.
5123+ for arr in arrays:
5124+ if arr.size > 0 :
5125+ raveled = np.ravel(arr, order = " K" )
5126+ physical_shape = tuple (
5127+ np.take(arr.shape, normalized_permutation))
5128+ reconstructed = raveled.reshape(physical_shape)
5129+ inv_perm = list (np.argsort(normalized_permutation))
5130+ reconstructed_logical = np.transpose(reconstructed, inv_perm)
5131+ if not np.array_equal(reconstructed_logical, arr):
5132+ raise ValueError (
5133+ " Array memory layout is incompatible with variable "
5134+ " shape tensor representation. Consider making the "
5135+ " array contiguous first with np.ascontiguousarray()." )
5136+ break
50985137
50995138 values = array([np.ravel(arr, order = " K" ) for arr in arrays], list_(arrow_type))
51005139 shapes = array(shape_rows, list_(int32(), list_size = ndim))
@@ -5123,7 +5162,7 @@ cdef class VariableShapeTensorArray(ExtensionArray):
51235162 list
51245163 List containing one ndarray per valid element and None for null elements.
51255164 """
5126- return [x.to_numpy_ndarray () if x.is_valid else None for x in self ]
5165+ return [x.to_numpy () if x.is_valid else None for x in self ]
51275166
51285167 def to_row_splits (self ):
51295168 """
@@ -5140,7 +5179,7 @@ cdef class VariableShapeTensorArray(ExtensionArray):
51405179 base = offsets[0 ].as_py()
51415180 if base == 0 :
51425181 return offsets
5143- return array([x - base for x in offsets.to_pylist()], type = offsets.type )
5182+ return _pc().subtract(offsets, base )
51445183
51455184 def to_offsets (self ):
51465185 """
@@ -5186,13 +5225,12 @@ cdef class VariableShapeTensorArray(ExtensionArray):
51865225 object data
51875226 object struct_arr
51885227 object ext_type
5189- list inferred_uniform_shape
51905228 list shape_rows
51915229
51925230 if value_type is not None and not isinstance (value_type, DataType):
51935231 try :
51945232 value_type = from_numpy_dtype(np.dtype(value_type))
5195- except Exception as exc:
5233+ except ( TypeError , ValueError ) as exc:
51965234 raise TypeError (
51975235 " value_type must be a pyarrow.DataType or numpy dtype"
51985236 ) from exc
@@ -5240,6 +5278,9 @@ cdef class VariableShapeTensorArray(ExtensionArray):
52405278 elif ndim < 0 :
52415279 raise ValueError (" ndim must be non-negative" )
52425280
5281+ _validate_dim_names(dim_names, ndim)
5282+ permutation = _validate_permutation(permutation, ndim)
5283+
52435284 shape_type = list_(int32(), list_size = ndim)
52445285 shape_arr = asarray(shapes, type = shape_type)
52455286 if isinstance (shape_arr, ChunkedArray):
@@ -5251,18 +5292,29 @@ cdef class VariableShapeTensorArray(ExtensionArray):
52515292 (f" shapes length ({len(shape_arr)}) must equal number of rows "
52525293 f" ({len(splits) - 1})" ))
52535294
5254- if uniform_shape is None :
5255- shape_rows = shape_arr.to_pylist()
5256- if len (shape_rows) > 0 :
5257- inferred_uniform_shape = []
5258- for i in range (ndim):
5259- axis_size = shape_rows[0 ][i]
5260- if all (shape[i] == axis_size for shape in shape_rows):
5261- inferred_uniform_shape.append(axis_size)
5262- else :
5263- inferred_uniform_shape.append(None )
5264- if any (x is not None for x in inferred_uniform_shape):
5265- uniform_shape = inferred_uniform_shape
5295+ shape_rows = shape_arr.to_pylist()
5296+
5297+ # Validate that each row's shape product matches its segment length
5298+ for i in range (len (splits) - 1 ):
5299+ expected_size = 1
5300+ for dim in shape_rows[i]:
5301+ expected_size *= dim
5302+ actual_size = splits[i + 1 ] - splits[i]
5303+ if expected_size != actual_size:
5304+ raise ValueError (
5305+ (f" shapes[{i}] product ({expected_size}) does not match "
5306+ f" row_splits interval ({actual_size})" ))
5307+
5308+ if uniform_shape is not None :
5309+ _validate_uniform_shape(uniform_shape, ndim)
5310+ for i, value in enumerate (uniform_shape):
5311+ if value is not None :
5312+ if any (shape[i] != value for shape in shape_rows):
5313+ raise ValueError (
5314+ (f" uniform_shape[{i}]={value} does not match "
5315+ f" shape dimension values" ))
5316+ else :
5317+ uniform_shape = _infer_uniform_shape(shape_rows, ndim)
52665318
52675319 data = ListArray.from_arrays(row_splits_arr, values_arr)
52685320 struct_arr = StructArray.from_arrays([data, shape_arr], names = [" data" , " shape" ])
0 commit comments