Skip to content

Commit 7e4d799

Browse files
AlenkaFjorisvandenbosscherok
authored andcommitted
apacheGH-34882: [Python] Binding for FixedShapeTensorType (apache#34883)
### Rationale for this change In the C++ the fixed shape tensor canonical extension type is implementated apache#8510 so we can add bindings to the extension type in Python. ### What changes are included in this PR? Binding for fixed shape tensor canonical extension type. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * Closes: apache#34882 Lead-authored-by: Alenka Frim <frim.alenka@gmail.com> Co-authored-by: Alenka Frim <AlenkaF@users.noreply.github.com> Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com> Co-authored-by: Rok Mihevc <rok@mihevc.org> Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
1 parent a501621 commit 7e4d799

File tree

8 files changed

+418
-2
lines changed

8 files changed

+418
-2
lines changed

docs/source/format/CanonicalExtensions.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ same rules as laid out above, and provide backwards compatibility guarantees.
7272
Official List
7373
=============
7474

75+
.. _fixed_shape_tensor_extension:
76+
7577
Fixed shape tensor
7678
==================
7779

python/pyarrow/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def print_entry(label, value):
170170
union, sparse_union, dense_union,
171171
dictionary,
172172
run_end_encoded,
173+
fixed_shape_tensor,
173174
field,
174175
type_for_alias,
175176
DataType, DictionaryType, StructType,
@@ -178,7 +179,7 @@ def print_entry(label, value):
178179
TimestampType, Time32Type, Time64Type, DurationType,
179180
FixedSizeBinaryType, Decimal128Type, Decimal256Type,
180181
BaseExtensionType, ExtensionType,
181-
RunEndEncodedType,
182+
RunEndEncodedType, FixedShapeTensorType,
182183
PyExtensionType, UnknownExtensionType,
183184
register_extension_type, unregister_extension_type,
184185
DictionaryMemo,
@@ -209,7 +210,7 @@ def print_entry(label, value):
209210
Time32Array, Time64Array, DurationArray,
210211
MonthDayNanoIntervalArray,
211212
Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
212-
RunEndEncodedArray,
213+
RunEndEncodedArray, FixedShapeTensorArray,
213214
scalar, NA, _NULL as NULL, Scalar,
214215
NullScalar, BooleanScalar,
215216
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,

python/pyarrow/array.pxi

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3075,6 +3075,115 @@ cdef class ExtensionArray(Array):
30753075
return Array._to_pandas(self.storage, options, **kwargs)
30763076

30773077

3078+
class FixedShapeTensorArray(ExtensionArray):
3079+
"""
3080+
Concrete class for fixed shape tensor extension arrays.
3081+
3082+
Examples
3083+
--------
3084+
Define the extension type for tensor array
3085+
3086+
>>> import pyarrow as pa
3087+
>>> tensor_type = pa.fixed_shape_tensor(pa.int32(), [2, 2])
3088+
3089+
Create an extension array
3090+
3091+
>>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
3092+
>>> storage = pa.array(arr, pa.list_(pa.int32(), 4))
3093+
>>> pa.ExtensionArray.from_storage(tensor_type, storage)
3094+
<pyarrow.lib.FixedShapeTensorArray object at ...>
3095+
[
3096+
[
3097+
1,
3098+
2,
3099+
3,
3100+
4
3101+
],
3102+
[
3103+
10,
3104+
20,
3105+
30,
3106+
40
3107+
],
3108+
[
3109+
100,
3110+
200,
3111+
300,
3112+
400
3113+
]
3114+
]
3115+
"""
3116+
3117+
def to_numpy_ndarray(self):
3118+
"""
3119+
Convert fixed shape tensor extension array to a numpy array (with dim+1).
3120+
3121+
Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``).
3122+
"""
3123+
if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))):
3124+
np_flat = np.asarray(self.storage.values)
3125+
numpy_tensor = np_flat.reshape((len(self),) + tuple(self.type.shape))
3126+
return numpy_tensor
3127+
else:
3128+
raise ValueError(
3129+
'Only non-permuted tensors can be converted to numpy tensors.')
3130+
3131+
@staticmethod
3132+
def from_numpy_ndarray(obj):
3133+
"""
3134+
Convert numpy tensors (ndarrays) to a fixed shape tensor extension array.
3135+
The first dimension of ndarray will become the length of the fixed
3136+
shape tensor array.
3137+
3138+
Numpy array needs to be C-contiguous in memory
3139+
(``obj.flags["C_CONTIGUOUS"]==True``).
3140+
3141+
Parameters
3142+
----------
3143+
obj : numpy.ndarray
3144+
3145+
Examples
3146+
--------
3147+
>>> import pyarrow as pa
3148+
>>> import numpy as np
3149+
>>> arr = np.array(
3150+
... [[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
3151+
... dtype=np.float32)
3152+
>>> pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
3153+
<pyarrow.lib.FixedShapeTensorArray object at ...>
3154+
[
3155+
[
3156+
1,
3157+
2,
3158+
3,
3159+
4,
3160+
5,
3161+
6
3162+
],
3163+
[
3164+
1,
3165+
2,
3166+
3,
3167+
4,
3168+
5,
3169+
6
3170+
]
3171+
]
3172+
"""
3173+
if not obj.flags["C_CONTIGUOUS"]:
3174+
raise ValueError('The data in the numpy array need to be in a single, '
3175+
'C-style contiguous segment.')
3176+
3177+
arrow_type = from_numpy_dtype(obj.dtype)
3178+
shape = obj.shape[1:]
3179+
size = obj.size / obj.shape[0]
3180+
3181+
return ExtensionArray.from_storage(
3182+
fixed_shape_tensor(arrow_type, shape),
3183+
FixedSizeListArray.from_arrays(np.ravel(obj, order='C'), size)
3184+
)
3185+
3186+
30783187
cdef dict _array_classes = {
30793188
_Type_NA: NullArray,
30803189
_Type_BOOL: BooleanArray,

python/pyarrow/includes/libarrow.pxd

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2622,6 +2622,27 @@ cdef extern from "arrow/extension_type.h" namespace "arrow":
26222622
shared_ptr[CArray] storage()
26232623

26242624

2625+
cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extension":
2626+
cdef cppclass CFixedShapeTensorType \
2627+
" arrow::extension::FixedShapeTensorType"(CExtensionType):
2628+
2629+
@staticmethod
2630+
CResult[shared_ptr[CDataType]] Make(const shared_ptr[CDataType]& value_type,
2631+
const vector[int64_t]& shape,
2632+
const vector[int64_t]& permutation,
2633+
const vector[c_string]& dim_names)
2634+
2635+
CResult[shared_ptr[CDataType]] Deserialize(const shared_ptr[CDataType] storage_type,
2636+
const c_string& serialized_data) const
2637+
2638+
c_string Serialize() const
2639+
2640+
const shared_ptr[CDataType] value_type()
2641+
const vector[int64_t] shape()
2642+
const vector[int64_t] permutation()
2643+
const vector[c_string] dim_names()
2644+
2645+
26252646
cdef extern from "arrow/util/compression.h" namespace "arrow" nogil:
26262647
cdef enum CCompressionType" arrow::Compression::type":
26272648
CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED"

python/pyarrow/lib.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ cdef class ExtensionType(BaseExtensionType):
199199
const CPyExtensionType* cpy_ext_type
200200

201201

202+
cdef class FixedShapeTensorType(BaseExtensionType):
203+
cdef:
204+
const CFixedShapeTensorType* tensor_ext_type
205+
206+
202207
cdef class PyExtensionType(ExtensionType):
203208
pass
204209

python/pyarrow/public-api.pxi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ cdef api object pyarrow_wrap_data_type(
118118
cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type)
119119
if cpy_ext_type != nullptr:
120120
return cpy_ext_type.GetInstance()
121+
elif ext_type.extension_name() == b"arrow.fixed_shape_tensor":
122+
out = FixedShapeTensorType.__new__(FixedShapeTensorType)
121123
else:
122124
out = BaseExtensionType.__new__(BaseExtensionType)
123125
else:

python/pyarrow/tests/test_extension_type.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,102 @@ def test_cpp_extension_in_python(tmpdir):
11441144
assert reconstructed_array == array
11451145

11461146

1147+
def test_tensor_type():
1148+
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
1149+
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
1150+
assert tensor_type.storage_type == pa.list_(pa.int8(), 6)
1151+
assert tensor_type.shape == [2, 3]
1152+
assert tensor_type.dim_names is None
1153+
assert tensor_type.permutation is None
1154+
1155+
tensor_type = pa.fixed_shape_tensor(pa.float64(), [2, 2, 3],
1156+
permutation=[0, 2, 1])
1157+
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
1158+
assert tensor_type.storage_type == pa.list_(pa.float64(), 12)
1159+
assert tensor_type.shape == [2, 2, 3]
1160+
assert tensor_type.dim_names is None
1161+
assert tensor_type.permutation == [0, 2, 1]
1162+
1163+
tensor_type = pa.fixed_shape_tensor(pa.bool_(), [2, 2, 3],
1164+
dim_names=['C', 'H', 'W'])
1165+
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
1166+
assert tensor_type.storage_type == pa.list_(pa.bool_(), 12)
1167+
assert tensor_type.shape == [2, 2, 3]
1168+
assert tensor_type.dim_names == ['C', 'H', 'W']
1169+
assert tensor_type.permutation is None
1170+
1171+
1172+
def test_tensor_class_methods():
1173+
tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
1174+
storage = pa.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]],
1175+
pa.list_(pa.float32(), 6))
1176+
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
1177+
expected = np.array(
1178+
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=np.float32)
1179+
result = arr.to_numpy_ndarray()
1180+
np.testing.assert_array_equal(result, expected)
1181+
1182+
arr = np.array(
1183+
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
1184+
dtype=np.float32, order="C")
1185+
tensor_array_from_numpy = pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
1186+
assert isinstance(tensor_array_from_numpy.type, pa.FixedShapeTensorType)
1187+
assert tensor_array_from_numpy.type.value_type == pa.float32()
1188+
assert tensor_array_from_numpy.type.shape == [2, 3]
1189+
1190+
arr = np.array(
1191+
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
1192+
dtype=np.float32, order="F")
1193+
with pytest.raises(ValueError, match="C-style contiguous segment"):
1194+
pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
1195+
1196+
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1])
1197+
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
1198+
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
1199+
with pytest.raises(ValueError, match="non-permuted tensors"):
1200+
arr.to_numpy_ndarray()
1201+
1202+
1203+
@pytest.mark.parametrize("tensor_type", (
1204+
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
1205+
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1]),
1206+
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], dim_names=['C', 'H', 'W'])
1207+
))
1208+
def test_tensor_type_ipc(tensor_type):
1209+
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
1210+
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
1211+
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
1212+
1213+
# check the built array has exactly the expected clss
1214+
tensor_class = tensor_type.__arrow_ext_class__()
1215+
assert type(arr) == tensor_class
1216+
1217+
buf = ipc_write_batch(batch)
1218+
del batch
1219+
batch = ipc_read_batch(buf)
1220+
1221+
result = batch.column(0)
1222+
# check the deserialized array class is the expected one
1223+
assert type(result) == tensor_class
1224+
assert result.type.extension_name == "arrow.fixed_shape_tensor"
1225+
assert arr.storage.to_pylist() == [[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]]
1226+
1227+
# we get back an actual TensorType
1228+
assert isinstance(result.type, pa.FixedShapeTensorType)
1229+
assert result.type.value_type == pa.int8()
1230+
assert result.type.shape == [2, 2, 3]
1231+
1232+
1233+
def test_tensor_type_equality():
1234+
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3])
1235+
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
1236+
1237+
tensor_type2 = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3])
1238+
tensor_type3 = pa.fixed_shape_tensor(pa.uint8(), [2, 2, 3])
1239+
assert tensor_type == tensor_type2
1240+
assert not tensor_type == tensor_type3
1241+
1242+
11471243
@pytest.mark.pandas
11481244
def test_extension_to_pandas_storage_type(registered_period_type):
11491245
period_type, _ = registered_period_type

0 commit comments

Comments
 (0)