Skip to content

Commit cbb8814

Browse files
authored
Merge pull request #339 from igerber/fix/axis-cj-closeouts
Close axis-C/J silent-failures audit: B-spline derivative + PA survey cache
2 parents 2e9447e + 0cf60dd commit cbb8814

5 files changed

Lines changed: 240 additions & 13 deletions

File tree

diff_diff/continuous_did_bspline.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
the dose-response curve estimation in ContinuousDiD.
66
"""
77

8+
import warnings
9+
810
import numpy as np
911
from scipy.interpolate import BSpline
1012

@@ -140,9 +142,12 @@ def bspline_derivative_design_matrix(x, knots, degree, include_intercept=True):
140142

141143
# Check if knot vector is degenerate (all identical, e.g. single dose)
142144
if knots[0] == knots[-1]:
143-
# All knots identical: derivatives are all zero
145+
# All knots identical: derivatives are all zero — this is a
146+
# mathematically well-defined degenerate case (single dose value
147+
# means no dose variation to differentiate), handled silently.
144148
pass
145149
else:
150+
failed_basis_indices = []
146151
for j in range(n_basis):
147152
c = np.zeros(n_basis)
148153
c[j] = 1.0
@@ -151,8 +156,29 @@ def bspline_derivative_design_matrix(x, knots, degree, include_intercept=True):
151156
deriv_j = spline_j.derivative()
152157
dB[:, j] = deriv_j(x_clamped)
153158
except ValueError:
154-
# Degenerate knot vector: derivative is zero
155-
pass
159+
# Finding #12 (axis C, silent-failures audit): silent pass
160+
# on ValueError meant a malformed knot vector (too few
161+
# knots for the degree, non-monotonic, etc.) quietly set
162+
# whole columns of the derivative design matrix to zero.
163+
# Downstream ContinuousDiD inference then used a silently
164+
# biased dPsi matrix. Track affected basis indices so we
165+
# can surface ONE aggregate warning.
166+
failed_basis_indices.append(j)
167+
168+
if failed_basis_indices:
169+
warnings.warn(
170+
f"B-spline derivative construction failed for "
171+
f"{len(failed_basis_indices)} of {n_basis} basis function(s) "
172+
f"(indices {failed_basis_indices}); their derivative columns "
173+
f"are zero. This typically indicates a malformed knot vector "
174+
f"(too few knots for the chosen degree, non-monotonic, or "
175+
f"repeated interior knots). Both ACRT point estimates and "
176+
f"analytical/bootstrap inference depend on this derivative "
177+
f"matrix, so both may be biased. Consider increasing the "
178+
f"number of distinct doses or reducing the B-spline degree.",
179+
UserWarning,
180+
stacklevel=2,
181+
)
156182

157183
if include_intercept:
158184
# Drop first column (intercept derivative = 0), prepend zeros

diff_diff/power.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,24 @@ def __post_init__(self) -> None:
167167
)
168168

169169
def _build_survey_design(self) -> Any:
170-
"""Return cached SurveyDesign (built once, reused across simulations)."""
171-
if not hasattr(self, "_cached_survey_design"):
172-
if self.survey_design is not None:
173-
self._cached_survey_design = self.survey_design
174-
else:
175-
from diff_diff.survey import SurveyDesign
170+
"""Return a SurveyDesign for this config.
171+
172+
Reflects the live ``self.survey_design`` value every call (no
173+
caching). Finding #28 (axis J, silent-failures audit): the
174+
previous ``_cached_survey_design`` was populated on first call
175+
and never invalidated on mutation, so ``config.survey_design =
176+
other_design`` silently kept returning the original. Since the
177+
default ``SurveyDesign(...)`` construction is microseconds and
178+
user-provided designs are just reference copies, there's no cache
179+
cost worth keeping.
180+
"""
181+
if self.survey_design is not None:
182+
return self.survey_design
183+
from diff_diff.survey import SurveyDesign
176184

177-
self._cached_survey_design = SurveyDesign(
178-
weights="weight", strata="stratum", psu="psu", fpc="fpc"
179-
)
180-
return self._cached_survey_design
185+
return SurveyDesign(
186+
weights="weight", strata="stratum", psu="psu", fpc="fpc"
187+
)
181188

182189
@property
183190
def min_viable_n(self) -> int:

docs/methodology/REGISTRY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ See `docs/methodology/continuous-did.md` Section 4 for full details.
723723
not-yet-treated controls. When `anticipation=0` (default), behavior is
724724
unchanged.
725725
- **Boundary knots**: Knots are built once from all treated doses (global, not per-cell) to ensure a common basis across (g,t) cells for aggregation. Evaluation grid is clamped to training-dose boundary knots (`range(dose)`). R's `contdid` v0.1.0 has an inconsistency where `splines2::bSpline(dvals)` uses `range(dvals)` instead of `range(dose)`, which can produce extrapolation artifacts at dose grid extremes. Our approach avoids extrapolation and is methodologically sound.
726+
- **Note:** `bspline_derivative_design_matrix` previously swallowed `ValueError` from `scipy.interpolate.BSpline` in the per-basis derivative loop, leaving affected columns of the derivative design matrix as zero with no user-facing signal. It now aggregates the failed basis indices and emits ONE `UserWarning` naming them. Both ACRT point estimates and analytical/bootstrap inference read the same `dPsi` matrix (see `continuous_did.py:1026-1046` and the bootstrap ACRT path at `continuous_did.py:1524-1561`), so both are biased on a partial derivative-construction failure — the warning wording makes that explicit. The all-identical-knot degenerate case (single dose value) remains silently handled — derivatives there are mathematically zero. Axis-C finding #12 in the Phase 2 silent-failures audit.
726727

727728
### Implementation Checklist
728729

@@ -2582,6 +2583,7 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE²
25822583
- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. When rounding occurs, all result objects (`SimulationPowerResults`, `SimulationMDEResults`, `SimulationSampleSizeResults`) set `effective_n_units` to the actual sample size used; it is `None` when no rounding occurred. `simulate_sample_size()` snaps bisection candidates to multiples of 8 so that `required_n` is always a realizable DDD sample size. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`).
25832584
- **Note:** The analytical power methods (`PowerAnalysis.power/mde/sample_size` and the `compute_power/compute_mde/compute_sample_size` convenience functions) accept a `deff` parameter (survey design effect, default 1.0). This inflates variance multiplicatively: `Var(ATT) *= deff`, and inflates required sample size: `n_total *= deff`. The `deff` parameter is **not redundant** with `rho` (intra-cluster correlation): `rho` models within-unit serial correlation in panel data via the Moulton factor `1 + (T-1)*rho`, while `deff` models the survey design effect from stratified multi-stage sampling (clustering + unequal weighting). A survey panel study may need both. Values `deff > 0` are accepted; `deff < 1.0` (net variance reduction, e.g., from stratification gain) emits a warning.
25842585
- **Note:** `simulate_power()` catches a narrow set of exception types — `ValueError`, `numpy.linalg.LinAlgError`, `KeyError`, `RuntimeError`, `ZeroDivisionError` — raised inside the per-simulation fit and result-extraction block, increments a per-effect failure counter, and skips the replicate. Programming errors (`TypeError`, `AttributeError`, `NameError`, `IndexError`, etc.) are allowed to propagate so that bugs in the estimator or custom result extractor surface loudly instead of being absorbed as simulation failures. The primary-effect failure count is surfaced on the result object as `SimulationPowerResults.n_simulation_failures`; a `UserWarning` still fires when the failure rate exceeds 10% for any effect size, and all-failed runs raise `RuntimeError`. This replaces the prior bare `except Exception` that swallowed root causes and kept the counter internal to the function (axis C — silent fallback — under the Phase 2 audit).
2586+
- **Note:** `SurveyPowerConfig._build_survey_design()` no longer caches its return value in `self._cached_survey_design`. Reassigning `config.survey_design` (either replacing a user-supplied `SurveyDesign` with another, or toggling between `None` and a user-supplied design) after the first call used to silently return the stale cached design; the method now returns the live `self.survey_design` (or the default construction when `None`) every call. Other config fields (`n_strata`, `icc`, `weight_variation`, etc.) never influenced the returned design, so the staleness surface was specifically `survey_design` reassignment. Construction is microseconds — the cache never earned its complexity. Axis-J finding #28 in the Phase 2 silent-failures audit.
25852587
- **Note:** The simulation-based power functions (`simulate_power/simulate_mde/simulate_sample_size`) accept a `survey_config` parameter (`SurveyPowerConfig` dataclass). When set, the simulation loop uses `generate_survey_did_data` instead of the default registry DGP, and automatically injects `SurveyDesign(weights="weight", strata="stratum", psu="psu", fpc="fpc")` into the estimator's `fit()` call. Supported estimators: DifferenceInDifferences, TwoWayFixedEffects, MultiPeriodDiD, CallawaySantAnna, SunAbraham, ImputationDiD, TwoStageDiD, StackedDiD, EfficientDiD. Unsupported (raises `ValueError`): TROP, SyntheticDiD, TripleDifference (generate_survey_did_data produces staggered cohort data incompatible with factor-model/DDD DGPs). `survey_config` and `data_generator` are mutually exclusive. `data_generator_kwargs` may not contain keys managed by `SurveyPowerConfig` (n_strata, psu_per_stratum, etc.) but may contain passthrough DGP params (unit_fe_sd, add_covariates, strata_sizes). Repeated cross-section survey power (`panel=False`) is only supported for `CallawaySantAnna(panel=False)` with a matching `data_generator_kwargs={"panel": False}`; both mismatch directions are rejected. `estimator_kwargs` may not contain `survey_design` when `survey_config` is set (use `SurveyPowerConfig(survey_design=...)` instead). Estimator settings that require a multi-cohort DGP (`control_group="not_yet_treated"`, `control_group="last_cohort"`, `clean_control="strict"`) are rejected because the survey DGP uses a single cohort; use the custom `data_generator` path for these configurations. `simulate_sample_size` raises the bisection floor to `n_strata * psu_per_stratum * 2` to ensure viable survey structure and rejects `strata_sizes` in `data_generator_kwargs` (it depends on `n_units` which varies during bisection).
25862588

25872589
**Reference implementation(s):**

tests/test_continuous_did.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,132 @@ def test_linear_basis(self):
102102
assert B.shape[1] == 2 # intercept + 1 basis fn
103103

104104

105+
# ---------------------------------------------------------------------------
106+
# Finding #12 (axis C, silent-failures audit). Previously
107+
# `bspline_derivative_design_matrix` silently swallowed ValueError in the
108+
# per-basis derivative loop, leaving affected columns of the derivative
109+
# design matrix as zero with no user-visible signal. ContinuousDiD's
110+
# analytical inference then fed a biased dPsi into downstream SE
111+
# computation. The fix aggregates failed-basis indices and emits ONE
112+
# UserWarning naming them.
113+
# ---------------------------------------------------------------------------
114+
115+
116+
class TestBSplineDerivativeDegenerateBasis:
117+
def test_single_dose_is_silent(self):
118+
"""All-identical knots (single dose value) is a well-defined
119+
degenerate case — derivatives are mathematically zero and the
120+
function returns silently. Regression-guard the existing contract."""
121+
x = np.array([3.0, 3.0, 3.0, 3.0])
122+
knots = np.array([3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]) # all identical
123+
import warnings as _w
124+
125+
with _w.catch_warnings(record=True) as caught:
126+
_w.simplefilter("always")
127+
dB = bspline_derivative_design_matrix(x, knots, degree=3, include_intercept=True)
128+
deriv_warnings = [
129+
w for w in caught if "B-spline derivative construction failed" in str(w.message)
130+
]
131+
assert deriv_warnings == [], (
132+
"All-identical knots should be handled silently (mathematically "
133+
"well-defined zero-derivative case); warning fired unexpectedly: "
134+
f"{[str(w.message) for w in deriv_warnings]}"
135+
)
136+
np.testing.assert_array_equal(dB, np.zeros_like(dB))
137+
138+
def test_valueerror_from_bspline_emits_aggregate_warning(self):
139+
"""When BSpline construction raises ValueError for some basis
140+
functions (malformed knot vector, etc.), the new aggregate
141+
UserWarning must fire naming the affected indices."""
142+
from unittest.mock import patch
143+
144+
import diff_diff.continuous_did_bspline as bspline_mod
145+
146+
dose = np.linspace(1, 5, 30)
147+
knots, deg = build_bspline_basis(dose, degree=3, num_knots=1)
148+
x = np.linspace(1.5, 4.5, 20)
149+
150+
# Force ValueError on basis indices 1 and 3 only; the rest run
151+
# through normally. This is the partial-failure mode the audit
152+
# called out.
153+
real_bspline = bspline_mod.BSpline
154+
call_counter = {"n": 0}
155+
156+
def flaky_bspline(knots, c, degree):
157+
# c is a one-hot vector; the index set to 1 is the basis j
158+
j = int(np.argmax(c))
159+
call_counter["n"] += 1
160+
if j in (1, 3):
161+
raise ValueError(f"forced test failure for basis j={j}")
162+
return real_bspline(knots, c, degree)
163+
164+
import warnings as _w
165+
166+
with patch.object(bspline_mod, "BSpline", side_effect=flaky_bspline):
167+
with _w.catch_warnings(record=True) as caught:
168+
_w.simplefilter("always")
169+
dB = bspline_derivative_design_matrix(
170+
x, knots, degree=deg, include_intercept=True
171+
)
172+
173+
deriv_warnings = [
174+
w for w in caught if "B-spline derivative construction failed" in str(w.message)
175+
]
176+
assert len(deriv_warnings) == 1, (
177+
f"Expected exactly one aggregate warning, got {len(deriv_warnings)}: "
178+
f"{[str(w.message) for w in deriv_warnings]}"
179+
)
180+
msg = str(deriv_warnings[0].message)
181+
# Message must name the failed basis indices so the user can debug.
182+
assert "[1, 3]" in msg, f"Expected indices [1, 3] in warning; got: {msg}"
183+
assert "2 of" in msg, f"Expected failure count '2 of ...' in warning; got: {msg}"
184+
# Affected columns should be zero.
185+
# With include_intercept=True, column 0 is always zero (intercept
186+
# derivative) and basis index j is at dB column j (the drop-first
187+
# then prepend-zeros logic keeps the same per-j mapping for j>=1).
188+
np.testing.assert_array_equal(dB[:, 1], np.zeros(len(x))) # failed basis j=1
189+
np.testing.assert_array_equal(dB[:, 3], np.zeros(len(x))) # failed basis j=3
190+
191+
# Unaffected columns must match the un-patched baseline exactly
192+
# (except columns 1 and 3 which were forced to zero). This guards
193+
# a regression that would zero or corrupt the entire derivative
194+
# matrix on any ValueError.
195+
dB_baseline = bspline_derivative_design_matrix(
196+
x, knots, degree=deg, include_intercept=True
197+
)
198+
for col in range(dB.shape[1]):
199+
if col in (1, 3):
200+
continue
201+
np.testing.assert_array_equal(
202+
dB[:, col],
203+
dB_baseline[:, col],
204+
err_msg=f"Unaffected column {col} diverges from baseline",
205+
)
206+
# At least one non-intercept, non-failed column must be non-zero,
207+
# confirming the function still produces meaningful derivatives.
208+
non_failed_cols = [c for c in range(1, dB.shape[1]) if c not in (1, 3)]
209+
assert any(np.any(dB[:, c] != 0) for c in non_failed_cols), (
210+
"Expected at least one unaffected non-intercept column to have "
211+
"non-zero derivatives; got all-zero dB outside failed cols."
212+
)
213+
214+
def test_clean_knots_emit_no_warning(self):
215+
"""Well-formed knot vector → no ValueError path taken → no
216+
warning. Regression-guard the happy path."""
217+
dose = np.linspace(1, 5, 50)
218+
knots, deg = build_bspline_basis(dose, degree=3, num_knots=2)
219+
x = np.linspace(1.5, 4.5, 30)
220+
import warnings as _w
221+
222+
with _w.catch_warnings(record=True) as caught:
223+
_w.simplefilter("always")
224+
bspline_derivative_design_matrix(x, knots, deg, include_intercept=True)
225+
deriv_warnings = [
226+
w for w in caught if "B-spline derivative construction failed" in str(w.message)
227+
]
228+
assert deriv_warnings == []
229+
230+
105231
class TestDoseGrid:
106232
"""Test dose grid computation."""
107233

0 commit comments

Comments
 (0)