Skip to content

Commit 673867d

Browse files
committed
dedupe tests
Signed-off-by: nina-xu <19981858+nina-xu@users.noreply.github.com>
1 parent 879dc0f commit 673867d

2 files changed

Lines changed: 16 additions & 39 deletions

File tree

tests/data_processing/test_validation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,23 @@ def test_validate_groupby_column_noop_when_groupby_is_none() -> None:
1616
validate_groupby_column(df, None)
1717

1818

19+
def test_validate_groupby_column_passes_when_column_exists() -> None:
20+
df = pd.DataFrame(
21+
{
22+
"col1": [1, 2, 3, 4, 5],
23+
"col2": ["a", "b", "c", "d", "e"],
24+
"group_col": ["g1", "g1", "g2", "g2", "g3"],
25+
}
26+
)
27+
validate_groupby_column(df, "group_col")
28+
29+
1930
def test_validate_groupby_column_raises_for_missing_column() -> None:
2031
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
21-
with pytest.raises(ParameterError, match="disable grouping"):
32+
with pytest.raises(
33+
ParameterError,
34+
match=r"Group by column 'missing_group' not found in input dataset columns.*disable grouping",
35+
):
2236
validate_groupby_column(df, "missing_group")
2337

2438

tests/training/test_huggingface_backend.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
SafeSynthesizerParameters,
1919
TrainingHyperparams,
2020
)
21-
from nemo_safe_synthesizer.data_processing.validation import validate_groupby_column
22-
from nemo_safe_synthesizer.errors import DataError, ParameterError
21+
from nemo_safe_synthesizer.errors import ParameterError
2322
from nemo_safe_synthesizer.training.huggingface_backend import (
2423
HuggingFaceBackend,
2524
compute_metrics,
@@ -209,17 +208,6 @@ def sample_dataframe():
209208
)
210209

211210

212-
@pytest.fixture
213-
def dataframe_with_null_group():
214-
"""Create a DataFrame with null values in the group column."""
215-
return pd.DataFrame(
216-
{
217-
"col1": [1, 2, 3],
218-
"group_col": ["g1", None, "g2"],
219-
}
220-
)
221-
222-
223211
class TestFilterModelKwargs:
224212
def test_filters_trainer_specific_keys(self, backend):
225213
"""Test that trainer-specific keys are filtered out."""
@@ -561,31 +549,6 @@ def test_uses_provided_data_collator(self, backend):
561549
assert "data_collator" not in training_args
562550

563551

564-
class TestValidateGroupbyColumn:
565-
def test_does_nothing_when_no_groupby(self, sample_dataframe):
566-
"""Test that nothing happens when groupby is None."""
567-
validate_groupby_column(sample_dataframe, None) # Should not raise
568-
569-
def test_passes_when_column_exists(self, sample_dataframe):
570-
"""Test that validation passes when column exists."""
571-
validate_groupby_column(sample_dataframe, "group_col") # Should not raise
572-
573-
def test_raises_when_column_missing(self, sample_dataframe):
574-
"""Test that ParameterError is raised when column is missing."""
575-
with pytest.raises(ParameterError, match="Group by column 'nonexistent_col' not found"):
576-
validate_groupby_column(sample_dataframe, "nonexistent_col")
577-
578-
def test_raises_with_comma_hint_when_column_has_comma(self, sample_dataframe):
579-
"""Test that ParameterError is raised when column name has a comma."""
580-
with pytest.raises(ParameterError, match="multi-column grouping is not supported"):
581-
validate_groupby_column(sample_dataframe, "patient_id,event_id")
582-
583-
def test_raises_when_column_has_nulls(self, dataframe_with_null_group):
584-
"""Test that DataError is raised when column has null values."""
585-
with pytest.raises(DataError, match="has missing values"):
586-
validate_groupby_column(dataframe_with_null_group, "group_col")
587-
588-
589552
class TestValidateOrderbyColumn:
590553
def test_does_nothing_when_no_orderby(self, backend, sample_dataframe):
591554
"""Test that nothing happens when orderby is None."""

0 commit comments

Comments
 (0)