diff --git a/docs/user-guide/troubleshooting.md b/docs/user-guide/troubleshooting.md index 8afcafff..d2a809d5 100644 --- a/docs/user-guide/troubleshooting.md +++ b/docs/user-guide/troubleshooting.md @@ -423,6 +423,25 @@ safe-synthesizer config validate --config config.yaml `data.group_training_examples_by`, config validation will fail. Ordering only makes sense within groups. +`group_training_examples_by` with comma-separated column names: + +: Setting `data.group_training_examples_by: col1,col2` in YAML is parsed as the + single string `"col1,col2"`, not as two separate columns. The pipeline will + fail with a `ParameterError` when it tries to find a column literally named + `"col1,col2"` in your data: + + ```text + ParameterError: Group by column 'patient_id,event_id' not found in the input data. + The column name contains a comma -- multi-column grouping is not supported. + Use a single column name. + ``` + + Only a single column name is supported. Multi-column grouping is not + currently available. If you need to group by multiple columns, consider + creating a composite column in your data before running the pipeline + (e.g. concatenate `patient_id` and `event_id` into a new + `patient_event_id` column). + Unsupported file extensions: : The `url` parameter accepts `.csv`, `.json`, `.jsonl`, `.parquet`, and `.txt` diff --git a/src/nemo_safe_synthesizer/config/autoconfig.py b/src/nemo_safe_synthesizer/config/autoconfig.py index c96b6bd8..7c84444f 100644 --- a/src/nemo_safe_synthesizer/config/autoconfig.py +++ b/src/nemo_safe_synthesizer/config/autoconfig.py @@ -46,7 +46,7 @@ def choose_num_input_records_to_sample(rope_scaling_factor: int) -> int: return rope_scaling_factor * 25_000 -def get_max_token_count(data: pd.DataFrame, group_by: list[str] | str | None) -> int: +def get_max_token_count(data: pd.DataFrame, group_by: str | None) -> int: """Estimate the maximum tokens per training example. Accounts for prompt overhead (~40 tokens), column names (repeated in JSON @@ -56,7 +56,7 @@ def get_max_token_count(data: pd.DataFrame, group_by: list[str] | str | None) -> Args: data: Training dataframe to analyze. - group_by: Column(s) used to group records into single training examples. + group_by: Column used to group records into single training examples. When set, grouped records are concatenated before token estimation. Returns: diff --git a/src/nemo_safe_synthesizer/data_processing/assembler.py b/src/nemo_safe_synthesizer/data_processing/assembler.py index bbc2d227..63a48ee3 100644 --- a/src/nemo_safe_synthesizer/data_processing/assembler.py +++ b/src/nemo_safe_synthesizer/data_processing/assembler.py @@ -867,7 +867,10 @@ def _validate_columns(self, dataset: Dataset) -> None: ParameterError: If group or order column is not found in dataset. """ if self.group_by_column not in dataset.column_names: - raise ParameterError(f"Group by column '{self.group_by_column}' not found in dataset.") + msg = f"Group by column {self.group_by_column!r} not found in dataset." + if "," in self.group_by_column: + msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name." + raise ParameterError(msg) if self.order_by_column not in dataset.column_names: raise ParameterError(f"Order by column '{self.order_by_column}' not found in dataset.") @@ -1312,11 +1315,12 @@ def __init__( if keep_columns: required_columns = list(set(required_columns + keep_columns)) - # We need to split the dataset first so that the grouping column(s) are still present when we invoke + # We need to split the dataset first so that the grouping column is still present when we invoke # `utils.grouped_train_test_split`. After the split we tokenize and perform the (potentially expensive) grouping step independently for # train and test. if test_size is not None and test_size > 0: - df_dataset = cast(pd.DataFrame, dataset.to_pandas()) + df_dataset = dataset.to_pandas() + assert isinstance(df_dataset, pd.DataFrame) train_raw, test_raw = grouped_train_test_split( df_dataset, group_by=self.group_by[0], diff --git a/src/nemo_safe_synthesizer/generation/processors.py b/src/nemo_safe_synthesizer/generation/processors.py index fe498eee..02adb410 100644 --- a/src/nemo_safe_synthesizer/generation/processors.py +++ b/src/nemo_safe_synthesizer/generation/processors.py @@ -179,8 +179,8 @@ class GroupedDataProcessor(Processor): """Processor for grouped data generation tasks. Used when training examples are grouped (and optionally ordered) by - one or more columns. Validates that each group has a unique - ``group_by`` value and respects the ``order_by`` ordering. + a column. Validates that each group has a unique ``group_by`` value + and respects the ``order_by`` ordering. Args: schema: JSON schema as a dictionary. @@ -199,13 +199,11 @@ def __init__( config: ValidationParameters, bos_token: str, eos_token: str, - group_by: str | list[str], + group_by: str, order_by: str | None = None, ): super().__init__(schema=schema, config=config) - if isinstance(group_by, str): - group_by = [group_by] - self.group_by = group_by + self.group_by: list[str] = [group_by] self.order_by = order_by self.bos_token = bos_token self.eos_token = eos_token diff --git a/src/nemo_safe_synthesizer/holdout/holdout.py b/src/nemo_safe_synthesizer/holdout/holdout.py index f1f94726..8011f14e 100644 --- a/src/nemo_safe_synthesizer/holdout/holdout.py +++ b/src/nemo_safe_synthesizer/holdout/holdout.py @@ -59,7 +59,7 @@ def naive_train_test_split( def grouped_train_test_split( - df: pd.DataFrame, test_size: float | int, group_by: str | list[str], random_state: int | None = None + df: pd.DataFrame, test_size: int | float, group_by: str, random_state: int | None = None ) -> DataFrameOptionalTuple: """Split a dataframe so that all rows sharing a group stay in the same fold. @@ -189,7 +189,10 @@ def train_test_split(self, input_df: pd.DataFrame) -> DataFrameOptionalTuple: ) if self.group_by is not None and self.group_by not in input_df.columns: - logger.warning(f"Group By column {self.group_by} not found in input Dataset columns! Doing a normal split.") + msg = f"Group by column {self.group_by!r} not found in input dataset columns. Doing a normal split." + if "," in self.group_by: + msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name." + logger.warning(msg) self.group_by = None if self.group_by is not None and input_df[self.group_by].isna().any(): raise ValueError(f"Group by column '{self.group_by}' has missing values. Please remove/replace them.") diff --git a/src/nemo_safe_synthesizer/training/huggingface_backend.py b/src/nemo_safe_synthesizer/training/huggingface_backend.py index 8cfb2074..a2ebb99f 100644 --- a/src/nemo_safe_synthesizer/training/huggingface_backend.py +++ b/src/nemo_safe_synthesizer/training/huggingface_backend.py @@ -553,7 +553,9 @@ def _validate_groupby_column(self, df: pd.DataFrame) -> None: return if col not in df.columns: - msg = f"Group by column '{col}' not found in the input data." + msg = f"Group by column {col!r} not found in the input data." + if "," in col: + msg += " The column name contains a comma -- multi-column grouping is not supported. Use a single column name." logger.error(msg) raise ParameterError(msg) diff --git a/src/nemo_safe_synthesizer/utils.py b/src/nemo_safe_synthesizer/utils.py index 53b49eb4..85c5ff49 100644 --- a/src/nemo_safe_synthesizer/utils.py +++ b/src/nemo_safe_synthesizer/utils.py @@ -187,7 +187,7 @@ def time_closure(*args: Any, **kwargs: Any) -> Any: def grouped_train_test_split( dataset: Dataset, test_size: float, - group_by: str | list[str], + group_by: str, seed: int | None = None, ) -> tuple[DataFrame, DataFrame | None]: """Split a HuggingFace Dataset preserving group membership. @@ -198,7 +198,7 @@ def grouped_train_test_split( Args: dataset: The HuggingFace ``Dataset`` to split. test_size: Fraction or absolute number of test rows. - group_by: Column name or list of column names defining groups. + group_by: Column name defining groups. seed: Random state for reproducibility. Returns: @@ -207,6 +207,7 @@ def grouped_train_test_split( """ # Convert to pandas for group operations df = dataset.to_pandas() + assert isinstance(df, pd.DataFrame) # importing like this to avoid a dep for testing on the sdk side from .holdout import holdout as nss_holdout diff --git a/tests/config/test_nss_config.py b/tests/config/test_nss_config.py index 8fb33989..8f51d8fd 100644 --- a/tests/config/test_nss_config.py +++ b/tests/config/test_nss_config.py @@ -180,3 +180,21 @@ def test_time_series_configuration_passes_validation(self): def test_read_from_yaml(self, yaml_config_str): p = SafeSynthesizerParameters.from_yaml_str(yaml_config_str) assert p.get("gradient_accumulation_steps") == 8 + + +class TestGroupTrainingExamplesBy: + def test_single_column_string_accepted(self): + params = DataParameters(group_training_examples_by="patient_id") + assert params.group_training_examples_by == "patient_id" + + def test_none_accepted(self): + params = DataParameters(group_training_examples_by=None) + assert params.group_training_examples_by is None + + def test_list_rejected_by_pydantic(self): + with pytest.raises(ValidationError): + DataParameters(group_training_examples_by=["patient_id", "event_id"]) # ty: ignore[invalid-argument-type] + + def test_comma_separated_string_accepted_by_pydantic(self): + params = DataParameters(group_training_examples_by="patient_id,event_id") + assert params.group_training_examples_by == "patient_id,event_id" diff --git a/tests/generation/test_processors.py b/tests/generation/test_processors.py index bbd08ee1..6f9a45f1 100644 --- a/tests/generation/test_processors.py +++ b/tests/generation/test_processors.py @@ -182,49 +182,6 @@ def test_grouped_data_processor_out_of_order_records( assert response.errors[-1] == ("Group not ordered", "groupby") -# Purpose: Composite group_by works when the tuple uniquely identifies groups. -# Data: Grouped by ["petal.width", "variety"]. -# Asserts: 5 valid; 0 invalid/errors. -def test_grouped_data_processor_multiple_group_by( - fixture_valid_iris_dataset_jsonl_and_schema, - fixture_validation_config: ValidationParameters, -): - jsonl_str, jsonl_schema = fixture_valid_iris_dataset_jsonl_and_schema - groups_jsonl_str = BOS + jsonl_str + EOS - response = GroupedDataProcessor( - schema=jsonl_schema, - config=fixture_validation_config, - group_by=["petal.width", "variety"], - bos_token=BOS, - eos_token=EOS, - )(1, groups_jsonl_str) - assert len(response.valid_records) == 5 - assert len(response.invalid_records) == 0 - assert len(response.errors) == 0 - - -# Purpose: Composite group_by should fail when the tuple does not uniquely identify groups. -# Data: Grouped by ["sepal.length", "variety"] (non-unique). -# Asserts: 0 valid; 5 invalid/errors; last error indicates non-unique group value. -def test_grouped_data_processor_multiple_group_by_error( - fixture_valid_iris_dataset_jsonl_and_schema, - fixture_validation_config: ValidationParameters, -): - jsonl_str, jsonl_schema = fixture_valid_iris_dataset_jsonl_and_schema - groups_jsonl_str = BOS + jsonl_str + EOS - response = GroupedDataProcessor( - schema=jsonl_schema, - config=fixture_validation_config, - group_by=["sepal.length", "variety"], - bos_token=BOS, - eos_token=EOS, - )(1, groups_jsonl_str) - assert len(response.valid_records) == 0 - assert len(response.invalid_records) == 5 - assert len(response.errors) == 5 - assert response.errors[-1] == ("Groupby value is not unique", "groupby") - - # Purpose: group_by_accept_no_delineator=True treats raw JSONL (no BOS/EOS) as a single group. # Data: Raw JSONL without BOS/EOS (same as test_grouped_data_processor_with_no_groups). # Asserts: 5 valid; 0 invalid; 0 errors. diff --git a/tests/training/test_huggingface_backend.py b/tests/training/test_huggingface_backend.py index 96eb5df6..e12b4b25 100644 --- a/tests/training/test_huggingface_backend.py +++ b/tests/training/test_huggingface_backend.py @@ -574,7 +574,13 @@ def test_raises_when_column_missing(self, backend, sample_dataframe): """Test that ParameterError is raised when column is missing.""" backend.params.data.group_training_examples_by = "nonexistent_col" - with pytest.raises(ParameterError, match="Group by column 'nonexistent_col' not found"): + with pytest.raises(ParameterError, match="not found in the input data"): + backend._validate_groupby_column(sample_dataframe) + + def test_raises_with_comma_hint_when_column_has_comma(self, backend, sample_dataframe): + backend.params.data.group_training_examples_by = "patient_id,event_id" + + with pytest.raises(ParameterError, match="multi-column grouping is not supported"): backend._validate_groupby_column(sample_dataframe) def test_raises_when_column_has_nulls(self, backend, dataframe_with_null_group):