@@ -1144,6 +1144,102 @@ def test_cpp_extension_in_python(tmpdir):
11441144 assert reconstructed_array == array
11451145
11461146
1147+ def test_tensor_type ():
1148+ tensor_type = pa .fixed_shape_tensor (pa .int8 (), [2 , 3 ])
1149+ assert tensor_type .extension_name == "arrow.fixed_shape_tensor"
1150+ assert tensor_type .storage_type == pa .list_ (pa .int8 (), 6 )
1151+ assert tensor_type .shape == [2 , 3 ]
1152+ assert tensor_type .dim_names is None
1153+ assert tensor_type .permutation is None
1154+
1155+ tensor_type = pa .fixed_shape_tensor (pa .float64 (), [2 , 2 , 3 ],
1156+ permutation = [0 , 2 , 1 ])
1157+ assert tensor_type .extension_name == "arrow.fixed_shape_tensor"
1158+ assert tensor_type .storage_type == pa .list_ (pa .float64 (), 12 )
1159+ assert tensor_type .shape == [2 , 2 , 3 ]
1160+ assert tensor_type .dim_names is None
1161+ assert tensor_type .permutation == [0 , 2 , 1 ]
1162+
1163+ tensor_type = pa .fixed_shape_tensor (pa .bool_ (), [2 , 2 , 3 ],
1164+ dim_names = ['C' , 'H' , 'W' ])
1165+ assert tensor_type .extension_name == "arrow.fixed_shape_tensor"
1166+ assert tensor_type .storage_type == pa .list_ (pa .bool_ (), 12 )
1167+ assert tensor_type .shape == [2 , 2 , 3 ]
1168+ assert tensor_type .dim_names == ['C' , 'H' , 'W' ]
1169+ assert tensor_type .permutation is None
1170+
1171+
1172+ def test_tensor_class_methods ():
1173+ tensor_type = pa .fixed_shape_tensor (pa .float32 (), [2 , 3 ])
1174+ storage = pa .array ([[1 , 2 , 3 , 4 , 5 , 6 ], [1 , 2 , 3 , 4 , 5 , 6 ]],
1175+ pa .list_ (pa .float32 (), 6 ))
1176+ arr = pa .ExtensionArray .from_storage (tensor_type , storage )
1177+ expected = np .array (
1178+ [[[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 3 ], [4 , 5 , 6 ]]], dtype = np .float32 )
1179+ result = arr .to_numpy_ndarray ()
1180+ np .testing .assert_array_equal (result , expected )
1181+
1182+ arr = np .array (
1183+ [[[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 3 ], [4 , 5 , 6 ]]],
1184+ dtype = np .float32 , order = "C" )
1185+ tensor_array_from_numpy = pa .FixedShapeTensorArray .from_numpy_ndarray (arr )
1186+ assert isinstance (tensor_array_from_numpy .type , pa .FixedShapeTensorType )
1187+ assert tensor_array_from_numpy .type .value_type == pa .float32 ()
1188+ assert tensor_array_from_numpy .type .shape == [2 , 3 ]
1189+
1190+ arr = np .array (
1191+ [[[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 3 ], [4 , 5 , 6 ]]],
1192+ dtype = np .float32 , order = "F" )
1193+ with pytest .raises (ValueError , match = "C-style contiguous segment" ):
1194+ pa .FixedShapeTensorArray .from_numpy_ndarray (arr )
1195+
1196+ tensor_type = pa .fixed_shape_tensor (pa .int8 (), [2 , 2 , 3 ], permutation = [0 , 2 , 1 ])
1197+ storage = pa .array ([[1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 ]], pa .list_ (pa .int8 (), 12 ))
1198+ arr = pa .ExtensionArray .from_storage (tensor_type , storage )
1199+ with pytest .raises (ValueError , match = "non-permuted tensors" ):
1200+ arr .to_numpy_ndarray ()
1201+
1202+
1203+ @pytest .mark .parametrize ("tensor_type" , (
1204+ pa .fixed_shape_tensor (pa .int8 (), [2 , 2 , 3 ]),
1205+ pa .fixed_shape_tensor (pa .int8 (), [2 , 2 , 3 ], permutation = [0 , 2 , 1 ]),
1206+ pa .fixed_shape_tensor (pa .int8 (), [2 , 2 , 3 ], dim_names = ['C' , 'H' , 'W' ])
1207+ ))
1208+ def test_tensor_type_ipc (tensor_type ):
1209+ storage = pa .array ([[1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 ]], pa .list_ (pa .int8 (), 12 ))
1210+ arr = pa .ExtensionArray .from_storage (tensor_type , storage )
1211+ batch = pa .RecordBatch .from_arrays ([arr ], ["ext" ])
1212+
1213+ # check the built array has exactly the expected clss
1214+ tensor_class = tensor_type .__arrow_ext_class__ ()
1215+ assert type (arr ) == tensor_class
1216+
1217+ buf = ipc_write_batch (batch )
1218+ del batch
1219+ batch = ipc_read_batch (buf )
1220+
1221+ result = batch .column (0 )
1222+ # check the deserialized array class is the expected one
1223+ assert type (result ) == tensor_class
1224+ assert result .type .extension_name == "arrow.fixed_shape_tensor"
1225+ assert arr .storage .to_pylist () == [[1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 ]]
1226+
1227+ # we get back an actual TensorType
1228+ assert isinstance (result .type , pa .FixedShapeTensorType )
1229+ assert result .type .value_type == pa .int8 ()
1230+ assert result .type .shape == [2 , 2 , 3 ]
1231+
1232+
1233+ def test_tensor_type_equality ():
1234+ tensor_type = pa .fixed_shape_tensor (pa .int8 (), [2 , 2 , 3 ])
1235+ assert tensor_type .extension_name == "arrow.fixed_shape_tensor"
1236+
1237+ tensor_type2 = pa .fixed_shape_tensor (pa .int8 (), [2 , 2 , 3 ])
1238+ tensor_type3 = pa .fixed_shape_tensor (pa .uint8 (), [2 , 2 , 3 ])
1239+ assert tensor_type == tensor_type2
1240+ assert not tensor_type == tensor_type3
1241+
1242+
11471243@pytest .mark .pandas
11481244def test_extension_to_pandas_storage_type (registered_period_type ):
11491245 period_type , _ = registered_period_type
0 commit comments