Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/nemo_safe_synthesizer/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ def common_setup(
synthesis_overrides = merge_dicts(synthesis_overrides, dataset_info.overrides or dict())
df = dataset_info.fetch()
elif resume:
# For generate-only runs without --data-source, verify cached dataset exists
# For generate-only runs without --data-source, verify cached dataset exists.
# test.csv may legitimately be absent when holdout=0.
cached_training: Path = workdir.source_dataset.training # type: ignore[assignment]
cached_test: Path = workdir.source_dataset.test # type: ignore[assignment]
if not cached_training.exists() or not cached_test.exists():
if not cached_training.exists():
raise click.ClickException(
f"No cached dataset found in workdir: {workdir.source_dataset.path}\n\n"
"Either provide --data-source to load a dataset, or ensure the workdir "
"contains cached training/test data from a previous run."
"contains cached training data from a previous run."
)
run_logger.info(f"Using cached dataset from: {workdir.source_dataset.path}")
# df is None - SafeSynthesizer.load_from_save_path() will load from cached files
Expand Down
11 changes: 7 additions & 4 deletions src/nemo_safe_synthesizer/sdk/library_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,16 @@ def load_from_save_path(self) -> SafeSynthesizer:
# Only fall back to with_data_source() data if cached files are missing.
training_path = self._workdir.source_dataset.training
test_path = self._workdir.source_dataset.test
if training_path.exists() and test_path.exists():
if training_path.exists():
logger.info("Loading cached train/test split from training run")
# training_path persists the original training split for evaluation.
self._original_train_df = pd.read_csv(training_path)
self._test_df = pd.read_csv(test_path)
# test.csv may not exist (holdout=0) or may be empty (old runs with holdout=0).
if test_path.exists() and test_path.stat().st_size > 0:
self._test_df = pd.read_csv(test_path)
else:
logger.info("No test split loaded (holdout was disabled for this run)")
self._test_df = None
# Mark that we have fully loaded from the saved run, including cached splits.
self._loaded_from_save_path = True
elif self._data_source is not None:
Expand Down Expand Up @@ -296,8 +301,6 @@ def process_data(self) -> SafeSynthesizer:
self._train_df.to_csv(self._workdir.dataset.transformed_training, index=False)
if self._test_df is not None:
self._test_df.to_csv(self._workdir.dataset.test, index=False)
else:
self._workdir.dataset.test.touch()
return self

@traced("SafeSynthesizer.train", category=LogCategory.RUNTIME)
Expand Down
130 changes: 130 additions & 0 deletions tests/sdk/test_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,136 @@ def test_run_after_load_from_save_path_raises(
builder.run()


# ---------------------------------------------------------------------------
# Tests: load_from_save_path with holdout=0
# ---------------------------------------------------------------------------


class TestLoadFromSavePathHoldoutZero:
"""``load_from_save_path`` and ``process_data`` must handle holdout=0 gracefully.

When ``holdout=0``, the train/test split returns ``test_df=None`` and no
test set is produced. Previously, ``process_data`` would ``touch()`` an
empty ``test.csv``, and ``load_from_save_path`` would unconditionally call
``pd.read_csv`` on it, raising ``EmptyDataError``.

These tests verify:

* ``process_data`` does not write ``test.csv`` when the holdout split
produces no test set.
* ``load_from_save_path`` succeeds when ``test.csv`` is absent (new runs
with ``holdout=0``).
* ``load_from_save_path`` succeeds when ``test.csv`` is an empty 0-byte
file (backward compatibility with runs created before the fix).
"""

def _prepare_workdir_no_holdout(
self,
tmp_path: Path,
fixture_sample_patient_dataframe: pd.DataFrame,
) -> tuple[Workdir, pd.DataFrame]:
"""Create a workdir with only training.csv (no test.csv), simulating holdout=0."""
workdir = Workdir(base_path=tmp_path, config_name="test", dataset_name="data")
workdir.ensure_directories()

train_split = fixture_sample_patient_dataframe.copy()
train_split.to_csv(workdir.dataset.training, index=False)

config = SafeSynthesizerParameters()
config_path = workdir.config
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(config.model_dump_json())

metadata_path = workdir.metadata_file
metadata_path.parent.mkdir(parents=True, exist_ok=True)
metadata_path.write_text("{}")

return workdir, train_split

@patch("nemo_safe_synthesizer.sdk.library_builder.ModelMetadata")
@patch("nemo_safe_synthesizer.sdk.library_builder.AutoConfigResolver")
@patch("nemo_safe_synthesizer.sdk.library_builder.Holdout")
def test_process_data_no_test_csv_when_holdout_zero(
self,
mock_holdout_cls,
mock_resolver_cls,
mock_metadata_cls,
fixture_workdir,
fixture_sample_patient_dataframe,
):
"""No ``test.csv`` is written when the holdout split yields no test set.

Mocks ``Holdout.train_test_split`` to return ``(train_df, None)`` --
the same value it returns for ``holdout=0`` -- and asserts that
``process_data`` leaves the dataset directory without a ``test.csv``
file and keeps ``_test_df`` as ``None``.
"""
train_split = fixture_sample_patient_dataframe.copy()
builder = SafeSynthesizer(config=SafeSynthesizerParameters(), workdir=fixture_workdir)
builder._data_source = fixture_sample_patient_dataframe

mock_holdout_cls.return_value.train_test_split.return_value = (train_split, None)
mock_resolver_cls.return_value.return_value = builder._nss_config
mock_metadata_cls.from_config.return_value = MagicMock()

builder.process_data()

assert not fixture_workdir.dataset.test.exists()
assert builder._test_df is None

@patch("nemo_safe_synthesizer.sdk.library_builder.ModelMetadata")
def test_load_succeeds_without_test_csv(
self,
mock_metadata_cls,
tmp_path,
fixture_sample_patient_dataframe,
):
"""Resume succeeds when ``test.csv`` does not exist on disk.

Prepares a workdir that only contains ``training.csv`` (no
``test.csv``), simulating a new run with ``holdout=0``. Verifies
that ``load_from_save_path`` loads the training split, sets
``_test_df`` to ``None``, and marks the load as complete.
"""
workdir, train_split = self._prepare_workdir_no_holdout(tmp_path, fixture_sample_patient_dataframe)
mock_metadata_cls.from_metadata_json.return_value = MagicMock()

builder = SafeSynthesizer(config=SafeSynthesizerParameters(), workdir=workdir)
builder.load_from_save_path()

pd.testing.assert_frame_equal(builder._original_train_df, train_split)
assert builder._test_df is None
assert builder._loaded_from_save_path is True

@patch("nemo_safe_synthesizer.sdk.library_builder.ModelMetadata")
def test_load_handles_empty_test_csv_from_old_runs(
self,
mock_metadata_cls,
tmp_path,
fixture_sample_patient_dataframe,
):
"""Resume succeeds when ``test.csv`` is an empty 0-byte file.

Before the fix, ``process_data`` called ``touch()`` to create an
empty ``test.csv`` when there was no holdout. Saved run directories
from those older runs still have the empty file on disk. This test
ensures ``load_from_save_path`` treats it the same as a missing file
rather than crashing with ``EmptyDataError``.
"""
workdir, train_split = self._prepare_workdir_no_holdout(tmp_path, fixture_sample_patient_dataframe)
# Simulate old behavior: empty 0-byte test.csv
workdir.dataset.test.touch()

mock_metadata_cls.from_metadata_json.return_value = MagicMock()

builder = SafeSynthesizer(config=SafeSynthesizerParameters(), workdir=workdir)
builder.load_from_save_path()

pd.testing.assert_frame_equal(builder._original_train_df, train_split)
assert builder._test_df is None
assert builder._loaded_from_save_path is True


# ---------------------------------------------------------------------------
# Tests: config validation at process_data entry
# ---------------------------------------------------------------------------
Expand Down
Loading