Skip to content

Commit 3d33680

Browse files
authored
feat: support 1-to-many FileSystemSeedReader hydration (#424)
1 parent 164db0a commit 3d33680

File tree

3 files changed

+435
-16
lines changed

3 files changed

+435
-16
lines changed

packages/data-designer-engine/src/data_designer/engine/resources/seed_reader.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
from abc import ABC, abstractmethod
7-
from collections.abc import Callable, Sequence
7+
from collections.abc import Callable, Iterable, Sequence
88
from dataclasses import dataclass
99
from fnmatch import fnmatchcase
1010
from pathlib import Path, PurePosixPath
@@ -77,13 +77,13 @@ def create_seed_reader_output_dataframe(
7777
continue
7878

7979
message_parts: list[str] = [
80-
f"Hydrated row at index {row_index} does not match output_columns {output_columns!r}."
80+
f"Hydrated record at index {row_index} does not match output_columns {output_columns!r}."
8181
]
8282
if missing_columns:
8383
message_parts.append(f"Missing columns: {missing_columns!r}.")
8484
if extra_columns:
8585
message_parts.append(f"Undeclared columns: {extra_columns!r}.")
86-
message_parts.append("Ensure hydrate_row() returns exactly the declared output schema.")
86+
message_parts.append("Ensure each record emitted by hydrate_row() matches the declared output schema.")
8787
raise SeedReaderError(" ".join(message_parts))
8888

8989
return lazy.pd.DataFrame(records, columns=output_columns)
@@ -118,18 +118,32 @@ def __init__(
118118
manifest_batch_reader: SeedReaderBatchReader,
119119
hydrate_records: Callable[[list[dict[str, Any]]], list[dict[str, Any]]],
120120
output_columns: list[str],
121+
no_rows_error_message: str,
121122
) -> None:
122123
self._manifest_batch_reader = manifest_batch_reader
123124
self._hydrate_records = hydrate_records
124125
self._output_columns = output_columns
126+
self._no_rows_error_message = no_rows_error_message
127+
self._has_emitted_records = False
125128

126129
def read_next_batch(self) -> SeedReaderBatch:
127-
manifest_batch = self._manifest_batch_reader.read_next_batch()
128-
manifest_records = manifest_batch.to_pandas().to_dict(orient="records")
129-
hydrated_records = self._hydrate_records(manifest_records)
130-
return PandasSeedReaderBatch(
131-
create_seed_reader_output_dataframe(records=hydrated_records, output_columns=self._output_columns)
132-
)
130+
while True:
131+
try:
132+
manifest_batch = self._manifest_batch_reader.read_next_batch()
133+
except StopIteration:
134+
if self._has_emitted_records:
135+
raise
136+
raise SeedReaderError(self._no_rows_error_message) from None
137+
138+
manifest_records = manifest_batch.to_pandas().to_dict(orient="records")
139+
hydrated_records = self._hydrate_records(manifest_records)
140+
if not hydrated_records:
141+
continue
142+
143+
self._has_emitted_records = True
144+
return PandasSeedReaderBatch(
145+
create_seed_reader_output_dataframe(records=hydrated_records, output_columns=self._output_columns)
146+
)
133147

134148

135149
SourceT = TypeVar("SourceT", bound=SeedSource)
@@ -342,11 +356,12 @@ class FileSystemSeedReader(SeedReader[FileSystemSourceT], ABC):
342356
343357
Plugin authors implement `build_manifest(...)` to describe the cheap logical
344358
rows available under the configured filesystem root. Readers that need
345-
expensive enrichment can optionally override `hydrate_row(...)`. When
346-
`hydrate_row(...)` changes the manifest schema, `output_columns` must declare
347-
the exact hydrated output schema. The framework owns attachment-scoped
348-
filesystem context reuse, manifest sampling, partitioning, randomization,
349-
batching, and DuckDB registration details.
359+
expensive enrichment can optionally override `hydrate_row(...)` to emit one
360+
record dict or an iterable of record dicts per manifest row. When emitted
361+
records change the manifest schema, `output_columns` must declare the exact
362+
hydrated output schema for each emitted record. The framework owns
363+
attachment-scoped filesystem context reuse, manifest sampling, partitioning,
364+
randomization, batching, and DuckDB registration details.
350365
"""
351366

352367
output_columns: ClassVar[list[str] | None] = None
@@ -379,7 +394,7 @@ def hydrate_row(
379394
*,
380395
manifest_row: dict[str, Any],
381396
context: SeedReaderFileSystemContext,
382-
) -> dict[str, Any]:
397+
) -> dict[str, Any] | Iterable[dict[str, Any]]:
383398
return manifest_row
384399

385400
def get_column_names(self) -> list[str]:
@@ -416,6 +431,7 @@ def create_batch_reader(
416431
context=context,
417432
),
418433
output_columns=self.get_output_column_names(),
434+
no_rows_error_message=self._get_empty_selected_manifest_rows_error_message(),
419435
)
420436

421437
def _get_row_manifest_dataframe(self) -> pd.DataFrame:
@@ -468,6 +484,9 @@ def _build_internal_table_name(self, suffix: str) -> str:
468484
seed_type = self.get_seed_type().replace("-", "_")
469485
return f"seed_reader_{seed_type}_{suffix}"
470486

487+
def _get_empty_selected_manifest_rows_error_message(self) -> str:
488+
return f"Selected manifest rows for seed source at {self.source.path} did not produce any rows after hydration"
489+
471490
def _normalize_rows_to_dataframe(self, rows: pd.DataFrame | list[dict[str, Any]]) -> pd.DataFrame:
472491
if isinstance(rows, lazy.pd.DataFrame):
473492
return rows.copy()
@@ -479,7 +498,15 @@ def _hydrate_rows(
479498
manifest_rows: list[dict[str, Any]],
480499
context: SeedReaderFileSystemContext,
481500
) -> list[dict[str, Any]]:
482-
return [self.hydrate_row(manifest_row=manifest_row, context=context) for manifest_row in manifest_rows]
501+
hydrated_records: list[dict[str, Any]] = []
502+
for manifest_row_index, manifest_row in enumerate(manifest_rows):
503+
hydrated_records.extend(
504+
_normalize_hydrated_row_output(
505+
hydrated_row_output=self.hydrate_row(manifest_row=manifest_row, context=context),
506+
manifest_row_index=manifest_row_index,
507+
)
508+
)
509+
return hydrated_records
483510

484511

485512
class DirectorySeedReader(FileSystemSeedReader[DirectorySeedSource]):
@@ -584,3 +611,30 @@ def _build_metadata_record(
584611

585612
def _normalize_relative_path(path: str) -> str:
586613
return path.lstrip("/")
614+
615+
616+
def _normalize_hydrated_row_output(
617+
*,
618+
hydrated_row_output: dict[str, Any] | Iterable[dict[str, Any]],
619+
manifest_row_index: int,
620+
) -> list[dict[str, Any]]:
621+
if isinstance(hydrated_row_output, dict):
622+
return [hydrated_row_output]
623+
624+
if not isinstance(hydrated_row_output, Iterable):
625+
raise SeedReaderError(
626+
"hydrate_row() must return a record dict or an iterable of record dicts. "
627+
f"Manifest row index {manifest_row_index} returned {type(hydrated_row_output).__name__}."
628+
)
629+
630+
hydrated_records = list(hydrated_row_output)
631+
for hydrated_record in hydrated_records:
632+
if isinstance(hydrated_record, dict):
633+
continue
634+
raise SeedReaderError(
635+
"hydrate_row() must return a record dict or an iterable of record dicts. "
636+
f"Manifest row index {manifest_row_index} returned an iterable containing "
637+
f"{type(hydrated_record).__name__}."
638+
)
639+
640+
return hydrated_records

0 commit comments

Comments
 (0)