Skip to content

Commit 2102b3d

Browse files
igerberclaude
andcommitted
Implement full Wooldridge covariate basis (D_g × X, f_t × X)
Add automatic cohort × covariate (D_g × X) and time × covariate (f_t × X) interaction blocks to the ETWFE design matrix, completing the W2025 Eq. 5.3 specification. Previously only cell × demeaned-X and raw X were included, silently fitting a restricted model. D_g × X auto-generated for exovar/xtvar (xgvar already has these in _prepare_covariates). f_t × X generated for all covariates, drop first time period for identification. ASF and delta-method gradient code unchanged — new blocks are nuisance parameters not zeroed in the counterfactual. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c847eab commit 2102b3d

3 files changed

Lines changed: 151 additions & 17 deletions

File tree

diff_diff/wooldridge.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,14 +446,47 @@ def fit(
446446
cov_label = cov_names_list[j] if j < len(cov_names_list) else f"cov{j}"
447447
interact_names.append(f"{gt_name}_x_{cov_label}")
448448

449+
# Cohort × covariate interactions (W2025 Eq. 5.3: D_g × X)
450+
# exovar/xtvar get automatic D_g × X; xgvar already has D_g × X
451+
cov_cols_for_dg = list(exovar or []) + list(xtvar or [])
452+
cohort_cov_cols = []
453+
cohort_cov_names = []
454+
if cov_cols_for_dg:
455+
cohort_vals_arr = sample[cohort].values
456+
for g in groups:
457+
g_ind = (cohort_vals_arr == g).astype(float)
458+
for col in cov_cols_for_dg:
459+
cohort_cov_cols.append(g_ind * sample[col].values.astype(float))
460+
cohort_cov_names.append(f"D{g}_x_{col}")
461+
462+
# Time × covariate interactions (W2025 Eq. 5.3: f_t × X)
463+
# All covariates get f_t × X, drop first time for identification
464+
all_cov_cols = list(exovar or []) + list(xtvar or []) + list(xgvar or [])
465+
times_sorted = sorted(sample[time].unique())
466+
time_cov_cols = []
467+
time_cov_names = []
468+
time_vals_arr = sample[time].values
469+
for t in times_sorted[1:]: # drop first
470+
t_ind = (time_vals_arr == t).astype(float)
471+
for col in all_cov_cols:
472+
time_cov_cols.append(t_ind * sample[col].values.astype(float))
473+
time_cov_names.append(f"ft{t}_x_{col}")
474+
475+
# Assemble: [cell_indicators, cell×cov, D_g×X, f_t×X, raw_cov]
476+
blocks = [X_int]
449477
if interact_cols:
450-
X_interact = np.column_stack(interact_cols)
451-
X_design = np.hstack([X_int, X_interact, X_cov])
478+
blocks.append(np.column_stack(interact_cols))
452479
all_regressors.extend(interact_names)
453-
else:
454-
X_design = np.hstack([X_int, X_cov])
480+
if cohort_cov_cols:
481+
blocks.append(np.column_stack(cohort_cov_cols))
482+
all_regressors.extend(cohort_cov_names)
483+
if time_cov_cols:
484+
blocks.append(np.column_stack(time_cov_cols))
485+
all_regressors.extend(time_cov_names)
486+
blocks.append(X_cov)
455487
for i in range(X_cov.shape[1]):
456488
all_regressors.append(f"_cov_{i}")
489+
X_design = np.hstack(blocks)
457490
else:
458491
X_design = X_int
459492

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ where `g(·)` is the link inverse (logistic or exp), `η_i` is the individual li
11361136
- `exovar`: Time-invariant covariates, added without demeaning (corresponds to W2025 Eq. 5.2 `x_i`)
11371137
- `xtvar`: Time-varying covariates, demeaned within cohort×period cells when `demean_covariates=True` (corresponds to W2025 Eq. 10.2 `x_hat_itgs = x_it - x_bar_gs`)
11381138
- `xgvar`: Covariates interacted with each cohort indicator
1139-
- **Note:** Covariates are included as both main effects and treatment × demeaned-covariate interactions (W2025 Eq. 5.3), allowing ATT to vary with covariates within each (g,t) cell.
1139+
- **Note:** Covariate-adjusted ETWFE includes the full W2025 Eq. 5.3 basis: raw X, cohort × X (D_g × X for treated cohorts, auto-generated for `exovar`/`xtvar`), time × X (f_t × X, drop first period), and cell × demeaned X (D_{g,t} × X̃). Variables in `xgvar` already contribute D_g × X via `_prepare_covariates`; `exovar`/`xtvar` get automatic D_g × X generation.
11401140
- **Note:** `xtvar` demeaning operates at the cohort×period level (W2025 Eq. 10.2), not the cohort level (W2025 Eq. 5.2). These are identical for time-constant covariates but differ for time-varying covariates.
11411141

11421142
*Control groups:*

tests/test_wooldridge.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,8 @@ def test_poisson_rank_deficient_design(self):
941941
assert cell["se"] >= 0
942942

943943
def test_logit_with_covariates(self):
944-
"""Logit with covariates should produce finite ATT/SE."""
944+
"""Logit with covariates should produce finite ATT/SE and differ from
945+
no-covariate fit (confirming covariates are used)."""
945946
rng = np.random.default_rng(42)
946947
rows = []
947948
for u in range(60):
@@ -953,14 +954,20 @@ def test_logit_with_covariates(self):
953954
y = int(rng.random() < 1 / (1 + np.exp(-eta)))
954955
rows.append({"unit": u, "time": t, "cohort": cohort, "y": y, "x1": x1})
955956
df = pd.DataFrame(rows)
956-
est = WooldridgeDiD(method="logit")
957-
r = est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort", exovar=["x1"])
958-
assert np.isfinite(r.overall_att)
959-
assert np.isfinite(r.overall_se)
960-
assert r.overall_se > 0
957+
r_cov = WooldridgeDiD(method="logit").fit(
958+
df, outcome="y", unit="unit", time="time", cohort="cohort", exovar=["x1"]
959+
)
960+
r_nocov = WooldridgeDiD(method="logit").fit(
961+
df, outcome="y", unit="unit", time="time", cohort="cohort"
962+
)
963+
assert np.isfinite(r_cov.overall_att)
964+
assert np.isfinite(r_cov.overall_se)
965+
assert r_cov.overall_se > 0
966+
assert r_cov.overall_att != r_nocov.overall_att, "Covariates should affect ATT"
961967

962968
def test_poisson_with_covariates(self):
963-
"""Poisson with covariates should produce finite ATT/SE."""
969+
"""Poisson with covariates should produce finite ATT/SE and differ from
970+
no-covariate fit (confirming covariates are used)."""
964971
rng = np.random.default_rng(7)
965972
rows = []
966973
for u in range(60):
@@ -972,11 +979,16 @@ def test_poisson_with_covariates(self):
972979
y = rng.poisson(mu)
973980
rows.append({"unit": u, "time": t, "cohort": cohort, "y": float(y), "x1": x1})
974981
df = pd.DataFrame(rows)
975-
est = WooldridgeDiD(method="poisson")
976-
r = est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort", exovar=["x1"])
977-
assert np.isfinite(r.overall_att)
978-
assert np.isfinite(r.overall_se)
979-
assert r.overall_se > 0
982+
r_cov = WooldridgeDiD(method="poisson").fit(
983+
df, outcome="y", unit="unit", time="time", cohort="cohort", exovar=["x1"]
984+
)
985+
r_nocov = WooldridgeDiD(method="poisson").fit(
986+
df, outcome="y", unit="unit", time="time", cohort="cohort"
987+
)
988+
assert np.isfinite(r_cov.overall_att)
989+
assert np.isfinite(r_cov.overall_se)
990+
assert r_cov.overall_se > 0
991+
assert r_cov.overall_att != r_nocov.overall_att, "Covariates should affect ATT"
980992

981993

982994
class TestCohortTimeInvariance:
@@ -1206,3 +1218,92 @@ def test_ols_never_treated_still_has_pre_treatment(self):
12061218
# OLS never_treated should have pre-treatment cells
12071219
pre_treatment = [(g, t) for (g, t) in r.group_time_effects if t < g]
12081220
assert len(pre_treatment) > 0, "OLS never_treated lost pre-treatment placebo cells"
1221+
1222+
1223+
class TestFullCovariateBasis:
1224+
"""Regression: covariate-adjusted ETWFE includes full W2025 Eq. 5.3 basis
1225+
(D_g × X, f_t × X, D_{g,t} × X̃, raw X)."""
1226+
1227+
@pytest.fixture
1228+
def cov_data(self):
1229+
rng = np.random.RandomState(42)
1230+
rows = []
1231+
for u in range(30):
1232+
g = 3 if u < 10 else (4 if u < 20 else 0)
1233+
x1 = rng.normal()
1234+
for t in range(1, 6):
1235+
effect = 0.5 if g > 0 and t >= g else 0.0
1236+
y = rng.normal() + effect + 0.3 * x1
1237+
rows.append({"unit": u, "time": t, "cohort": g, "y": y, "x1": x1})
1238+
return pd.DataFrame(rows)
1239+
1240+
def test_ols_covariate_parity_with_full_basis_dummy_ols(self, cov_data):
1241+
"""OLS with exovar should match explicit-dummy OLS with full basis."""
1242+
from diff_diff.linalg import solve_ols
1243+
1244+
df = cov_data
1245+
r = WooldridgeDiD(control_group="not_yet_treated").fit(
1246+
df, outcome="y", unit="unit", time="time", cohort="cohort", exovar=["x1"]
1247+
)
1248+
1249+
# Build explicit-dummy regression with full basis
1250+
sample = _filter_sample(df, "unit", "time", "cohort", "not_yet_treated", 0)
1251+
X_int, _, gt_keys = _build_interaction_matrix(
1252+
sample, "cohort", "time", 0, "not_yet_treated", "ols"
1253+
)
1254+
n_int = X_int.shape[1]
1255+
x1_raw = sample["x1"].values.astype(float)
1256+
1257+
# Cell × demeaned-X interactions
1258+
groups = sorted(g for g in sample["cohort"].unique() if g > 0)
1259+
x1_demeaned = x1_raw.copy()
1260+
for g in groups:
1261+
mask = sample["cohort"].values == g
1262+
if mask.any():
1263+
x1_demeaned[mask] -= x1_raw[mask].mean()
1264+
cell_cov = np.column_stack([X_int[:, i] * x1_demeaned for i in range(n_int)])
1265+
1266+
# D_g × X (cohort × covariate)
1267+
cohort_cov = np.column_stack([
1268+
(sample["cohort"].values == g).astype(float) * x1_raw for g in groups
1269+
])
1270+
1271+
# f_t × X (time × covariate, drop first)
1272+
times = sorted(sample["time"].unique())
1273+
time_cov = np.column_stack([
1274+
(sample["time"].values == t).astype(float) * x1_raw for t in times[1:]
1275+
])
1276+
1277+
# Full design: intercept + cells + cell×cov + D_g×X + f_t×X + raw_X + unit + time dummies
1278+
unit_dummies = pd.get_dummies(sample["unit"], drop_first=True).values.astype(float)
1279+
time_dummies = pd.get_dummies(sample["time"], drop_first=True).values.astype(float)
1280+
intercept = np.ones((len(sample), 1))
1281+
X_full = np.hstack([
1282+
intercept, X_int, cell_cov, cohort_cov, time_cov,
1283+
x1_raw.reshape(-1, 1), unit_dummies, time_dummies,
1284+
])
1285+
y = sample["y"].values
1286+
coefs_dummy, _, _ = solve_ols(X_full, y, rank_deficient_action="silent")
1287+
1288+
# Compare ATT coefficients (positions 1..n_int in dummy OLS)
1289+
for i, (g, t) in enumerate(gt_keys):
1290+
if (g, t) in r.group_time_effects:
1291+
np.testing.assert_allclose(
1292+
r.group_time_effects[(g, t)]["att"],
1293+
coefs_dummy[1 + i],
1294+
atol=1e-5,
1295+
err_msg=f"Covariate ATT mismatch at cell ({g},{t})",
1296+
)
1297+
1298+
def test_covariates_affect_ols_att(self, cov_data):
1299+
"""OLS with covariates should produce different ATT than without."""
1300+
df = cov_data
1301+
r_cov = WooldridgeDiD().fit(
1302+
df, outcome="y", unit="unit", time="time", cohort="cohort", exovar=["x1"]
1303+
)
1304+
r_nocov = WooldridgeDiD().fit(
1305+
df, outcome="y", unit="unit", time="time", cohort="cohort"
1306+
)
1307+
assert r_cov.overall_att != r_nocov.overall_att, (
1308+
"Covariate-adjusted ATT should differ from unadjusted"
1309+
)

0 commit comments

Comments
 (0)