Skip to content

Commit 539d2f3

Browse files
hombitCopilot
andauthored
Support fixed-size and large lists (#418)
* Support fixed-size lists * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Support fixed-size lists --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 27a7e28 commit 539d2f3

File tree

13 files changed

+448
-50
lines changed

13 files changed

+448
-50
lines changed

docs/reference/ext_array.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ Functions
4141
series.ext_array.NestedExtensionArray.set_list_field
4242
series.ext_array.NestedExtensionArray.fill_field_lists
4343
series.ext_array.NestedExtensionArray.pop_fields
44+
series.ext_array.NestedExtensionArray.is_input_pa_type_supported

src/nested_pandas/nestedframe/core.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
_subexprs_by_nest,
2525
)
2626
from nested_pandas.series.dtype import NestedDtype
27+
from nested_pandas.series.ext_array import NestedExtensionArray
2728
from nested_pandas.series.nestedseries import NestedSeries
2829
from nested_pandas.series.packer import pack, pack_lists, pack_sorted_df_into_struct
29-
from nested_pandas.series.utils import is_pa_type_a_list
3030

3131
pd.set_option("display.max_rows", 30)
3232
pd.set_option("display.min_rows", 5)
@@ -68,13 +68,11 @@ def _cast_cols_to_nested(self, *, struct_list: bool) -> None:
6868
if not isinstance(dtype, pd.ArrowDtype):
6969
continue
7070
pa_type = dtype.pyarrow_dtype
71-
if not is_pa_type_a_list(pa_type) and not (struct_list and pa.types.is_struct(pa_type)):
71+
if pa.types.is_struct(pa_type) and not struct_list:
7272
continue
73-
try:
74-
nested_dtype = NestedDtype(pa_type)
75-
except (TypeError, ValueError):
73+
if not NestedExtensionArray.is_input_pa_type_supported(pa_type):
7674
continue
77-
self[column] = self[column].astype(nested_dtype)
75+
self[column] = NestedExtensionArray(pa.array(self[column]))
7876

7977
@property
8078
def _constructor(self) -> Self: # type: ignore[name-defined] # noqa: F821

src/nested_pandas/nestedframe/io.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from pyarrow.lib import ArrowInvalid
1212
from upath import UPath
1313

14-
from ..series.dtype import NestedDtype
14+
from ..series.ext_array import NestedExtensionArray
1515
from ..series.packer import pack_lists
16-
from ..series.utils import table_to_struct_array
16+
from ..series.utils import is_pa_type_a_list, table_to_struct_array
1717
from .core import NestedFrame
1818

1919
# Use smaller block size for these FSSPEC filesystems.
@@ -148,7 +148,7 @@ def read_parquet(
148148
# if any of the columns are not list type, reject the cast
149149
# and remove the column from the list of nested structures if
150150
# it was added
151-
if not pa.types.is_list(table.schema[i].type):
151+
if not is_pa_type_a_list(table.schema[i].type):
152152
reject_nesting.append(nested_col)
153153
if nested_col in nested_structures:
154154
# remove the column from the list of nested structures
@@ -455,28 +455,25 @@ def _cast_struct_cols_to_nested(df, reject_nesting):
455455
"""cast struct columns to nested dtype"""
456456
# Attempt to cast struct columns to NestedDTypes
457457
for col, dtype in df.dtypes.items():
458-
# First validate the dtype
459-
# will return valueerror when not a struct-list
460-
valid_dtype = True
458+
if col in reject_nesting:
459+
continue
460+
461+
if not NestedExtensionArray.is_input_pa_type_supported(dtype.pyarrow_dtype):
462+
continue
463+
461464
try:
462-
NestedDtype._validate_dtype(dtype.pyarrow_dtype)
463-
except ValueError:
464-
valid_dtype = False
465-
466-
if valid_dtype and col not in reject_nesting:
467-
try:
468-
# Attempt to cast Struct to NestedDType
469-
df = df.astype({col: NestedDtype(dtype.pyarrow_dtype)})
470-
except ValueError as err:
471-
# If cast fails, the struct likely does not fit nested-pandas
472-
# criteria for a valid nested column
473-
raise ValueError(
474-
f"Column '{col}' is a Struct, but an attempt to cast it to a NestedDType failed. "
475-
"This is likely due to the struct not meeting the requirements for a nested column "
476-
"(all fields should be equal length). To proceed, you may add the column to the "
477-
"`reject_nesting` argument of the read_parquet function to skip the cast attempt:"
478-
f" read_parquet(..., reject_nesting=['{col}'])"
479-
) from err
465+
# Attempt to cast Struct to NestedDType
466+
df[col] = NestedExtensionArray(pa.array(df[col]))
467+
except ValueError as err:
468+
# If cast fails, the struct likely does not fit nested-pandas
469+
# criteria for a valid nested column
470+
raise ValueError(
471+
f"Column '{col}' is a Struct, but an attempt to cast it to a NestedDType failed. "
472+
"This is likely due to the struct not meeting the requirements for a nested column "
473+
"(all fields should be equal length). To proceed, you may add the column to the "
474+
"`reject_nesting` argument of the read_parquet function to skip the cast attempt:"
475+
f" read_parquet(..., reject_nesting=['{col}'])"
476+
) from err
480477
return df
481478

482479

src/nested_pandas/series/_storage/list_struct_storage.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
import pyarrow as pa
66

7-
from nested_pandas.series.utils import transpose_struct_list_chunked, validate_list_struct_type
7+
from nested_pandas.series.utils import (
8+
normalize_list_array,
9+
transpose_struct_list_chunked,
10+
validate_list_struct_type,
11+
)
812

913
if TYPE_CHECKING:
1014
from nested_pandas.series._storage.struct_list_storage import StructListStorage
@@ -22,11 +26,14 @@ class ListStructStorage:
2226

2327
_data: pa.ChunkedArray
2428

25-
def __init__(self, array: pa.ListArray | pa.ChunkedArray) -> None:
29+
def __init__(
30+
self, array: pa.ListArray | pa.FixedSizeListArray | pa.LargeListArray | pa.ChunkedArray
31+
) -> None:
2632
if isinstance(array, pa.ListArray):
2733
array = pa.chunked_array([array])
2834
if not isinstance(array, pa.ChunkedArray):
2935
raise ValueError("array must be of type pa.ChunkedArray")
36+
array = normalize_list_array(array)
3037
validate_list_struct_type(array.type)
3138
self._data = array
3239

src/nested_pandas/series/_storage/struct_list_storage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from nested_pandas.series.utils import (
99
align_chunked_struct_list_offsets,
10+
normalize_struct_list_array,
1011
table_to_struct_array,
1112
transpose_list_struct_chunked,
1213
)
@@ -39,6 +40,7 @@ def __init__(self, array: pa.StructArray | pa.ChunkedArray, *, validate: bool =
3940
raise ValueError("array must be a StructArray or ChunkedArray")
4041

4142
if validate:
43+
array = normalize_struct_list_array(array)
4244
array = align_chunked_struct_list_offsets(array)
4345

4446
self._data = array

src/nested_pandas/series/dtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ExtensionArray:
186186

187187
pyarrow_dtype: pa.StructType
188188

189-
def __init__(self, pyarrow_dtype: pa.DataType) -> None:
189+
def __init__(self, pyarrow_dtype: pa.DataType | Mapping) -> None:
190190
# Allow pd.ArrowDtypes on init
191191
if isinstance(pyarrow_dtype, pd.ArrowDtype):
192192
pyarrow_dtype = pyarrow_dtype.pyarrow_dtype

src/nested_pandas/series/ext_array.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
from nested_pandas.series.utils import (
6666
chunk_lengths,
6767
is_pa_type_a_list,
68+
normalize_list_array,
69+
normalize_struct_list_type,
6870
rechunk,
6971
struct_field_names,
7072
transpose_struct_list_type,
@@ -253,7 +255,7 @@ def __init__(self, values: pa.Array | pa.ChunkedArray, *, validate: bool = True)
253255
list_struct_storage = ListStructStorage.from_struct_list_storage(struct_list_storage)
254256

255257
self._storage = list_struct_storage
256-
self._dtype = NestedDtype(values.type)
258+
self._dtype = NestedDtype(self._storage.type)
257259

258260
# End of Constructor and initialized attributes #
259261

@@ -669,6 +671,53 @@ def __getstate__(self):
669671

670672
# End of Additional magic methods #
671673

674+
@classmethod
675+
def is_input_pa_type_supported(cls, pa_type: pa.DataType) -> bool:
676+
"""Check whether a pyarrow data type is supported by the constructor.
677+
678+
Calling this method is cheaper than trying to construct the array,
679+
because data transformations are avoided.
680+
681+
Parameters
682+
----------
683+
pa_type : pyarrow.DataType
684+
The pyarrow data type to check for compatibility with
685+
NestedExtensionArray.
686+
687+
Returns
688+
-------
689+
bool
690+
``True`` if ``pa_type`` is a supported input type for
691+
NestedExtensionArray, ``False`` otherwise.
692+
693+
Examples
694+
--------
695+
Check support for a list-of-structs type::
696+
697+
>>> import pyarrow as pa
698+
>>> value_type = pa.struct([("a", pa.int64())])
699+
>>> pa_type = pa.list_(value_type)
700+
>>> NestedExtensionArray.is_input_pa_type_supported(pa_type)
701+
True
702+
703+
Check support for an unsupported primitive type::
704+
705+
>>> NestedExtensionArray.is_input_pa_type_supported(pa.int64())
706+
False
707+
"""
708+
if is_pa_type_a_list(pa_type):
709+
list_type = cast(
710+
pa.ListType | pa.LargeListType | pa.FixedSizeListType,
711+
pa_type,
712+
)
713+
return pa.types.is_struct(list_type.value_type)
714+
715+
try:
716+
_ = normalize_struct_list_type(pa_type)
717+
except ValueError:
718+
return False
719+
return True
720+
672721
@classmethod
673722
def _box_pa_scalar(cls, value, *, pa_type: pa.DataType | None) -> pa.Scalar:
674723
"""Convert a value to a PyArrow scalar with the specified type."""
@@ -1098,8 +1147,7 @@ def set_list_field(self, field: str, value: ArrayLike, *, keep_dtype: bool = Fal
10981147
"or NestedExtensionArray.set_list_field(..., keep_dtype=False) instead."
10991148
) from e
11001149

1101-
if not is_pa_type_a_list(pa_array.type):
1102-
raise ValueError(f"Expected a list array, got {pa_array.type}")
1150+
pa_array = normalize_list_array(pa_array)
11031151

11041152
if len(pa_array) != len(self):
11051153
raise ValueError("The length of the list-array must be equal to the length of the series")

0 commit comments

Comments
 (0)