Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/user-guide/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,25 @@ safe-synthesizer config validate --config config.yaml
`data.group_training_examples_by`, config validation will fail. Ordering only
makes sense within groups.

`group_training_examples_by` with comma-separated column names:

: Setting `data.group_training_examples_by: col1,col2` in YAML is parsed as the
single string `"col1,col2"`, not as two separate columns. The pipeline will
fail with a `ParameterError` when it tries to find a column literally named
`"col1,col2"` in your data:

```text
ParameterError: Group by column 'patient_id,event_id' not found in the input data.
The column name contains a comma -- multi-column grouping is not supported.
Use a single column name.
```

Only a single column name is supported. Multi-column grouping is not
currently available. If you need to group by multiple columns, consider
creating a composite column in your data before running the pipeline
(e.g. concatenate `patient_id` and `event_id` into a new
`patient_event_id` column).

Unsupported file extensions:

: The `url` parameter accepts `.csv`, `.json`, `.jsonl`, `.parquet`, and `.txt`
Expand Down
4 changes: 2 additions & 2 deletions src/nemo_safe_synthesizer/config/autoconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def choose_num_input_records_to_sample(rope_scaling_factor: int) -> int:
return rope_scaling_factor * 25_000


def get_max_token_count(data: pd.DataFrame, group_by: list[str] | str | None) -> int:
def get_max_token_count(data: pd.DataFrame, group_by: str | None) -> int:
"""Estimate the maximum tokens per training example.

Accounts for prompt overhead (~40 tokens), column names (repeated in JSON
Expand All @@ -56,7 +56,7 @@ def get_max_token_count(data: pd.DataFrame, group_by: list[str] | str | None) ->

Args:
data: Training dataframe to analyze.
group_by: Column(s) used to group records into single training examples.
group_by: Column used to group records into single training examples.
When set, grouped records are concatenated before token estimation.

Returns:
Expand Down
10 changes: 7 additions & 3 deletions src/nemo_safe_synthesizer/data_processing/assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,10 @@ def _validate_columns(self, dataset: Dataset) -> None:
ParameterError: If group or order column is not found in dataset.
"""
if self.group_by_column not in dataset.column_names:
raise ParameterError(f"Group by column '{self.group_by_column}' not found in dataset.")
msg = f"Group by column {self.group_by_column!r} not found in dataset."
if "," in self.group_by_column:
msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name."
raise ParameterError(msg)

if self.order_by_column not in dataset.column_names:
raise ParameterError(f"Order by column '{self.order_by_column}' not found in dataset.")
Expand Down Expand Up @@ -1312,11 +1315,12 @@ def __init__(
if keep_columns:
required_columns = list(set(required_columns + keep_columns))

# We need to split the dataset first so that the grouping column(s) are still present when we invoke
# We need to split the dataset first so that the grouping column is still present when we invoke
# `utils.grouped_train_test_split`. After the split we tokenize and perform the (potentially expensive) grouping step independently for
# train and test.
if test_size is not None and test_size > 0:
df_dataset = cast(pd.DataFrame, dataset.to_pandas())
df_dataset = dataset.to_pandas()
assert isinstance(df_dataset, pd.DataFrame)
train_raw, test_raw = grouped_train_test_split(
df_dataset,
group_by=self.group_by[0],
Expand Down
10 changes: 4 additions & 6 deletions src/nemo_safe_synthesizer/generation/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ class GroupedDataProcessor(Processor):
"""Processor for grouped data generation tasks.

Used when training examples are grouped (and optionally ordered) by
one or more columns. Validates that each group has a unique
``group_by`` value and respects the ``order_by`` ordering.
a column. Validates that each group has a unique ``group_by`` value
and respects the ``order_by`` ordering.

Args:
schema: JSON schema as a dictionary.
Expand All @@ -199,13 +199,11 @@ def __init__(
config: ValidationParameters,
bos_token: str,
eos_token: str,
group_by: str | list[str],
group_by: str,
order_by: str | None = None,
):
super().__init__(schema=schema, config=config)
if isinstance(group_by, str):
group_by = [group_by]
self.group_by = group_by
self.group_by: list[str] = [group_by]
self.order_by = order_by
self.bos_token = bos_token
self.eos_token = eos_token
Expand Down
7 changes: 5 additions & 2 deletions src/nemo_safe_synthesizer/holdout/holdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def naive_train_test_split(


def grouped_train_test_split(
df: pd.DataFrame, test_size: float | int, group_by: str | list[str], random_state: int | None = None
df: pd.DataFrame, test_size: int | float, group_by: str, random_state: int | None = None
) -> DataFrameOptionalTuple:
"""Split a dataframe so that all rows sharing a group stay in the same fold.

Expand Down Expand Up @@ -189,7 +189,10 @@ def train_test_split(self, input_df: pd.DataFrame) -> DataFrameOptionalTuple:
)

if self.group_by is not None and self.group_by not in input_df.columns:
logger.warning(f"Group By column {self.group_by} not found in input Dataset columns! Doing a normal split.")
msg = f"Group by column {self.group_by!r} not found in input dataset columns. Doing a normal split."
if "," in self.group_by:
msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name."
logger.warning(msg)
self.group_by = None
if self.group_by is not None and input_df[self.group_by].isna().any():
raise ValueError(f"Group by column '{self.group_by}' has missing values. Please remove/replace them.")
Expand Down
4 changes: 3 additions & 1 deletion src/nemo_safe_synthesizer/training/huggingface_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,9 @@ def _validate_groupby_column(self, df: pd.DataFrame) -> None:
return

if col not in df.columns:
msg = f"Group by column '{col}' not found in the input data."
msg = f"Group by column {col!r} not found in the input data."
if "," in col:
msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name."
logger.error(msg)
raise ParameterError(msg)

Expand Down
5 changes: 3 additions & 2 deletions src/nemo_safe_synthesizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def time_closure(*args: Any, **kwargs: Any) -> Any:
def grouped_train_test_split(
dataset: Dataset,
test_size: float,
group_by: str | list[str],
group_by: str,
seed: int | None = None,
) -> tuple[DataFrame, DataFrame | None]:
"""Split a HuggingFace Dataset preserving group membership.
Expand All @@ -198,7 +198,7 @@ def grouped_train_test_split(
Args:
dataset: The HuggingFace ``Dataset`` to split.
test_size: Fraction or absolute number of test rows.
group_by: Column name or list of column names defining groups.
group_by: Column name defining groups.
seed: Random state for reproducibility.

Returns:
Expand All @@ -207,6 +207,7 @@ def grouped_train_test_split(
"""
# Convert to pandas for group operations
df = dataset.to_pandas()
assert isinstance(df, pd.DataFrame)
# importing like this to avoid a dep for testing on the sdk side
from .holdout import holdout as nss_holdout

Expand Down
18 changes: 18 additions & 0 deletions tests/config/test_nss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,21 @@ def test_time_series_configuration_passes_validation(self):
def test_read_from_yaml(self, yaml_config_str):
p = SafeSynthesizerParameters.from_yaml_str(yaml_config_str)
assert p.get("gradient_accumulation_steps") == 8


class TestGroupTrainingExamplesBy:
def test_single_column_string_accepted(self):
params = DataParameters(group_training_examples_by="patient_id")
assert params.group_training_examples_by == "patient_id"

def test_none_accepted(self):
params = DataParameters(group_training_examples_by=None)
assert params.group_training_examples_by is None

def test_list_rejected_by_pydantic(self):
with pytest.raises(ValidationError):
DataParameters(group_training_examples_by=["patient_id", "event_id"]) # ty: ignore[invalid-argument-type]

def test_comma_separated_string_accepted_by_pydantic(self):
params = DataParameters(group_training_examples_by="patient_id,event_id")
assert params.group_training_examples_by == "patient_id,event_id"
43 changes: 0 additions & 43 deletions tests/generation/test_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,49 +182,6 @@ def test_grouped_data_processor_out_of_order_records(
assert response.errors[-1] == ("Group not ordered", "groupby")


# Purpose: Composite group_by works when the tuple uniquely identifies groups.
# Data: Grouped by ["petal.width", "variety"].
# Asserts: 5 valid; 0 invalid/errors.
def test_grouped_data_processor_multiple_group_by(
fixture_valid_iris_dataset_jsonl_and_schema,
fixture_validation_config: ValidationParameters,
):
jsonl_str, jsonl_schema = fixture_valid_iris_dataset_jsonl_and_schema
groups_jsonl_str = BOS + jsonl_str + EOS
response = GroupedDataProcessor(
schema=jsonl_schema,
config=fixture_validation_config,
group_by=["petal.width", "variety"],
bos_token=BOS,
eos_token=EOS,
)(1, groups_jsonl_str)
assert len(response.valid_records) == 5
assert len(response.invalid_records) == 0
assert len(response.errors) == 0


# Purpose: Composite group_by should fail when the tuple does not uniquely identify groups.
# Data: Grouped by ["sepal.length", "variety"] (non-unique).
# Asserts: 0 valid; 5 invalid/errors; last error indicates non-unique group value.
def test_grouped_data_processor_multiple_group_by_error(
fixture_valid_iris_dataset_jsonl_and_schema,
fixture_validation_config: ValidationParameters,
):
jsonl_str, jsonl_schema = fixture_valid_iris_dataset_jsonl_and_schema
groups_jsonl_str = BOS + jsonl_str + EOS
response = GroupedDataProcessor(
schema=jsonl_schema,
config=fixture_validation_config,
group_by=["sepal.length", "variety"],
bos_token=BOS,
eos_token=EOS,
)(1, groups_jsonl_str)
assert len(response.valid_records) == 0
assert len(response.invalid_records) == 5
assert len(response.errors) == 5
assert response.errors[-1] == ("Groupby value is not unique", "groupby")


# Purpose: group_by_accept_no_delineator=True treats raw JSONL (no BOS/EOS) as a single group.
# Data: Raw JSONL without BOS/EOS (same as test_grouped_data_processor_with_no_groups).
# Asserts: 5 valid; 0 invalid; 0 errors.
Expand Down
8 changes: 7 additions & 1 deletion tests/training/test_huggingface_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,13 @@ def test_raises_when_column_missing(self, backend, sample_dataframe):
"""Test that ParameterError is raised when column is missing."""
backend.params.data.group_training_examples_by = "nonexistent_col"

with pytest.raises(ParameterError, match="Group by column 'nonexistent_col' not found"):
with pytest.raises(ParameterError, match="not found in the input data"):
backend._validate_groupby_column(sample_dataframe)

def test_raises_with_comma_hint_when_column_has_comma(self, backend, sample_dataframe):
backend.params.data.group_training_examples_by = "patient_id,event_id"

with pytest.raises(ParameterError, match="multi-column grouping is not supported"):
backend._validate_groupby_column(sample_dataframe)

def test_raises_when_column_has_nulls(self, backend, dataframe_with_null_group):
Expand Down
Loading