Skip to content

Commit 3ce5d33

Browse files
committed
test ValueException when validation_time > treatment_time
1 parent bb0b937 commit 3ce5d33

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,27 @@ def test_its(validation_time):
349349
result.summary()
350350

351351

352+
def test_its_with_invalid_validation_time():
353+
"""
354+
Test that we get a ValueError when validation_time is greater than validation_time.
355+
"""
356+
df = (
357+
cp.load_data("its")
358+
.assign(date=lambda x: pd.to_datetime(x["date"]))
359+
.set_index("date")
360+
)
361+
treatment_time = pd.to_datetime("2017-01-01")
362+
validation_time = pd.to_datetime("2018-01-01")
363+
with pytest.raises(ValueError):
364+
_ = cp.pymc_experiments.InterruptedTimeSeries(
365+
df,
366+
treatment_time,
367+
validation_time=validation_time,
368+
formula="y ~ 1 + t + C(month)",
369+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
370+
)
371+
372+
352373
@pytest.mark.integration
353374
def test_its_covid():
354375
"""
@@ -409,6 +430,23 @@ def test_sc(validation_time):
409430
result.summary()
410431

411432

433+
def test_sc_with_invalid_validation_time():
434+
"""
435+
Test that we get a ValueError when validation_time is greater than validation_time.
436+
"""
437+
df = cp.load_data("sc")
438+
treatment_time = 70
439+
validation_time = 80
440+
with pytest.raises(ValueError):
441+
_ = cp.pymc_experiments.SyntheticControl(
442+
df,
443+
treatment_time,
444+
validation_time=validation_time,
445+
formula="actual ~ 0 + a + b + c + d + e + f + g",
446+
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
447+
)
448+
449+
412450
@pytest.mark.integration
413451
def test_sc_brexit():
414452
"""

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)