Skip to content

Commit dbf1194

Browse files
committed
Add tests and replace to_pylist with equals() post #48727
1 parent 3f255f2 commit dbf1194

File tree

2 files changed

+41
-39
lines changed

2 files changed

+41
-39
lines changed

python/pyarrow/src/arrow/python/python_to_arrow.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ class PyConverter : public Converter<PyObject*, PyConversionOptions> {
585585
};
586586

587587
// Helper function to unwrap extension scalar to its storage scalar
588-
inline const Scalar& GetStorageScalar(const Scalar& scalar) {
588+
const Scalar& GetStorageScalar(const Scalar& scalar) {
589589
if (scalar.type->id() == Type::EXTENSION) {
590590
return *checked_cast<const ExtensionScalar&>(scalar).value;
591591
}

python/pyarrow/tests/test_extension_type.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,67 +1487,69 @@ def bytes(self):
14871487

14881488

14891489
def test_array_from_extension_scalars():
1490-
# Test unwrap to various storage types and different converters
14911490
import datetime
1491+
from decimal import Decimal
14921492

14931493
builtin_cases = [
1494-
# fixed_size_binary[16] storage
1495-
(pa.uuid(), [b"0123456789abcdef"], [
1496-
UUID('30313233-3435-3637-3839-616263646566')]),
1497-
# int8 storage
1498-
(pa.bool8(), [0, 1], [0, 1]),
1499-
# string storage
1500-
(pa.json_(pa.string()), ['{"a":1}', '{"b":2}'], ['{"a":1}', '{"b":2}']),
1501-
# binary storage
1502-
(pa.opaque(pa.binary(), "t", "v"), [b"x", b"y"], [b"x", b"y"]),
1494+
(pa.uuid(), [b"0123456789abcdef"]),
1495+
(pa.bool8(), [0, 1]),
1496+
(pa.json_(pa.string()), ['{"a":1}', '{"b":2}']),
1497+
(pa.opaque(pa.binary(), "t", "v"), [b"x", b"y"]),
15031498
]
1504-
for ext_type, values, expected in builtin_cases:
1499+
for ext_type, values in builtin_cases:
15051500
scalars = [pa.scalar(v, type=ext_type) for v in values]
15061501
result = pa.array(scalars, type=ext_type)
1507-
assert result.type == ext_type
1508-
# TODO: make `expected` pyarrow array so `to_pylist` isn't used, check GH-48241
1509-
assert result.to_pylist() == expected
1502+
expected = pa.array(values, type=ext_type)
1503+
assert result.equals(expected)
15101504

15111505
# Custom extension types requiring registration
15121506
custom_cases = [
1513-
# int8 storage
1514-
(TinyIntType(), [1, 2], [1, 2]),
1515-
# int64 storage
1516-
(IntegerType(), [100, 200], [100, 200]),
1517-
# string storage
1518-
(LabelType(), ["a", "b"], ["a", "b"]),
1519-
# struct storage
1520-
(MyStructType(), [{"left": 1, "right": 2}], [{"left": 1, "right": 2}]),
1521-
# timestamp storage
1507+
(TinyIntType(), [1, 2]),
1508+
(IntegerType(), [100, 200]),
1509+
(LabelType(), ["a", "b"]),
1510+
(MyStructType(), [{"left": 1, "right": 2}]),
15221511
(AnnotatedType(pa.timestamp("us"), "ts"),
1523-
[datetime.datetime(2023, 1, 1)], [datetime.datetime(2023, 1, 1)]),
1524-
# duration storage
1512+
[datetime.datetime(2023, 1, 1)]),
15251513
(AnnotatedType(pa.duration("s"), "dur"),
1526-
[datetime.timedelta(seconds=100)], [datetime.timedelta(seconds=100)]),
1527-
# date storage
1514+
[datetime.timedelta(seconds=100)]),
15281515
(AnnotatedType(pa.date32(), "date"),
1529-
[datetime.date(2023, 1, 1)], [datetime.date(2023, 1, 1)]),
1530-
# float64 storage
1531-
(AnnotatedType(pa.float64(), "f"), [1.5, 2.5], [1.5, 2.5]),
1532-
# boolean storage
1533-
(AnnotatedType(pa.bool_(), "b"), [True, False], [True, False]),
1534-
# binary storage
1535-
(AnnotatedType(pa.binary(), "bin"), [b"x", b"y"], [b"x", b"y"]),
1516+
[datetime.date(2023, 1, 1)]),
1517+
(AnnotatedType(pa.float64(), "f"), [1.5, 2.5]),
1518+
(AnnotatedType(pa.bool_(), "b"), [True, False]),
1519+
(AnnotatedType(pa.binary(), "bin"), [b"x", b"y"]),
1520+
(AnnotatedType(pa.decimal128(10, 2), "dec"),
1521+
[Decimal("1.50"), Decimal("2.75")]),
1522+
(AnnotatedType(pa.large_string(), "lstr"), ["hello", "world"]),
1523+
(AnnotatedType(pa.large_binary(), "lbin"), [b"ab", b"cd"]),
15361524
]
1537-
for ext_type, values, expected in custom_cases:
1525+
for ext_type, values in custom_cases:
15381526
with registered_extension_type(ext_type):
15391527
scalars = [pa.scalar(v, type=ext_type) for v in values]
15401528
result = pa.array(scalars, type=ext_type)
1541-
assert result.type == ext_type
1542-
# TODO: make `expected` pyarrow array so `to_pylist` isn't used
1543-
assert result.to_pylist() == expected
1529+
expected = pa.array(values, type=ext_type)
1530+
assert result.equals(expected)
15441531

1532+
# Null handling
15451533
uuid_type = pa.uuid()
15461534
scalars = [pa.scalar(b"0123456789abcdef", type=uuid_type),
15471535
pa.scalar(None, type=uuid_type)]
15481536
result = pa.array(scalars, type=uuid_type)
15491537
assert result[0].is_valid and not result[1].is_valid
15501538

1539+
# Type inference without explicit type
1540+
u = uuid4()
1541+
scalars = [pa.scalar(u, type=pa.uuid()), None]
1542+
result = pa.array(scalars)
1543+
assert result.type == pa.uuid()
1544+
assert result[0].as_py() == u
1545+
assert not result[1].is_valid
1546+
1547+
# Mixed extension scalars and raw Python objects
1548+
u1, u2 = uuid4(), uuid4()
1549+
result = pa.array([pa.scalar(u1, type=pa.uuid()), u2], type=pa.uuid())
1550+
expected = pa.array([u1, u2], type=pa.uuid())
1551+
assert result.equals(expected)
1552+
15511553

15521554
def test_tensor_type():
15531555
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])

0 commit comments

Comments
 (0)