1414from sdv .metadata .metadata import Metadata
1515from sdv .metadata .single_table import SingleTableMetadata
1616from sdv .sampling import Condition
17- from sdv .sequential .par import PARSynthesizer
17+ from sdv .sequential .par import PARSynthesizer , _diff_and_bfill
1818from sdv .single_table .base import BaseSynthesizer
1919from sdv .single_table .copulas import GaussianCopulaSynthesizer
2020
2121
22+ def test__diff_and_bfill ():
23+ """Test the ``_diff_and_bfill`` method."""
24+ # Setup
25+ data = pd .Series ([10 , 15 , 20 , 30 ])
26+
27+ # Run
28+ result = _diff_and_bfill (data )
29+
30+ # Assert
31+ expected = pd .Series ([5.0 , 5.0 , 5.0 , 10.0 ])
32+ pd .testing .assert_series_equal (result , expected )
33+
34+
2235class TestPARSynthesizer :
2336 def get_metadata (self , add_sequence_key = True , add_sequence_index = False ):
2437 metadata = Metadata ()
@@ -283,6 +296,7 @@ def test_validate_context_columns_unique_per_sequence_key(self):
283296 with pytest .raises (InvalidDataError , match = err_msg ):
284297 instance .validate (data )
285298
299+ @pytest .mark .filterwarnings ('error::FutureWarning' )
286300 def test__transform_sequence (self ):
287301 # Setup
288302 metadata = self .get_metadata (add_sequence_index = True )
@@ -310,6 +324,7 @@ def test__transform_sequence(self):
310324 assert list (par .extended_columns .keys ()) == ['time' ]
311325 assert par .extended_columns ['time' ].enforce_min_max_values is True
312326
327+ @pytest .mark .filterwarnings ('error::FutureWarning' )
313328 def test__transform_sequence_index_single_instances (self ):
314329 # Setup
315330 metadata = self .get_metadata (add_sequence_index = True )
@@ -332,6 +347,7 @@ def test__transform_sequence_index_single_instances(self):
332347 assert list (par .extended_columns .keys ()) == ['time' ]
333348 assert par .extended_columns ['time' ].enforce_min_max_values is True
334349
350+ @pytest .mark .filterwarnings ('error::FutureWarning' )
335351 def test__transform_sequence_index_non_unique_sequence_key (self ):
336352 # Setup
337353 metadata = self .get_metadata (add_sequence_index = True )
@@ -833,6 +849,7 @@ def test__sample_from_par_with_sequence_key(self, tqdm_mock):
833849 })
834850 pd .testing .assert_frame_equal (sampled , expected_output )
835851
852+ @pytest .mark .filterwarnings ('error::FutureWarning' )
836853 @patch ('sdv.sequential.par.tqdm' )
837854 def test__sample_from_par_with_sequence_index (self , tqdm_mock ):
838855 """Test that the method handles the sequence index properly.
@@ -1245,6 +1262,9 @@ def test_sample_with_all_null_column_categorical(self):
12451262 assert result ['all_null_cat_col' ].isna ().all ()
12461263 assert len (result ) > 0
12471264
1265+ @pytest .mark .filterwarnings (
1266+ 'error:Series.__getitem__ treating keys as positions is deprecated:FutureWarning'
1267+ )
12481268 def test_sample_with_multiple_all_null_columns (self ):
12491269 """Test that sampling works correctly with multiple all-null columns."""
12501270 # Setup
@@ -1257,15 +1277,21 @@ def test_sample_with_multiple_all_null_columns(self):
12571277 'all_null_col2' : [np .nan ] * 9 ,
12581278 })
12591279
1260- metadata = Metadata ()
1261- metadata .add_table ('table' )
1262- metadata .add_column ('time' , 'table' , sdtype = 'datetime' )
1263- metadata .add_column ('gender' , 'table' , sdtype = 'categorical' )
1264- metadata .add_column ('name' , 'table' , sdtype = 'id' )
1265- metadata .add_column ('measurement' , 'table' , sdtype = 'numerical' )
1266- metadata .add_column ('all_null_col1' , 'table' , sdtype = 'numerical' )
1267- metadata .add_column ('all_null_col2' , 'table' , sdtype = 'categorical' )
1268- metadata .set_sequence_key ('name' , 'table' )
1280+ metadata = Metadata ().load_from_dict ({
1281+ 'tables' : {
1282+ 'table' : {
1283+ 'columns' : {
1284+ 'time' : {'sdtype' : 'datetime' },
1285+ 'gender' : {'sdtype' : 'categorical' },
1286+ 'name' : {'sdtype' : 'id' },
1287+ 'measurement' : {'sdtype' : 'numerical' },
1288+ 'all_null_col1' : {'sdtype' : 'numerical' },
1289+ 'all_null_col2' : {'sdtype' : 'categorical' },
1290+ },
1291+ 'sequence_key' : 'name' ,
1292+ }
1293+ }
1294+ })
12691295
12701296 # Run
12711297 synthesizer = PARSynthesizer (metadata = metadata , epochs = 1 )
0 commit comments