Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions python/pyarrow/src/arrow/python/python_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,14 @@ class PyConverter : public Converter<PyObject*, PyConversionOptions> {
}
};

// Helper function to unwrap extension scalar to its storage scalar
const Scalar& GetStorageScalar(const Scalar& scalar) {
if (scalar.type->id() == Type::EXTENSION) {
return *checked_cast<const ExtensionScalar&>(scalar).value;
}
return scalar;
}

Comment on lines +587 to +594
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using Result here?

Suggested change
// Helper function to unwrap extension scalar to its storage scalar
const Scalar& GetStorageScalar(const Scalar& scalar) {
if (scalar.type->id() == Type::EXTENSION) {
return *checked_cast<const ExtensionScalar&>(scalar).value;
}
return scalar;
}
// Helper function to unwrap extension scalar to its storage scalar
Result<const Scalar*> GetStorageScalar(const Scalar& scalar) {
if (scalar.type->id() != Type::EXTENSION) {
return &scalar;
}
const auto& extension_scalar = checked_cast<const ExtensionScalar&>(scalar);
return extension_scalar.value.get();
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps Result is not needed here - if its kind of failure path is unlikely then const Scalar& might be better?
Looking at the line 757 check if (PyValue::IsNull(this->options_, value)) { ...
( For an (unlikely?) non-None extension scalar with internal storage .value of a null shared_ptr both approaches don't avoid it? )

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Result is probably not needed indeed, since no error can be returned here. As you prefer @tadeja.

template <typename T, typename Enable = void>
class PyPrimitiveConverter;

Expand Down Expand Up @@ -663,7 +671,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
Comment thread
tadeja marked this conversation as resolved.
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -684,7 +693,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -710,7 +720,8 @@ class PyPrimitiveConverter<T, enable_if_t<std::is_same<T, FixedSizeBinaryType>::
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -747,7 +758,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -791,7 +803,7 @@ class PyDictionaryConverter<U, enable_if_has_c_type<U>>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1);
} else {
ARROW_ASSIGN_OR_RAISE(auto converted,
PyValue::Convert(this->value_type_, this->options_, value));
Expand All @@ -810,7 +822,7 @@ class PyDictionaryConverter<U, enable_if_has_string_view<U>>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1);
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->value_type_, this->options_, value, view_));
Expand Down Expand Up @@ -983,7 +995,7 @@ class PyStructConverter : public StructConverter<PyConverter, PyConverterTrait>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->struct_builder_->AppendScalar(*scalar);
return this->struct_builder_->AppendScalar(GetStorageScalar(*scalar));
}
switch (input_kind_) {
case InputKind::DICT:
Expand Down
65 changes: 65 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,71 @@ def bytes(self):
pa.scalar(bad)


Comment thread
tadeja marked this conversation as resolved.
def test_array_from_extension_scalars():
import datetime
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can move this stdlib import to the top of file.

from decimal import Decimal

builtin_cases = [
(pa.uuid(), [b"0123456789abcdef"]),
(pa.bool8(), [0, 1]),
(pa.json_(pa.string()), ['{"a":1}', '{"b":2}']),
(pa.opaque(pa.binary(), "t", "v"), [b"x", b"y"]),
]
for ext_type, values in builtin_cases:
scalars = [pa.scalar(v, type=ext_type) for v in values]
result = pa.array(scalars, type=ext_type)
expected = pa.array(values, type=ext_type)
assert result.equals(expected)

# Custom extension types requiring registration
custom_cases = [
(TinyIntType(), [1, 2]),
(IntegerType(), [100, 200]),
(LabelType(), ["a", "b"]),
(MyStructType(), [{"left": 1, "right": 2}]),
(AnnotatedType(pa.timestamp("us"), "ts"),
[datetime.datetime(2023, 1, 1)]),
(AnnotatedType(pa.duration("s"), "dur"),
[datetime.timedelta(seconds=100)]),
(AnnotatedType(pa.date32(), "date"),
[datetime.date(2023, 1, 1)]),
(AnnotatedType(pa.float64(), "f"), [1.5, 2.5]),
(AnnotatedType(pa.bool_(), "b"), [True, False]),
(AnnotatedType(pa.binary(), "bin"), [b"x", b"y"]),
(AnnotatedType(pa.decimal128(10, 2), "dec"),
[Decimal("1.50"), Decimal("2.75")]),
(AnnotatedType(pa.large_string(), "lstr"), ["hello", "world"]),
(AnnotatedType(pa.large_binary(), "lbin"), [b"ab", b"cd"]),
]
for ext_type, values in custom_cases:
with registered_extension_type(ext_type):
scalars = [pa.scalar(v, type=ext_type) for v in values]
result = pa.array(scalars, type=ext_type)
expected = pa.array(values, type=ext_type)
assert result.equals(expected)

# Null handling
uuid_type = pa.uuid()
scalars = [pa.scalar(b"0123456789abcdef", type=uuid_type),
pa.scalar(None, type=uuid_type)]
result = pa.array(scalars, type=uuid_type)
assert result[0].is_valid and not result[1].is_valid

Comment thread
tadeja marked this conversation as resolved.
# Type inference without explicit type
u = uuid4()
scalars = [pa.scalar(u, type=pa.uuid()), None]
result = pa.array(scalars)
assert result.type == pa.uuid()
assert result[0].as_py() == u
assert not result[1].is_valid

# Mixed extension scalars and raw Python objects
u1, u2 = uuid4(), uuid4()
result = pa.array([pa.scalar(u1, type=pa.uuid()), u2], type=pa.uuid())
expected = pa.array([u1, u2], type=pa.uuid())
assert result.equals(expected)


def test_tensor_type():
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
Expand Down
Loading