Skip to content

Commit 4c2e530

Browse files
igerberclaude
andcommitted
Address PR #105 feedback: fix Webb weights, lazy R fixture, test perf
- Fix Rust Webb bootstrap weights to match NumPy implementation: - Correct values: ±√(3/2), ±1, ±√(1/2) (was using wrong values) - Correct probabilities: [1,2,3,3,2,1]/12 (was uniform) - Add 3 Rust unit tests for Webb weight verification - Both backends now produce variance ≈ 0.833 - Add lazy R availability fixture to avoid import-time latency: - New tests/conftest.py with session-scoped r_available fixture - Support DIFF_DIFF_R=skip environment variable - Test collection now completes in <1s (was ~2s with subprocess) - Improve test performance: - Add @pytest.mark.slow marker for thorough bootstrap tests - Reduce bootstrap iterations from 199 to 99 where sufficient - Add slow marker definition to pyproject.toml - Documentation updates: - METHODOLOGY_REVIEW.md: Correct Webb variance to 0.833 - TODO.md: Log 7 pre-existing NaN handling issues - CLAUDE.md: Document Rust test troubleshooting (PyO3 linking) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 71110ea commit 4c2e530

7 files changed

Lines changed: 253 additions & 38 deletions

File tree

CLAUDE.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,36 @@ DIFF_DIFF_BACKEND=rust pytest
5353
pytest tests/test_rust_backend.py -v
5454
```
5555

56+
#### Troubleshooting Rust Tests (PyO3 Linking)
57+
58+
If `cargo test` fails with `library 'pythonX.Y' not found`, PyO3 cannot find the Python library. This commonly happens on macOS when using the system Python (which lacks development headers in expected locations).
59+
60+
**Solution**: Use a Python environment with proper library paths (e.g., conda, Homebrew, or pyenv):
61+
62+
```bash
63+
# Using miniconda (example path - adjust for your system)
64+
cd rust
65+
PYO3_PYTHON=/path/to/miniconda3/bin/python3 \
66+
DYLD_LIBRARY_PATH="/path/to/miniconda3/lib" \
67+
cargo test
68+
69+
# Using Homebrew Python
70+
PYO3_PYTHON=/opt/homebrew/bin/python3 \
71+
DYLD_LIBRARY_PATH="/opt/homebrew/lib" \
72+
cargo test
73+
```
74+
75+
**Environment variables:**
76+
- `PYO3_PYTHON`: Path to Python interpreter with development headers
77+
- `DYLD_LIBRARY_PATH` (macOS) / `LD_LIBRARY_PATH` (Linux): Path to `libpythonX.Y.dylib`/`.so`
78+
79+
**Verification**: All 22 Rust tests should pass, including bootstrap weight tests:
80+
```
81+
test bootstrap::tests::test_webb_variance_approx_correct ... ok
82+
test bootstrap::tests::test_webb_values_correct ... ok
83+
test bootstrap::tests::test_webb_mean_approx_zero ... ok
84+
```
85+
5686
## Architecture
5787

5888
### Module Structure

METHODOLOGY_REVIEW.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ Each estimator in diff-diff should be periodically reviewed to ensure:
148148
**Deviations from R's did::att_gt():**
149149
1. **NaN for invalid inference**: When SE is non-finite or zero, Python returns NaN for
150150
t_stat/p_value rather than potentially erroring. This is a defensive enhancement.
151-
2. **Webb weights variance**: Webb's 6-point distribution has Var(w) ≈ 0.72, not 1.0.
152-
This is the correct theoretical variance for this distribution.
151+
2. **Webb weights variance**: Webb's 6-point distribution with values ±√(3/2), ±1, ±√(1/2)
152+
and probabilities [1,2,3,3,2,1]/12 has Var(w) ≈ 0.833 (=10/12), not 1.0.
153+
This is the correct theoretical variance matching the NumPy and Rust implementations.
153154

154155
---
155156

TODO.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ Target: < 1000 lines per module for maintainability.
5252
| `pretrends.py` | 1160 | Acceptable |
5353
| `bacon.py` | 1027 | OK |
5454

55+
### NaN Handling for Undefined t-statistics
56+
57+
Several estimators return `0.0` for t-statistic when SE is 0 or undefined. This is incorrect—a t-stat of 0 implies a null effect, whereas `np.nan` correctly indicates undefined inference.
58+
59+
**Pattern to fix**: `t_stat = effect / se if se > 0 else 0.0``t_stat = effect / se if se > 0 else np.nan`
60+
61+
| Location | Line | Current Code |
62+
|----------|------|--------------|
63+
| `diagnostics.py` | 665 | `t_stat = original_att / se if se > 0 else 0.0` |
64+
| `diagnostics.py` | 786 | `t_stat = mean_effect / se if se > 0 else 0.0` |
65+
| `sun_abraham.py` | 603 | `overall_t = overall_att / overall_se if overall_se > 0 else 0.0` |
66+
| `sun_abraham.py` | 626 | `overall_t = overall_att / overall_se if overall_se > 0 else 0.0` |
67+
| `sun_abraham.py` | 643 | `eff_val / se_val if se_val > 0 else 0.0` |
68+
| `sun_abraham.py` | 881 | `t_stat = agg_effect / agg_se if agg_se > 0 else 0.0` |
69+
| `triple_diff.py` | 601 | `t_stat = att / se if se > 0 else 0.0` |
70+
71+
**Priority**: Medium - affects inference reporting in edge cases.
72+
73+
**Note**: CallawaySantAnna was fixed in PR #97 to use `np.nan`. These other estimators should follow the same pattern.
74+
75+
---
76+
5577
### Standard Error Consistency
5678

5779
Different estimators compute SEs differently. Consider unified interface.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ python-packages = ["diff_diff"]
7171
testpaths = ["tests"]
7272
python_files = "test_*.py"
7373
addopts = "-v --tb=short"
74+
markers = [
75+
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
76+
]
7477

7578
[tool.black]
7679
line-length = 100

rust/src/bootstrap.rs

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,29 @@ fn generate_mammen_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array
118118
/// Six-point distribution that matches additional moments:
119119
/// E[w] = 0, E[w²] = 1, E[w³] = 0, E[w⁴] = 1
120120
///
121-
/// Values: ±√(3/2), ±√(1/2), ±√(1/6) with specific probabilities
121+
/// Values: ±√(3/2), ±√(2/2)=±1, ±√(1/2) with probabilities [1,2,3,3,2,1]/12
122+
/// This matches the NumPy implementation in staggered_bootstrap.py
122123
fn generate_webb_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2<f64> {
123-
// Webb 6-point values
124-
let val1 = (3.0_f64 / 2.0).sqrt(); // √(3/2) ≈ 1.225
125-
let val2 = (1.0_f64 / 2.0).sqrt(); // √(1/2) ≈ 0.707
126-
let val3 = (1.0_f64 / 6.0).sqrt(); // √(1/6) ≈ 0.408
124+
// Webb 6-point values (matching NumPy implementation)
125+
let val1 = (3.0_f64 / 2.0).sqrt(); // √(3/2) ≈ 1.2247
126+
let val2 = 1.0_f64; // √(2/2) = 1.0
127+
let val3 = (1.0_f64 / 2.0).sqrt(); // √(1/2) ≈ 0.7071
127128

128-
// Lookup table for direct index computation (replaces 6-way if-else)
129-
// Equal probability: u in [0, 1/6) -> -val1, [1/6, 2/6) -> -val2, etc.
129+
// Values in order: -val1, -val2, -val3, val3, val2, val1
130130
let weights_table = [-val1, -val2, -val3, val3, val2, val1];
131131

132+
// Cumulative probabilities for [1,2,3,3,2,1]/12
133+
// Probs: 1/12, 2/12, 3/12, 3/12, 2/12, 1/12
134+
// Cumulative: 1/12, 3/12, 6/12, 9/12, 11/12, 12/12
135+
let cum_probs = [
136+
1.0 / 12.0, // P(bucket 0) = 1/12
137+
3.0 / 12.0, // P(bucket <= 1) = 3/12
138+
6.0 / 12.0, // P(bucket <= 2) = 6/12 = 0.5
139+
9.0 / 12.0, // P(bucket <= 3) = 9/12 = 0.75
140+
11.0 / 12.0, // P(bucket <= 4) = 11/12
141+
// bucket 5 is implicit (u >= 11/12)
142+
];
143+
132144
// Pre-allocate output array - eliminates double allocation
133145
let mut weights = Array2::<f64>::zeros((n_bootstrap, n_units));
134146

@@ -142,9 +154,20 @@ fn generate_webb_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2<
142154
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(i as u64));
143155
for elem in row.iter_mut() {
144156
let u = rng.gen::<f64>();
145-
// Direct bucket computation: multiply by 6 and floor to get index 0-5
146-
// Clamp to 5 to handle edge case where u == 1.0
147-
let bucket = ((u * 6.0).floor() as usize).min(5);
157+
// Find bucket using cumulative probabilities
158+
let bucket = if u < cum_probs[0] {
159+
0
160+
} else if u < cum_probs[1] {
161+
1
162+
} else if u < cum_probs[2] {
163+
2
164+
} else if u < cum_probs[3] {
165+
3
166+
} else if u < cum_probs[4] {
167+
4
168+
} else {
169+
5
170+
};
148171
*elem = weights_table[bucket];
149172
}
150173
});
@@ -225,4 +248,59 @@ mod tests {
225248
// Different seeds should produce different results
226249
assert_ne!(weights1, weights2);
227250
}
251+
252+
#[test]
253+
fn test_webb_mean_approx_zero() {
254+
let weights = generate_webb_batch(10000, 1, 42);
255+
let mean: f64 = weights.iter().sum::<f64>() / weights.len() as f64;
256+
257+
// With 10000 samples, mean should be close to 0
258+
assert!(
259+
mean.abs() < 0.1,
260+
"Webb mean should be close to 0, got {}",
261+
mean
262+
);
263+
}
264+
265+
#[test]
266+
fn test_webb_variance_approx_correct() {
267+
// Webb's 6-point distribution with values ±√(3/2), ±1, ±√(1/2)
268+
// and probabilities [1,2,3,3,2,1]/12 should have variance close to
269+
// the theoretical value of 10/12 ≈ 0.833
270+
let weights = generate_webb_batch(10000, 100, 42);
271+
let n = weights.len() as f64;
272+
let mean: f64 = weights.iter().sum::<f64>() / n;
273+
let variance: f64 = weights.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
274+
275+
// Theoretical variance = 2 * (1/12 * 3/2 + 2/12 * 1 + 3/12 * 1/2) = 10/12 ≈ 0.833
276+
// Allow some statistical variance in the estimate
277+
assert!(
278+
(variance - 0.833).abs() < 0.05,
279+
"Webb variance should be ~0.833 (matching NumPy), got {}",
280+
variance
281+
);
282+
}
283+
284+
#[test]
285+
fn test_webb_values_correct() {
286+
// Verify that Webb weights only take the expected 6 values
287+
let weights = generate_webb_batch(100, 1000, 42);
288+
289+
let val1 = (3.0_f64 / 2.0).sqrt(); // ≈ 1.2247
290+
let val2 = 1.0_f64;
291+
let val3 = (1.0_f64 / 2.0).sqrt(); // ≈ 0.7071
292+
293+
let expected_values = [-val1, -val2, -val3, val3, val2, val1];
294+
295+
for w in weights.iter() {
296+
let matches_expected = expected_values
297+
.iter()
298+
.any(|&expected| (*w - expected).abs() < 1e-10);
299+
assert!(
300+
matches_expected,
301+
"Webb weight {} is not one of the expected values",
302+
w
303+
);
304+
}
305+
}
228306
}

tests/conftest.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""
2+
Pytest configuration and shared fixtures for diff-diff tests.
3+
4+
This module provides shared fixtures including lazy R availability checking
5+
to avoid import-time subprocess latency.
6+
"""
7+
8+
import os
9+
import subprocess
10+
11+
import pytest
12+
13+
14+
# =============================================================================
15+
# R Availability Fixtures (Lazy Loading)
16+
# =============================================================================
17+
18+
_r_available_cache = None
19+
20+
21+
def _check_r_available() -> bool:
22+
"""
23+
Check if R and the did package are available (cached).
24+
25+
This is called lazily when the r_available fixture is first used,
26+
not at module import time, to avoid subprocess latency during test collection.
27+
28+
Returns
29+
-------
30+
bool
31+
True if R and did package are available, False otherwise.
32+
"""
33+
global _r_available_cache
34+
if _r_available_cache is None:
35+
# Allow environment override (matches DIFF_DIFF_BACKEND pattern)
36+
r_env = os.environ.get("DIFF_DIFF_R", "auto").lower()
37+
if r_env == "skip":
38+
_r_available_cache = False
39+
else:
40+
try:
41+
result = subprocess.run(
42+
["Rscript", "-e", "library(did); cat('OK')"],
43+
capture_output=True,
44+
text=True,
45+
timeout=30,
46+
)
47+
_r_available_cache = result.returncode == 0 and "OK" in result.stdout
48+
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
49+
_r_available_cache = False
50+
return _r_available_cache
51+
52+
53+
@pytest.fixture(scope="session")
54+
def r_available():
55+
"""
56+
Lazy check for R availability.
57+
58+
This fixture is session-scoped and cached, so R availability is only
59+
checked once per test session, and only when a test actually needs it.
60+
61+
Returns
62+
-------
63+
bool
64+
True if R and did package are available.
65+
"""
66+
return _check_r_available()
67+
68+
69+
@pytest.fixture
70+
def require_r(r_available):
71+
"""
72+
Skip test if R is not available.
73+
74+
Use this fixture in tests that require R:
75+
76+
```python
77+
def test_comparison_with_r(require_r):
78+
# This test will be skipped if R is not available
79+
...
80+
```
81+
"""
82+
if not r_available:
83+
pytest.skip("R or did package not available")

0 commit comments

Comments
 (0)