Skip to content

Commit a5c080c

Browse files
igerberclaude
andcommitted
Address review P2s: snippet guard, pickle snapshot drop
- Practitioner LOO snippet now guards on `_loo_unit_ids` availability, not just variance_method. Single-treated-unit and single-effective- control designs legitimately return empty jackknife output; a user copy-pasting the prior snippet would hit a ValueError from get_loo_effects_df(). The else branch now describes the actual requirements. - SyntheticDiDResults.__getstate__ drops _fit_snapshot on pickle so generic pickle.dumps() no longer carries outcome matrices, unit IDs, or survey weights to wherever the pickled bytes are sent. The live session is unaffected; unpickled results raise the existing "re-fit to enable" message from in_time_placebo / sensitivity_to_zeta_omega. Tests cover: - Snippet degrades gracefully on single-treated jackknife fit (exec()'s the snippet, asserts the "LOO not available" message). - Pickle round-trip drops the snapshot and leaves public fields intact. - in_time_placebo / sensitivity_to_zeta_omega raise after unpickle. - __getstate__ does not mutate the live instance's snapshot. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0843481 commit a5c080c

3 files changed

Lines changed: 105 additions & 2 deletions

File tree

diff_diff/practitioner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,15 @@ def _handle_synthetic(results: Any):
545545
"which units move the ATT the most."
546546
),
547547
code=(
548-
"if results.variance_method == 'jackknife':\n"
548+
"# Requires variance_method='jackknife' AND enough support for LOO\n"
549+
"# (n_treated >= 2 and >= 2 effective-weight controls).\n"
550+
"if getattr(results, '_loo_unit_ids', None) is not None:\n"
549551
" loo_df = results.get_loo_effects_df()\n"
550552
" print(loo_df.head(10))\n"
551553
"else:\n"
552-
" print('Re-fit with variance_method=\"jackknife\" to see LOO.')"
554+
" print('LOO not available - re-fit with '\n"
555+
" 'variance_method=\"jackknife\" and ensure >=2 treated units '\n"
556+
" 'with positive effective support.')"
553557
),
554558
priority="medium",
555559
step_name="sensitivity",

diff_diff/results.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,23 @@ def __repr__(self) -> str:
793793
f"p={self.p_value:.4f})"
794794
)
795795

796+
def __getstate__(self) -> Dict[str, Any]:
797+
"""Exclude the internal fit snapshot from pickling.
798+
799+
The snapshot retains outcome matrices, unit IDs, and survey weights
800+
to support post-hoc diagnostics (`in_time_placebo`,
801+
`sensitivity_to_zeta_omega`). Serialization would otherwise carry
802+
that panel state to wherever the pickle is sent, which is a privacy
803+
hazard for survey-weighted or sensitive fits.
804+
805+
Unpickled results keep the public fields (ATT, weights, trajectories,
806+
etc.); calling a diagnostic method that needs the snapshot raises a
807+
ValueError directing the user to re-fit.
808+
"""
809+
state = self.__dict__.copy()
810+
state["_fit_snapshot"] = None
811+
return state
812+
796813
@property
797814
def coef_var(self) -> float:
798815
"""Coefficient of variation: SE / |ATT|. NaN when ATT is 0 or SE non-finite."""

tests/test_methodology_sdid.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,3 +1994,85 @@ def test_snippets_parse_as_python(self):
19941994
steps, _ = _handle_synthetic(res)
19951995
for step in steps:
19961996
ast.parse(step["code"])
1997+
1998+
def test_jackknife_loo_snippet_handles_unavailable_loo(self):
1999+
"""When variance_method='jackknife' but LOO is unavailable
2000+
(e.g., n_treated=1 returns empty jackknife array), the LOO snippet
2001+
should degrade gracefully instead of raising."""
2002+
from diff_diff.practitioner import _handle_synthetic
2003+
2004+
df = _make_panel(n_control=10, n_treated=1, n_pre=5, n_post=3, seed=97)
2005+
with warnings.catch_warnings():
2006+
warnings.simplefilter("ignore")
2007+
sdid = SyntheticDiD(variance_method="jackknife", seed=97)
2008+
res = sdid.fit(df, outcome="outcome", treatment="treated",
2009+
unit="unit", time="period",
2010+
post_periods=list(range(5, 8)))
2011+
assert res.variance_method == "jackknife"
2012+
assert res._loo_unit_ids is None # LOO intentionally unavailable
2013+
2014+
steps, _ = _handle_synthetic(res)
2015+
loo_snippet = next(
2016+
s["code"] for s in steps if "get_loo_effects_df" in s["code"]
2017+
)
2018+
# Executing the snippet against this result must not raise.
2019+
import io
2020+
import contextlib
2021+
captured = io.StringIO()
2022+
with contextlib.redirect_stdout(captured):
2023+
exec(loo_snippet, {"results": res})
2024+
assert "LOO not available" in captured.getvalue()
2025+
2026+
2027+
class TestSyntheticDiDResultsPickle:
2028+
"""Pickle round-trip drops the fit snapshot; diagnostic methods raise
2029+
with the documented recovery message."""
2030+
2031+
def _fit(self, seed=101):
2032+
df = _make_panel(seed=seed)
2033+
sdid = SyntheticDiD(variance_method="jackknife", seed=seed)
2034+
return sdid.fit(df, outcome="outcome", treatment="treated",
2035+
unit="unit", time="period")
2036+
2037+
def test_snapshot_dropped_on_pickle(self):
2038+
import pickle
2039+
2040+
res = self._fit()
2041+
assert res._fit_snapshot is not None # present pre-pickle
2042+
2043+
restored = pickle.loads(pickle.dumps(res))
2044+
assert restored._fit_snapshot is None
2045+
# Public fields survive
2046+
assert restored.att == res.att
2047+
assert restored.se == res.se
2048+
assert np.allclose(
2049+
restored.synthetic_pre_trajectory, res.synthetic_pre_trajectory
2050+
)
2051+
2052+
def test_in_time_placebo_raises_after_pickle(self):
2053+
import pickle
2054+
2055+
res = self._fit(seed=103)
2056+
restored = pickle.loads(pickle.dumps(res))
2057+
with pytest.raises(ValueError, match="fit snapshot"):
2058+
restored.in_time_placebo()
2059+
2060+
def test_sensitivity_raises_after_pickle(self):
2061+
import pickle
2062+
2063+
res = self._fit(seed=105)
2064+
restored = pickle.loads(pickle.dumps(res))
2065+
with pytest.raises(ValueError, match="fit snapshot"):
2066+
restored.sensitivity_to_zeta_omega()
2067+
2068+
def test_live_instance_snapshot_untouched_by_getstate(self):
2069+
"""__getstate__ must not mutate the live object's snapshot —
2070+
only the returned state dict carries the nulled field."""
2071+
res = self._fit(seed=107)
2072+
snap_before = res._fit_snapshot
2073+
assert snap_before is not None
2074+
_ = res.__getstate__()
2075+
# Live instance unchanged after __getstate__ call
2076+
assert res._fit_snapshot is snap_before
2077+
# Diagnostics still work in the live session
2078+
_ = res.in_time_placebo(fake_treatment_periods=[2])

0 commit comments

Comments
 (0)