Skip to content

Commit 736fdea

Browse files
igerberclaude
andcommitted
Address PR #350 CI review round 5: P2 to_dict JSON-safe
`HeterogeneousAdoptionDiDEventStudyResults.to_dict()` promised JSON- serializable output but previously returned raw numpy scalars via `list(ndarray)`, which `json.dumps` can't serialize. The `F` field and `filter_info.F_last` could also hold numpy scalars or pandas Timestamps that break serialization. Fix: - Per-horizon arrays use `.tolist()` (unwraps numpy scalars to native Python). - New `_json_safe_scalar` helper coerces numpy scalars via `.item()` and pandas Timestamp/Timedelta via `.isoformat()`; everything else passes through. - New `_json_safe_filter_info` helper applies `_json_safe_scalar` to `F_last` and each element of `dropped_cohorts`, and casts counts to native `int`. - `to_dict()` now applies these helpers consistently. **Test added:** `test_to_dict_json_serializable` asserts `json.dumps(result.to_dict())` succeeds and the round-trip values parse back as native Python types (int, float, list). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1f044e7 commit 736fdea

2 files changed

Lines changed: 89 additions & 16 deletions

File tree

diff_diff/had.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,53 @@
124124
}
125125

126126

127+
# =============================================================================
128+
# JSON-serialization helpers
129+
# =============================================================================
130+
131+
132+
def _json_safe_scalar(x: Any) -> Any:
133+
"""Coerce a scalar to a JSON-serializable type.
134+
135+
- NumPy scalars (``np.int64``, ``np.float64``, ``np.bool_``) ->
136+
native Python via ``.item()``
137+
- ``pd.Timestamp`` / ``pd.Timedelta`` -> ISO 8601 string via
138+
``.isoformat()``
139+
- Everything else returned as-is.
140+
141+
The ``to_dict`` methods use this to keep the returned dict
142+
serializable via ``json.dumps`` regardless of the underlying
143+
pandas/numpy dtype of the time / first_treat columns.
144+
"""
145+
if isinstance(x, (pd.Timestamp, pd.Timedelta)):
146+
return x.isoformat()
147+
if hasattr(x, "item") and callable(getattr(x, "item")):
148+
try:
149+
return x.item()
150+
except (AttributeError, ValueError, TypeError):
151+
return x
152+
return x
153+
154+
155+
def _json_safe_filter_info(
156+
filter_info: Optional[Dict[str, Any]],
157+
) -> Optional[Dict[str, Any]]:
158+
"""Normalize a ``filter_info`` dict to JSON-safe scalars.
159+
160+
Returns ``None`` unchanged; otherwise coerces ``F_last`` and each
161+
entry in ``dropped_cohorts`` via :func:`_json_safe_scalar`. Int
162+
counts are cast to ``int`` for stability.
163+
"""
164+
if filter_info is None:
165+
return None
166+
return {
167+
"F_last": _json_safe_scalar(filter_info.get("F_last")),
168+
"n_kept": int(filter_info.get("n_kept", 0)),
169+
"n_dropped": int(filter_info.get("n_dropped", 0)),
170+
"dropped_cohorts": [_json_safe_scalar(c) for c in filter_info.get("dropped_cohorts", [])],
171+
}
172+
173+
127174
# =============================================================================
128175
# Results dataclass
129176
# =============================================================================
@@ -591,29 +638,34 @@ def print_summary(self) -> None:
591638
def to_dict(self) -> Dict[str, Any]:
592639
"""Return results as a dict with per-horizon arrays and scalars.
593640
594-
Per-horizon arrays are returned as Python lists for JSON-
595-
serialization friendliness.
641+
Per-horizon arrays are converted to Python lists via
642+
``ndarray.tolist()`` (which unwraps NumPy scalar elements to
643+
native ``int`` / ``float``); scalar fields are coerced to
644+
native Python types via ``_json_safe_scalar`` where relevant
645+
(NumPy scalars -> ``.item()``, pandas ``Timestamp`` -> ISO
646+
string, ``Timedelta`` -> ISO string). The returned dict is
647+
JSON-serializable directly via ``json.dumps``.
596648
"""
597649
return {
598-
"event_times": list(self.event_times),
599-
"att": list(self.att),
600-
"se": list(self.se),
601-
"t_stat": list(self.t_stat),
602-
"p_value": list(self.p_value),
603-
"conf_int_low": list(self.conf_int_low),
604-
"conf_int_high": list(self.conf_int_high),
605-
"n_obs_per_horizon": list(self.n_obs_per_horizon),
606-
"alpha": self.alpha,
650+
"event_times": self.event_times.tolist(),
651+
"att": self.att.tolist(),
652+
"se": self.se.tolist(),
653+
"t_stat": self.t_stat.tolist(),
654+
"p_value": self.p_value.tolist(),
655+
"conf_int_low": self.conf_int_low.tolist(),
656+
"conf_int_high": self.conf_int_high.tolist(),
657+
"n_obs_per_horizon": self.n_obs_per_horizon.tolist(),
658+
"alpha": float(self.alpha),
607659
"design": self.design,
608660
"target_parameter": self.target_parameter,
609-
"d_lower": self.d_lower,
610-
"dose_mean": self.dose_mean,
611-
"F": self.F,
612-
"n_units": self.n_units,
661+
"d_lower": float(self.d_lower),
662+
"dose_mean": float(self.dose_mean),
663+
"F": _json_safe_scalar(self.F),
664+
"n_units": int(self.n_units),
613665
"inference_method": self.inference_method,
614666
"vcov_type": self.vcov_type,
615667
"cluster_name": self.cluster_name,
616-
"filter_info": self.filter_info,
668+
"filter_info": _json_safe_filter_info(self.filter_info),
617669
}
618670

619671
def to_dataframe(self) -> pd.DataFrame:

tests/test_had.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,6 +2576,27 @@ def test_to_dict_shape(self):
25762576
assert d["design"] == result.design
25772577
assert d["F"] == result.F
25782578

2579+
def test_to_dict_json_serializable(self):
2580+
"""``to_dict()`` output must be JSON-serializable via ``json.dumps``.
2581+
2582+
Covers CI reviewer round 5 P2: previously the per-horizon arrays
2583+
contained numpy scalars that tripped ``json.dumps``.
2584+
"""
2585+
import json
2586+
2587+
result = self._fit()
2588+
d = result.to_dict()
2589+
# Should not raise.
2590+
payload = json.dumps(d)
2591+
assert isinstance(payload, str)
2592+
# Round-trip: values should parse back as native Python types.
2593+
parsed = json.loads(payload)
2594+
assert isinstance(parsed["event_times"], list)
2595+
assert isinstance(parsed["event_times"][0], int)
2596+
assert isinstance(parsed["att"][0], float)
2597+
assert isinstance(parsed["alpha"], float)
2598+
assert isinstance(parsed["n_units"], int)
2599+
25792600
def test_summary_renders(self):
25802601
result = self._fit()
25812602
summary = result.summary()

0 commit comments

Comments
 (0)