Skip to content

Commit 0983c8a

Browse files
authored
PARSynthesizer: FutureWarnings in groupby.apply and Series.__getitem__ from pandas (#2707)
1 parent 8222039 commit 0983c8a

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

sdv/sequential/par.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
import tqdm
1212
from rdt.transformers import FloatFormatter
1313

14-
from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list
14+
from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list, _is_datetime_type
1515
from sdv.cag import ProgrammableConstraint
1616
from sdv.cag._utils import _validate_constraints_single_table
17+
from sdv.constraints.utils import cast_to_datetime64
1718
from sdv.errors import SamplingError, SynthesizerInputError
1819
from sdv.metadata.errors import InvalidMetadataError
1920
from sdv.metadata.metadata import Metadata
@@ -37,6 +38,11 @@
3738
LOGGER = logging.getLogger(__name__)
3839

3940

41+
def _diff_and_bfill(series):
42+
"""Compute the diff of a pandas Series and backfill the first NaN."""
43+
return series.diff().bfill()
44+
45+
4046
class PARSynthesizer(LossValuesMixin, MissingModuleMixin, BaseSynthesizer):
4147
"""Synthesizer for sequential data.
4248
@@ -310,20 +316,25 @@ def _transform_sequence_index(self, data):
310316
sequence_index_context = sequence_index_context.rename(
311317
columns={self._sequence_index: f'{self._sequence_index}.context'}
312318
)
319+
320+
if _is_datetime_type(sequence_index[self._sequence_index]):
321+
sequence_index[self._sequence_index] = cast_to_datetime64(
322+
sequence_index[self._sequence_index]
323+
).astype(np.int64)
324+
313325
if all(sequence_index[self._sequence_key].nunique() == 1):
314-
sequence_index_sequence = sequence_index[[self._sequence_index]].diff().bfill()
326+
diff_series = sequence_index[self._sequence_index].diff().bfill()
315327
else:
316-
sequence_index_sequence = (
317-
sequence_index.groupby(self._sequence_key)
318-
.apply(lambda x: x[self._sequence_index].diff().bfill())
319-
.droplevel(1)
320-
.reset_index()
321-
)
328+
diff_series = sequence_index.groupby(self._sequence_key, group_keys=False)[
329+
self._sequence_index
330+
].transform(_diff_and_bfill)
322331

332+
sequence_index_sequence = diff_series.to_frame(name=self._sequence_index)
323333
if all(sequence_index_sequence[self._sequence_index].isna()):
324334
fill_value = 0
325335
else:
326336
fill_value = min(sequence_index_sequence[self._sequence_index].dropna())
337+
327338
sequence_index_sequence = sequence_index_sequence.fillna(fill_value)
328339

329340
data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy()
@@ -573,7 +584,7 @@ def _sample_from_par(self, context, sequence_length=None):
573584
pd.DataFrame({self._sequence_index: diffs})
574585
)[self._sequence_index].to_numpy()
575586
start_index = context_columns.index(f'{self._sequence_index}.context')
576-
start = context_values[start_index]
587+
start = context_values.iloc[start_index]
577588
sequence[sequence_index_idx] = np.cumsum(diffs) - diffs[0] + start
578589

579590
# Reformat as a DataFrame

tests/unit/sequential/test_par.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,24 @@
1414
from sdv.metadata.metadata import Metadata
1515
from sdv.metadata.single_table import SingleTableMetadata
1616
from sdv.sampling import Condition
17-
from sdv.sequential.par import PARSynthesizer
17+
from sdv.sequential.par import PARSynthesizer, _diff_and_bfill
1818
from sdv.single_table.base import BaseSynthesizer
1919
from sdv.single_table.copulas import GaussianCopulaSynthesizer
2020

2121

22+
def test__diff_and_bfill():
23+
"""Test the ``_diff_and_bfill`` method."""
24+
# Setup
25+
data = pd.Series([10, 15, 20, 30])
26+
27+
# Run
28+
result = _diff_and_bfill(data)
29+
30+
# Assert
31+
expected = pd.Series([5.0, 5.0, 5.0, 10.0])
32+
pd.testing.assert_series_equal(result, expected)
33+
34+
2235
class TestPARSynthesizer:
2336
def get_metadata(self, add_sequence_key=True, add_sequence_index=False):
2437
metadata = Metadata()
@@ -283,6 +296,7 @@ def test_validate_context_columns_unique_per_sequence_key(self):
283296
with pytest.raises(InvalidDataError, match=err_msg):
284297
instance.validate(data)
285298

299+
@pytest.mark.filterwarnings('error::FutureWarning')
286300
def test__transform_sequence(self):
287301
# Setup
288302
metadata = self.get_metadata(add_sequence_index=True)
@@ -310,6 +324,7 @@ def test__transform_sequence(self):
310324
assert list(par.extended_columns.keys()) == ['time']
311325
assert par.extended_columns['time'].enforce_min_max_values is True
312326

327+
@pytest.mark.filterwarnings('error::FutureWarning')
313328
def test__transform_sequence_index_single_instances(self):
314329
# Setup
315330
metadata = self.get_metadata(add_sequence_index=True)
@@ -332,6 +347,7 @@ def test__transform_sequence_index_single_instances(self):
332347
assert list(par.extended_columns.keys()) == ['time']
333348
assert par.extended_columns['time'].enforce_min_max_values is True
334349

350+
@pytest.mark.filterwarnings('error::FutureWarning')
335351
def test__transform_sequence_index_non_unique_sequence_key(self):
336352
# Setup
337353
metadata = self.get_metadata(add_sequence_index=True)
@@ -833,6 +849,7 @@ def test__sample_from_par_with_sequence_key(self, tqdm_mock):
833849
})
834850
pd.testing.assert_frame_equal(sampled, expected_output)
835851

852+
@pytest.mark.filterwarnings('error::FutureWarning')
836853
@patch('sdv.sequential.par.tqdm')
837854
def test__sample_from_par_with_sequence_index(self, tqdm_mock):
838855
"""Test that the method handles the sequence index properly.
@@ -1245,6 +1262,9 @@ def test_sample_with_all_null_column_categorical(self):
12451262
assert result['all_null_cat_col'].isna().all()
12461263
assert len(result) > 0
12471264

1265+
@pytest.mark.filterwarnings(
1266+
'error:Series.__getitem__ treating keys as positions is deprecated:FutureWarning'
1267+
)
12481268
def test_sample_with_multiple_all_null_columns(self):
12491269
"""Test that sampling works correctly with multiple all-null columns."""
12501270
# Setup
@@ -1257,15 +1277,21 @@ def test_sample_with_multiple_all_null_columns(self):
12571277
'all_null_col2': [np.nan] * 9,
12581278
})
12591279

1260-
metadata = Metadata()
1261-
metadata.add_table('table')
1262-
metadata.add_column('time', 'table', sdtype='datetime')
1263-
metadata.add_column('gender', 'table', sdtype='categorical')
1264-
metadata.add_column('name', 'table', sdtype='id')
1265-
metadata.add_column('measurement', 'table', sdtype='numerical')
1266-
metadata.add_column('all_null_col1', 'table', sdtype='numerical')
1267-
metadata.add_column('all_null_col2', 'table', sdtype='categorical')
1268-
metadata.set_sequence_key('name', 'table')
1280+
metadata = Metadata().load_from_dict({
1281+
'tables': {
1282+
'table': {
1283+
'columns': {
1284+
'time': {'sdtype': 'datetime'},
1285+
'gender': {'sdtype': 'categorical'},
1286+
'name': {'sdtype': 'id'},
1287+
'measurement': {'sdtype': 'numerical'},
1288+
'all_null_col1': {'sdtype': 'numerical'},
1289+
'all_null_col2': {'sdtype': 'categorical'},
1290+
},
1291+
'sequence_key': 'name',
1292+
}
1293+
}
1294+
})
12691295

12701296
# Run
12711297
synthesizer = PARSynthesizer(metadata=metadata, epochs=1)

0 commit comments

Comments
 (0)