Skip to content

Commit b86051c

Browse files
committed
PR feedback: add early fail to order by check, remove unnecessary internal method, clean up docstrings, etc.
Signed-off-by: nina-xu <19981858+nina-xu@users.noreply.github.com>
1 parent 53769a0 commit b86051c

File tree

11 files changed

+111
-72
lines changed

11 files changed

+111
-72
lines changed

src/nemo_safe_synthesizer/data_processing/assembler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RunningStatistics,
3131
Statistics,
3232
)
33+
from ..data_processing.validation import MISSING_GROUP_BY_COLUMN_ERROR, MISSING_ORDER_BY_COLUMN_ERROR
3334
from ..defaults import (
3435
DEFAULT_CACHE_PREFIX,
3536
PSEUDO_GROUP_COLUMN,
@@ -864,10 +865,10 @@ def _validate_columns(self, dataset: Dataset) -> None:
864865
ParameterError: If group or order column is not found in dataset.
865866
"""
866867
if self.group_by_column not in dataset.column_names:
867-
raise ParameterError(f"Group by column '{self.group_by_column}' not found in dataset.")
868+
raise ParameterError(MISSING_GROUP_BY_COLUMN_ERROR.format(group_by=self.group_by_column))
868869

869870
if self.order_by_column not in dataset.column_names:
870-
raise ParameterError(f"Order by column '{self.order_by_column}' not found in dataset.")
871+
raise ParameterError(MISSING_ORDER_BY_COLUMN_ERROR.format(order_by=self.order_by_column))
871872

872873
def _reorder_columns(self, dataset: Dataset) -> Dataset:
873874
"""Reorder columns: group_by first, order_by second, then the rest.

src/nemo_safe_synthesizer/data_processing/validation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
MISSING_GROUP_BY_COLUMN_ERROR = (
1313
"Group by column '{group_by}' not found in input dataset columns. "
14-
"Please set `data.group_training_examples_by` to an existing column or disable grouping."
14+
"Please set `data.group_training_examples_by` to an existing column or to `null`/`None` to disable grouping."
1515
)
1616
MISSING_GROUP_BY_VALUES_ERROR = "Group by column '{group_by}' has missing values. Please remove/replace them."
17+
MISSING_ORDER_BY_COLUMN_ERROR = "Order by column '{order_by}' not found in the input data."
1718

1819

1920
def validate_groupby_column(df: pd.DataFrame, group_by: str | None) -> None:
@@ -35,3 +36,20 @@ def validate_groupby_column(df: pd.DataFrame, group_by: str | None) -> None:
3536

3637
if df[group_by].isna().any():
3738
raise DataError(MISSING_GROUP_BY_VALUES_ERROR.format(group_by=group_by))
39+
40+
41+
def validate_orderby_column(df: pd.DataFrame, order_by: str | None) -> None:
42+
"""Validate that the configured order-by column exists.
43+
44+
Args:
45+
df: Dataframe to validate.
46+
order_by: Name of the configured ordering column.
47+
48+
Raises:
49+
ParameterError: If ``order_by`` is configured but not present in ``df``.
50+
"""
51+
if order_by is None:
52+
return
53+
54+
if order_by not in df.columns:
55+
raise ParameterError(MISSING_ORDER_BY_COLUMN_ERROR.format(order_by=order_by))

src/nemo_safe_synthesizer/holdout/holdout.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ..config.data import DEFAULT_HOLDOUT, MIN_HOLDOUT
1515
from ..config.parameters import SafeSynthesizerParameters
16-
from ..data_processing.validation import MISSING_GROUP_BY_COLUMN_ERROR, validate_groupby_column
16+
from ..data_processing.validation import validate_groupby_column
1717
from ..observability import get_logger
1818

1919
MIN_RECORDS_FOR_TEXT_AND_PRIVACY_METRICS = 200
@@ -24,20 +24,11 @@
2424
INPUT_DATA_TOO_SMALL_ERROR = (
2525
f"Dataset must have at least {MIN_RECORDS_FOR_TEXT_AND_PRIVACY_METRICS} records to use holdout."
2626
)
27+
2728
logger = get_logger(__name__)
2829

2930
DataFrameOptionalTuple = tuple[pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, None]
3031

31-
__all__ = [
32-
"HOLDOUT_TOO_SMALL_ERROR",
33-
"INPUT_DATA_TOO_SMALL_ERROR",
34-
"MISSING_GROUP_BY_COLUMN_ERROR",
35-
"MIN_RECORDS_FOR_TEXT_AND_PRIVACY_METRICS",
36-
"Holdout",
37-
"grouped_train_test_split",
38-
"naive_train_test_split",
39-
]
40-
4132

4233
def naive_train_test_split(df, test_size, random_state=None) -> DataFrameOptionalTuple:
4334
"""Split a dataframe into train and test sets with a random shuffle.
@@ -83,12 +74,10 @@ def grouped_train_test_split(df, test_size, group_by, random_state=None) -> Data
8374
grouped split could be produced.
8475
8576
Raises:
86-
ValueError: If the ``group_by`` column contains missing values.
77+
ParameterError: If the ``group_by`` column is not present in ``df``.
78+
DataError: If the ``group_by`` column contains missing values.
8779
"""
88-
# Do not continue the split process if the groupby column has missing values.
89-
if df[group_by].isna().any():
90-
msg = f"Group by column '{group_by}' has missing values. Please remove/replace them."
91-
raise ValueError(msg)
80+
validate_groupby_column(df, group_by)
9281

9382
if test_size > df.groupby(group_by).ngroups or test_size == 1 or test_size == 0:
9483
logger.info(

src/nemo_safe_synthesizer/sdk/library_builder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
SafeSynthesizerParameters,
2424
)
2525
from ..config.autoconfig import AutoConfigResolver
26-
from ..data_processing.validation import validate_groupby_column
26+
from ..data_processing.validation import validate_groupby_column, validate_orderby_column
2727
from ..evaluation.evaluator import Evaluator
2828
from ..generation.timeseries_backend import TimeseriesBackend
2929
from ..generation.vllm_backend import VllmBackend
@@ -242,9 +242,11 @@ def load_from_save_path(self) -> SafeSynthesizer:
242242
def process_data(self) -> SafeSynthesizer:
243243
"""Perform train/test split, auto-config resolution, and optional PII replacement.
244244
245-
Splits the data via ``Holdout``, runs ``AutoConfigResolver`` to
246-
resolve ``"auto"`` parameters, applies PII replacement to the
247-
training set when enabled, and persists the splits to the workdir.
245+
Validates configured grouping/ordering columns against the input
246+
dataset, splits the data via ``Holdout``, runs
247+
``AutoConfigResolver`` to resolve ``"auto"`` parameters, applies
248+
PII replacement to the training set when enabled, and persists the
249+
splits to the workdir.
248250
249251
Returns:
250252
Self for method chaining.
@@ -267,6 +269,7 @@ def process_data(self) -> SafeSynthesizer:
267269
assert isinstance(self._data_source, pd.DataFrame)
268270

269271
validate_groupby_column(self._data_source, self._nss_config.data.group_training_examples_by)
272+
validate_orderby_column(self._data_source, self._nss_config.data.order_training_examples_by)
270273

271274
holdout = Holdout(self._nss_config)
272275
original_train_df, self._test_df = holdout.train_test_split(self._data_source)

src/nemo_safe_synthesizer/training/backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ def __subclasshook__(cls, subclass):
220220
def prepare_training_data(self):
221221
"""Load, validate, and tokenize the training dataset.
222222
223-
Runs auto-config resolution, validates groupby/orderby columns,
224-
applies time-series processing and ``action_executor`` preprocessing,
225-
then assembles tokenized training examples. Populates
226-
``training_examples``, ``dataset_schema``, ``df_train``, and
227-
``data_fraction``.
223+
Validates grouping/ordering columns (where applicable), resolves
224+
auto-config values, applies time-series processing and
225+
``action_executor`` preprocessing, then assembles tokenized training
226+
examples. Populates ``training_examples``, ``dataset_schema``,
227+
``df_train``, and ``data_fraction``.
228228
"""
229229
...
230230

src/nemo_safe_synthesizer/training/huggingface_backend.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from ..config.autoconfig import AutoConfigResolver
3939
from ..data_processing.assembler import TrainingExampleAssembler
4040
from ..data_processing.dataset import make_json_schema
41-
from ..data_processing.validation import validate_groupby_column
41+
from ..data_processing.validation import validate_groupby_column, validate_orderby_column
4242
from ..defaults import (
4343
DEFAULT_VALID_RECORD_EVAL_BATCH_SIZE,
4444
EVAL_STEPS,
@@ -529,19 +529,6 @@ def prepare_params(self, **training_args):
529529
self.trainer = self._create_trainer(self.train_args, data_collator)
530530
self._configure_trainer_callbacks(self.trainer, training_args)
531531

532-
def _validate_groupby_column(self, df) -> None:
533-
"""Validate the groupby column exists and has no missing values.
534-
535-
Args:
536-
df: The DataFrame to validate.
537-
538-
Raises:
539-
ParameterError: If the groupby column doesn't exist.
540-
DataError: If the groupby column has missing values.
541-
"""
542-
col = self.params.data.group_training_examples_by
543-
validate_groupby_column(df, col)
544-
545532
def _validate_orderby_column(self, df) -> None:
546533
"""Validate the orderby column exists in the dataset.
547534
@@ -558,10 +545,7 @@ def _validate_orderby_column(self, df) -> None:
558545
if self.params.time_series.is_timeseries and self.params.time_series.timestamp_column is None:
559546
return
560547

561-
if orderby_col and orderby_col not in df.columns:
562-
msg = f"Order by column '{orderby_col}' not found in the input data."
563-
logger.error(msg)
564-
raise ParameterError(msg)
548+
validate_orderby_column(df, orderby_col)
565549

566550
def _apply_preprocessing(self, df):
567551
"""Apply action_executor preprocessing if available.
@@ -642,8 +626,8 @@ def _log_dataset_statistics(self, assembler) -> None:
642626
def prepare_training_data(self):
643627
"""Validate, preprocess, and tokenize the training dataset.
644628
645-
Runs auto-config resolution, time-series processing, groupby /
646-
orderby validation, and assembles tokenized training examples.
629+
Validates groupby/orderby columns, resolves auto-config values,
630+
runs time-series preprocessing, and assembles tokenized training examples.
647631
Populates ``training_examples``, ``dataset_schema``,
648632
``df_train``, and ``data_fraction``.
649633
@@ -660,7 +644,7 @@ def prepare_training_data(self):
660644
raise DataError("Expected DataFrame from to_pandas(), got an iterator")
661645

662646
# Validate groupby/orderby parameters as a preprocessing step.
663-
self._validate_groupby_column(df_all)
647+
validate_groupby_column(df_all, self.params.data.group_training_examples_by)
664648
self._validate_orderby_column(df_all)
665649
self.params = AutoConfigResolver(df_all, self.params).resolve()
666650

tests/data_processing/test_assembler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def test_sequential_assembler_raises_for_missing_group_column(
670670
fixture_sequential_metadata: ModelMetadata,
671671
):
672672
"""Test that SequentialExampleAssembler raises for missing group column."""
673-
with pytest.raises(ParameterError, match="Group by column.*not found in dataset"):
673+
with pytest.raises(ParameterError, match="Group by column.*not found"):
674674
SequentialExampleAssembler(
675675
dataset=fixture_iris_dataset,
676676
tokenizer=fixture_tokenizer,
@@ -689,7 +689,7 @@ def test_sequential_assembler_raises_for_missing_order_column(
689689
fixture_sequential_metadata: ModelMetadata,
690690
):
691691
"""Test that SequentialExampleAssembler raises for missing order column."""
692-
with pytest.raises(ParameterError, match="Order by column.*not found in dataset"):
692+
with pytest.raises(ParameterError, match="Order by column.*not found"):
693693
SequentialExampleAssembler(
694694
dataset=fixture_iris_dataset,
695695
tokenizer=fixture_tokenizer,

tests/data_processing/test_validation.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from nemo_safe_synthesizer.data_processing.validation import (
88
MISSING_GROUP_BY_COLUMN_ERROR,
99
MISSING_GROUP_BY_VALUES_ERROR,
10+
MISSING_ORDER_BY_COLUMN_ERROR,
1011
validate_groupby_column,
12+
validate_orderby_column,
1113
)
1214
from nemo_safe_synthesizer.errors import DataError, ParameterError
1315

@@ -20,12 +22,24 @@ def test_validate_groupby_column_noop_when_groupby_is_none() -> None:
2022
def test_validate_groupby_column_raises_for_missing_column() -> None:
2123
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
2224
with pytest.raises(ParameterError) as excinfo:
23-
validate_groupby_column(df, "missing")
24-
assert str(excinfo.value) == MISSING_GROUP_BY_COLUMN_ERROR.format(group_by="missing")
25+
validate_groupby_column(df, "missing_group")
26+
assert str(excinfo.value) == MISSING_GROUP_BY_COLUMN_ERROR.format(group_by="missing_group")
2527

2628

2729
def test_validate_groupby_column_raises_for_missing_values() -> None:
2830
df = pd.DataFrame({"group": ["x", None], "value": [1, 2]})
2931
with pytest.raises(DataError) as excinfo:
3032
validate_groupby_column(df, "group")
3133
assert str(excinfo.value) == MISSING_GROUP_BY_VALUES_ERROR.format(group_by="group")
34+
35+
36+
def test_validate_orderby_column_noop_when_orderby_is_none() -> None:
37+
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
38+
validate_orderby_column(df, None)
39+
40+
41+
def test_validate_orderby_column_raises_for_missing_column() -> None:
42+
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
43+
with pytest.raises(ParameterError) as excinfo:
44+
validate_orderby_column(df, "missing_order")
45+
assert str(excinfo.value) == MISSING_ORDER_BY_COLUMN_ERROR.format(order_by="missing_order")

tests/holdout/test_holdout.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import pytest
77

88
from nemo_safe_synthesizer.config.parameters import SafeSynthesizerParameters
9+
from nemo_safe_synthesizer.errors import DataError, ParameterError
910
from nemo_safe_synthesizer.holdout.holdout import (
1011
HOLDOUT_TOO_SMALL_ERROR,
1112
INPUT_DATA_TOO_SMALL_ERROR,
12-
MISSING_GROUP_BY_COLUMN_ERROR,
1313
Holdout,
1414
naive_train_test_split,
1515
)
@@ -106,10 +106,16 @@ def test_does_group_by_holdout(df):
106106

107107
def test_raises_on_group_by_holdout_with_bad_column(df):
108108
holdout = Holdout(SafeSynthesizerParameters.from_params(group_training_examples_by="dne"))
109-
with pytest.raises(ValueError) as excinfo:
109+
with pytest.raises(ParameterError, match="Group by column 'dne' not found"):
110110
holdout.train_test_split(df)
111111

112-
assert str(excinfo.value) == MISSING_GROUP_BY_COLUMN_ERROR.format(group_by="dne")
112+
113+
def test_raises_on_group_by_holdout_with_missing_values(df):
114+
df_with_missing_group = df.copy()
115+
df_with_missing_group.loc[0, "big_cat"] = None
116+
holdout = Holdout(SafeSynthesizerParameters.from_params(group_training_examples_by="big_cat"))
117+
with pytest.raises(DataError, match="Group by column 'big_cat' has missing values"):
118+
holdout.train_test_split(df_with_missing_group)
113119

114120

115121
def test_complains_when_training_dataset_is_too_small():

tests/sdk/test_process_data.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,9 @@ class TestProcessDataConfigValidation:
585585
methods after construction are not visible to the Pydantic validator
586586
until ``_resolve_nss_config()`` is called. ``process_data`` must
587587
call it at the top of the method so invalid configs are caught
588-
immediately -- before holdout split, PII replacement, or any disk I/O.
588+
immediately. It also validates configured group/order columns against
589+
the input dataset before holdout split, autoconfig resolution, PII
590+
replacement, or any disk I/O.
589591
"""
590592

591593
def test_dp_and_explicit_unsloth_raises_at_process_data(self, fixture_workdir: Workdir) -> None:
@@ -615,11 +617,37 @@ def test_invalid_groupby_raises_before_holdout(
615617
downstream ``KeyError``.
616618
"""
617619
ss = SafeSynthesizer(
618-
config=SafeSynthesizerParameters.from_params(group_training_examples_by="non_existent"),
620+
config=SafeSynthesizerParameters.from_params(group_training_examples_by="non_existent_group"),
619621
workdir=fixture_workdir,
620622
).with_data_source(fixture_sample_patient_dataframe)
621623

622-
with pytest.raises(ParameterError, match="Group by column 'non_existent' not found"):
624+
with pytest.raises(ParameterError, match="Group by column 'non_existent_group' not found"):
625+
ss.process_data()
626+
627+
mock_holdout_cls.assert_not_called()
628+
629+
@patch("nemo_safe_synthesizer.sdk.library_builder.Holdout")
630+
def test_invalid_orderby_raises_before_holdout(
631+
self,
632+
mock_holdout_cls,
633+
fixture_workdir: Workdir,
634+
fixture_sample_patient_dataframe: pd.DataFrame,
635+
) -> None:
636+
"""Missing order-by column raises immediately during ``process_data``.
637+
638+
This catches invalid ``order_training_examples_by`` before holdout split
639+
or autoconfig runs, ensuring a clear ``ParameterError`` instead of a
640+
downstream pandas error.
641+
"""
642+
ss = SafeSynthesizer(
643+
config=SafeSynthesizerParameters.from_params(
644+
group_training_examples_by="patient_name",
645+
order_training_examples_by="non_existent_order",
646+
),
647+
workdir=fixture_workdir,
648+
).with_data_source(fixture_sample_patient_dataframe)
649+
650+
with pytest.raises(ParameterError, match="Order by column 'non_existent_order' not found"):
623651
ss.process_data()
624652

625653
mock_holdout_cls.assert_not_called()

0 commit comments

Comments
 (0)