Skip to content

Commit 1228c1a

Browse files
authored
Lazy read_xarray with iterables and parsing metadata. (#81)
I no longer convert the block_slices iterables into a list. I've also implemented a quick `_parse_schema` method to extract the `pa.Schema` from the dataset without materializing the first block. These should provide a few good performance improvements for working with real datasets.
1 parent c3b4b00 commit 1228c1a

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

xarray_sql/df.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,23 @@ def pivot(ds: xr.Dataset) -> pd.DataFrame:
150150
return ds.to_dataframe().reset_index()
151151

152152

153+
def _parse_schema(ds) -> pa.Schema:
154+
"""Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns."""
155+
columns = []
156+
157+
for coord_name, coord_var in ds.coords.items():
158+
# Only include dimension coordinates
159+
if coord_name in ds.dims:
160+
pa_type = pa.from_numpy_dtype(coord_var.dtype)
161+
columns.append(pa.field(coord_name, pa_type))
162+
163+
for var_name, var in ds.data_vars.items():
164+
pa_type = pa.from_numpy_dtype(var.dtype)
165+
columns.append(pa.field(var_name, pa_type))
166+
167+
return pa.schema(columns)
168+
169+
153170
def read_xarray(ds: xr.Dataset, chunks: Chunks = None) -> pa.RecordBatchReader:
154171
"""Pivots an Xarray Dataset into a PyArrow Table, partitioned by chunks.
155172
@@ -162,18 +179,15 @@ def read_xarray(ds: xr.Dataset, chunks: Chunks = None) -> pa.RecordBatchReader:
162179
Returns:
163180
A PyArrow Table, which is a table representation of the input Dataset.
164181
"""
165-
fst = next(iter(ds.values())).dims
166-
assert all(
167-
da.dims == fst for da in ds.values()
168-
), "All dimensions must be equal. Please filter data_vars in the Dataset."
169-
170-
blocks = list(block_slices(ds, chunks))
171182

172183
def pivot_block(b: Block):
173184
return pivot(ds.isel(b))
174185

175-
schema = pa.Schema.from_pandas(pivot_block(blocks[0]))
176-
last_schema = pa.Schema.from_pandas(pivot_block(blocks[-1]))
177-
assert schema == last_schema, "Schemas must be consistent across blocks!"
186+
fst = next(iter(ds.values())).dims
187+
assert all(
188+
da.dims == fst for da in ds.values()
189+
), "All dimensions must be equal. Please filter data_vars in the Dataset."
178190

191+
schema = _parse_schema(ds)
192+
blocks = block_slices(ds, chunks)
179193
return from_map_batched(pivot_block, blocks, schema=schema)

0 commit comments

Comments
 (0)