Skip to content

Commit d1dadfe

Browse files
igerberclaude
andcommitted
Fix DDD _snap_n floor mismatch in explicit n_range path
Hoist abs_min before the bracket branch so explicit n_range snaps use the DDD realizable floor (16) instead of the registry profile floor (64). Also pass floor=abs_min to bisection midpoint snap. Adds regression test for low-range DDD sample-size search. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f1112ac commit d1dadfe

2 files changed

Lines changed: 29 additions & 3 deletions

File tree

diff_diff/power.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,8 +2316,11 @@ def _power_at_n(n: int) -> float:
23162316
return pwr
23172317

23182318
# --- Bracket ---
2319+
abs_min = 16 if is_ddd_grid else 4
23192320
if n_range is not None:
2320-
lo, hi = _snap_n(n_range[0], "up"), _snap_n(n_range[1], "down")
2321+
lo, hi = _snap_n(n_range[0], "up", floor=abs_min), _snap_n(
2322+
n_range[1], "down", floor=abs_min
2323+
)
23212324
if lo > hi:
23222325
lo = hi # collapsed bracket — evaluate single point
23232326
power_lo = _power_at_n(lo)
@@ -2351,7 +2354,6 @@ def _power_at_n(n: int) -> float:
23512354
if power_lo >= power:
23522355
# Floor achieves target — search downward for true minimum
23532356
hi = lo
2354-
abs_min = 16 if is_ddd_grid else 4
23552357
found_lower = False
23562358
probe = _snap_n(max(abs_min, lo // 2), floor=abs_min)
23572359
for _ in range(8):
@@ -2413,7 +2415,7 @@ def _power_at_n(n: int) -> float:
24132415
for _ in range(max_steps):
24142416
if hi - lo <= convergence_threshold:
24152417
break
2416-
mid = _snap_n((lo + hi) // 2)
2418+
mid = _snap_n((lo + hi) // 2, floor=abs_min)
24172419
if mid <= lo or mid >= hi:
24182420
break
24192421
pwr = _power_at_n(mid)

tests/test_power.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,30 @@ def test_ddd_sample_size_grid_aligned(self):
11451145
result.required_n % 8 == 0
11461146
), f"DDD required_n={result.required_n} is not a multiple of 8"
11471147

1148+
@pytest.mark.slow
1149+
def test_ddd_sample_size_low_range(self):
1150+
"""DDD sample-size search with low n_range stays within bracket."""
1151+
result = simulate_sample_size(
1152+
TripleDifference(),
1153+
n_periods=2,
1154+
treatment_period=1,
1155+
treatment_effect=0.5,
1156+
sigma=5.0,
1157+
n_simulations=5,
1158+
n_range=(16, 56),
1159+
seed=42,
1160+
progress=False,
1161+
)
1162+
assert (
1163+
result.required_n % 8 == 0
1164+
), f"DDD required_n={result.required_n} is not a multiple of 8"
1165+
assert (
1166+
16 <= result.required_n <= 56
1167+
), f"DDD required_n={result.required_n} outside requested bracket [16, 56]"
1168+
assert (
1169+
len(result.search_path) > 2
1170+
), f"Bisection should explore >2 points, got {len(result.search_path)}"
1171+
11481172
@pytest.mark.slow
11491173
def test_trop(self):
11501174
result = simulate_power(

0 commit comments

Comments
 (0)