Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/nemo_safe_synthesizer/data_processing/assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RunningStatistics,
Statistics,
)
from ..data_processing.validation import validate_groupby_column, validate_orderby_column
from ..defaults import (
DEFAULT_CACHE_PREFIX,
PSEUDO_GROUP_COLUMN,
Expand Down Expand Up @@ -865,14 +866,8 @@ def _validate_columns(self, dataset: Dataset) -> None:
Raises:
ParameterError: If group or order column is not found in dataset.
"""
if self.group_by_column not in dataset.column_names:
msg = f"Group by column {self.group_by_column!r} not found in dataset."
if "," in self.group_by_column:
msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name."
raise ParameterError(msg)

if self.order_by_column not in dataset.column_names:
raise ParameterError(f"Order by column '{self.order_by_column}' not found in dataset.")
validate_groupby_column(dataset.column_names, self.group_by_column)
validate_orderby_column(dataset.column_names, self.order_by_column)

def _reorder_columns(self, dataset: Dataset) -> Dataset:
"""Reorder columns: group_by first, order_by second, then the rest.
Expand Down
71 changes: 71 additions & 0 deletions src/nemo_safe_synthesizer/data_processing/validation.py
Comment thread
kendrickb-nvidia marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Data validation helpers shared across pipeline stages."""

from __future__ import annotations

from collections.abc import Collection

import pandas as pd

from ..errors import DataError, ParameterError

MISSING_GROUP_BY_COLUMN_ERROR = "Group by column '{group_by}' not found in input dataset columns. "
MISSING_GROUP_BY_VALUES_ERROR = "Group by column '{group_by}' has missing values. Please remove/replace them."
MISSING_ORDER_BY_COLUMN_ERROR = "Order by column '{order_by}' not found in the input data."


def _get_column_names(data: pd.DataFrame | Collection[str]) -> Collection[str]:
if isinstance(data, pd.DataFrame):
return data.columns
return data


def validate_groupby_column(data: pd.DataFrame | Collection[str], group_by: str | None) -> None:
Comment thread
nina-xu marked this conversation as resolved.
"""Validate that the configured group-by column exists and has no missing values.

Args:
data: A DataFrame or collection of column names to validate against.
group_by: Name of the configured grouping column.

Raises:
ParameterError: If ``group_by`` is configured but not present in ``data``.
DataError: If ``data`` is a DataFrame and ``group_by`` contains missing values.
"""
if group_by is None:
return

columns = _get_column_names(data)

if group_by not in columns:
message = MISSING_GROUP_BY_COLUMN_ERROR.format(group_by=group_by)
if "," in group_by:
message += (
" The column name contains a comma -- multi-column grouping is not supported. Use a single column name."
)
else:
message += " Please set `data.group_training_examples_by` to an existing column or to `null`/`None` to disable grouping."
raise ParameterError(message)

if isinstance(data, pd.DataFrame) and data[group_by].isna().any():
raise DataError(MISSING_GROUP_BY_VALUES_ERROR.format(group_by=group_by))


def validate_orderby_column(data: pd.DataFrame | Collection[str], order_by: str | None) -> None:
"""Validate that the configured order-by column exists.

Args:
data: A DataFrame or collection of column names to validate against.
order_by: Name of the configured ordering column.

Raises:
ParameterError: If ``order_by`` is configured but not present in ``data``.
"""
if order_by is None:
return

columns = _get_column_names(data)

if order_by not in columns:
raise ParameterError(MISSING_ORDER_BY_COLUMN_ERROR.format(order_by=order_by))
25 changes: 10 additions & 15 deletions src/nemo_safe_synthesizer/holdout/holdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ..config.data import DEFAULT_HOLDOUT, MIN_HOLDOUT
from ..config.parameters import SafeSynthesizerParameters
from ..data_processing.validation import validate_groupby_column
from ..observability import get_logger

MIN_RECORDS_FOR_TEXT_AND_PRIVACY_METRICS = 200
Expand Down Expand Up @@ -79,11 +80,10 @@ def grouped_train_test_split(
grouped split could be produced.

Raises:
ValueError: If the ``group_by`` column contains missing values.
ParameterError: If the ``group_by`` column is not present in ``df``.
DataError: If the ``group_by`` column contains missing values.
"""
if input_df[group_by].isna().any():
msg = f"Group by column '{group_by}' has missing values. Please remove/replace them."
raise ValueError(msg)
validate_groupby_column(input_df, group_by)

if test_size > input_df.groupby(group_by).ngroups or test_size == 1 or test_size == 0:
logger.info(
Expand Down Expand Up @@ -150,9 +150,11 @@ def train_test_split(self, input_df: pd.DataFrame) -> DataFrameOptionalTuple:

Raises:
ValueError: If the input dataframe has fewer than
``MIN_RECORDS_FOR_TEXT_AND_PRIVACY_METRICS`` rows, if the
computed holdout is smaller than ``MIN_HOLDOUT``, or if
the ``group_by`` column contains missing values.
``MIN_RECORDS_FOR_TEXT_AND_PRIVACY_METRICS`` rows or if the
computed holdout is smaller than ``MIN_HOLDOUT``.
ParameterError: If the configured ``group_by`` column is missing.
DataError: If the configured ``group_by`` column contains missing
values.
"""
if not self.holdout or not self.max_holdout:
return input_df, None
Expand Down Expand Up @@ -187,14 +189,7 @@ def train_test_split(self, input_df: pd.DataFrame) -> DataFrameOptionalTuple:
HOLDOUT_TOO_SMALL_ERROR,
)

if self.group_by is not None and self.group_by not in input_df.columns:
msg = f"Group by column {self.group_by!r} not found in input dataset columns. Doing a normal split."
if "," in self.group_by:
msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name."
logger.warning(msg)
self.group_by = None
if self.group_by is not None and input_df[self.group_by].isna().any():
raise ValueError(f"Group by column '{self.group_by}' has missing values. Please remove/replace them.")
validate_groupby_column(input_df, self.group_by)

if self.group_by:
training_df, test_df = grouped_train_test_split(
Expand Down
12 changes: 9 additions & 3 deletions src/nemo_safe_synthesizer/sdk/library_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SafeSynthesizerParameters,
)
from ..config.autoconfig import AutoConfigResolver
from ..data_processing.validation import validate_groupby_column, validate_orderby_column
from ..evaluation.evaluator import Evaluator
from ..generation.timeseries_backend import TimeseriesBackend
from ..generation.vllm_backend import VllmBackend
Expand Down Expand Up @@ -247,9 +248,11 @@ def load_from_save_path(self) -> SafeSynthesizer:
def process_data(self) -> SafeSynthesizer:
"""Perform train/test split, auto-config resolution, and optional PII replacement.

Splits the data via ``Holdout``, runs ``AutoConfigResolver`` to
resolve ``"auto"`` parameters, applies PII replacement to the
training set when enabled, and persists the splits to the workdir.
Validates configured grouping/ordering columns against the input
dataset, splits the data via ``Holdout``, runs
``AutoConfigResolver`` to resolve ``"auto"`` parameters, applies
PII replacement to the training set when enabled, and persists the
splits to the workdir.

Returns:
Self for method chaining.
Expand All @@ -271,6 +274,9 @@ def process_data(self) -> SafeSynthesizer:
assert self._nss_config is not None
assert isinstance(self._data_source, pd.DataFrame)

validate_groupby_column(self._data_source, self._nss_config.data.group_training_examples_by)
validate_orderby_column(self._data_source, self._nss_config.data.order_training_examples_by)

holdout = Holdout(self._nss_config)
original_training_df, self._test_df = holdout.train_test_split(self._data_source)

Expand Down
10 changes: 5 additions & 5 deletions src/nemo_safe_synthesizer/training/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def __subclasshook__(cls, subclass):
def prepare_training_data(self) -> None:
"""Load, validate, and tokenize the training dataset.

Runs auto-config resolution, validates groupby/orderby columns,
applies time-series processing and ``action_executor`` preprocessing,
then assembles tokenized training examples. Populates
``training_examples``, ``dataset_schema``, ``training_df``, and
``data_fraction``.
Validates grouping/ordering columns (where applicable), resolves
auto-config values, applies time-series processing and
``action_executor`` preprocessing, then assembles tokenized training
examples. Populates ``training_examples``, ``dataset_schema``,
``training_df``, and ``data_fraction``.
"""
...

Expand Down
41 changes: 6 additions & 35 deletions src/nemo_safe_synthesizer/training/huggingface_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..config.autoconfig import AutoConfigResolver
from ..data_processing.assembler import TrainingExampleAssembler
from ..data_processing.dataset import make_json_schema
from ..data_processing.validation import validate_groupby_column, validate_orderby_column
from ..defaults import (
DEFAULT_VALID_RECORD_EVAL_BATCH_SIZE,
EVAL_STEPS,
Expand Down Expand Up @@ -538,32 +539,6 @@ def prepare_params(self, **training_args: Any) -> None:
self.trainer = self._create_trainer(self.train_args, data_collator)
self._configure_trainer_callbacks(self.trainer, training_args)

def _validate_groupby_column(self, df: pd.DataFrame) -> None:
"""Validate the groupby column exists and has no missing values.

Args:
df: The DataFrame to validate.

Raises:
ParameterError: If the groupby column doesn't exist.
DataError: If the groupby column has missing values.
"""
col = self.params.data.group_training_examples_by
if col is None:
return

if col not in df.columns:
msg = f"Group by column {col!r} not found in the input data."
if "," in col:
msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name."
logger.error(msg)
raise ParameterError(msg)

if df[col].isnull().any():
msg = f"Group by column '{col}' has missing values. Please remove/replace them."
logger.error(msg)
raise DataError(msg)

def _validate_orderby_column(self, df: pd.DataFrame) -> None:
"""Validate the orderby column exists in the dataset.

Expand All @@ -580,10 +555,7 @@ def _validate_orderby_column(self, df: pd.DataFrame) -> None:
if self.params.time_series.is_timeseries and self.params.time_series.timestamp_column is None:
return

if orderby_col and orderby_col not in df.columns:
msg = f"Order by column '{orderby_col}' not found in the input data."
logger.error(msg)
raise ParameterError(msg)
validate_orderby_column(df, orderby_col)

def _apply_preprocessing(self, df: pd.DataFrame) -> pd.DataFrame:
"""Apply action_executor preprocessing if available.
Expand Down Expand Up @@ -664,8 +636,8 @@ def _log_dataset_statistics(self, assembler: TrainingExampleAssembler) -> None:
def prepare_training_data(self) -> None:
"""Validate, preprocess, and tokenize the training dataset.

Runs auto-config resolution, time-series processing, groupby /
orderby validation, and assembles tokenized training examples.
Validates groupby/orderby columns, resolves auto-config values,
runs time-series preprocessing, and assembles tokenized training examples.
Populates ``training_examples``, ``dataset_schema``,
``training_df``, and ``data_fraction``.

Expand All @@ -681,11 +653,10 @@ def prepare_training_data(self) -> None:
if not isinstance(training_df, pd.DataFrame):
raise DataError("Expected DataFrame from to_pandas(), got an iterator")

self.params = AutoConfigResolver(training_df, self.params).resolve()

# Validate groupby/orderby parameters as a preprocessing step.
self._validate_groupby_column(training_df)
validate_groupby_column(training_df, self.params.data.group_training_examples_by)
self._validate_orderby_column(training_df)
self.params = AutoConfigResolver(training_df, self.params).resolve()

Comment thread
kendrickb-nvidia marked this conversation as resolved.
# Process time series data (sort by timestamp, infer intervals, etc.)
training_df = self._process_timeseries(training_df)
Expand Down
4 changes: 2 additions & 2 deletions tests/data_processing/test_assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def test_sequential_assembler_raises_for_missing_group_column(
fixture_sequential_metadata: ModelMetadata,
):
"""Test that SequentialExampleAssembler raises for missing group column."""
with pytest.raises(ParameterError, match="Group by column.*not found in dataset"):
with pytest.raises(ParameterError, match="Group by column.*not found"):
SequentialExampleAssembler(
dataset=fixture_iris_dataset,
tokenizer=fixture_tokenizer,
Expand All @@ -680,7 +680,7 @@ def test_sequential_assembler_raises_for_missing_order_column(
fixture_sequential_metadata: ModelMetadata,
):
"""Test that SequentialExampleAssembler raises for missing order column."""
with pytest.raises(ParameterError, match="Order by column.*not found in dataset"):
with pytest.raises(ParameterError, match="Order by column.*not found"):
SequentialExampleAssembler(
dataset=fixture_iris_dataset,
tokenizer=fixture_tokenizer,
Expand Down
59 changes: 59 additions & 0 deletions tests/data_processing/test_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import pandas as pd
import pytest

from nemo_safe_synthesizer.data_processing.validation import (
validate_groupby_column,
validate_orderby_column,
)
from nemo_safe_synthesizer.errors import DataError, ParameterError


def test_validate_groupby_column_noop_when_groupby_is_none() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
validate_groupby_column(df, None)


def test_validate_groupby_column_passes_when_column_exists() -> None:
df = pd.DataFrame(
{
"col1": [1, 2, 3, 4, 5],
"col2": ["a", "b", "c", "d", "e"],
"group_col": ["g1", "g1", "g2", "g2", "g3"],
}
)
validate_groupby_column(df, "group_col")


def test_validate_groupby_column_raises_for_missing_column() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(
ParameterError,
match=r"Group by column 'missing_group' not found in input dataset columns.*disable grouping",
):
validate_groupby_column(df, "missing_group")


def test_validate_groupby_column_raises_for_comma_in_name() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(ParameterError, match="multi-column grouping is not supported"):
validate_groupby_column(df, "col1,col2")


def test_validate_groupby_column_raises_for_missing_values() -> None:
df = pd.DataFrame({"group": ["x", None], "value": [1, 2]})
with pytest.raises(DataError, match="missing values"):
validate_groupby_column(df, "group")


def test_validate_orderby_column_noop_when_orderby_is_none() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
validate_orderby_column(df, None)


def test_validate_orderby_column_raises_for_missing_column() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(ParameterError, match="not found in the input data"):
validate_orderby_column(df, "missing_order")
17 changes: 12 additions & 5 deletions tests/holdout/test_holdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from nemo_safe_synthesizer.config.parameters import SafeSynthesizerParameters
from nemo_safe_synthesizer.errors import DataError, ParameterError
from nemo_safe_synthesizer.holdout.holdout import (
HOLDOUT_TOO_SMALL_ERROR,
INPUT_DATA_TOO_SMALL_ERROR,
Expand Down Expand Up @@ -103,12 +104,18 @@ def test_does_group_by_holdout(df):
assert len(test) == 100


def test_skips_group_by_holdout_with_bad_column(df):
def test_raises_on_group_by_holdout_with_bad_column(df):
holdout = Holdout(SafeSynthesizerParameters.from_params(group_training_examples_by="dne"))
train, test = holdout.train_test_split(df)
assert len(train) == 190
assert test is not None
assert len(test) == 10
with pytest.raises(ParameterError, match="Group by column 'dne' not found"):
holdout.train_test_split(df)


def test_raises_on_group_by_holdout_with_missing_values(df):
df_with_missing_group = df.copy()
df_with_missing_group.loc[0, "big_cat"] = None
holdout = Holdout(SafeSynthesizerParameters.from_params(group_training_examples_by="big_cat"))
with pytest.raises(DataError, match="Group by column 'big_cat' has missing values"):
holdout.train_test_split(df_with_missing_group)


def test_complains_when_training_dataset_is_too_small():
Expand Down
Loading
Loading