Skip to content

Commit bdc8292

Browse files
igerberclaude
andcommitted
Introduce ConleyMetric type alias; fix public conley_metric signature
Address P2 Code Quality finding from CI Codex review of PR #411 on 0500909. Public signatures advertised `conley_metric: str` but the runtime + docs + tests + REGISTRY all accept a callable `(coords1, coords2) -> n×n` for custom (e.g. network) distance metrics. Static type checkers and IDE autocomplete were therefore lying about the public API contract. Introduce a shared `ConleyMetric` type alias in `diff_diff/conley.py`: ConleyMetric = Union[ Literal["haversine", "euclidean"], Callable[[np.ndarray, np.ndarray], np.ndarray], ] Sweep all 11 public/internal annotation sites: - `solve_ols` overloads (4) + impl (1) - `_solve_ols_numpy` overloads (4) + impl (1) - `compute_robust_vcov` - `_compute_robust_vcov_numpy` - `LinearRegression.__init__` Internal helpers in `conley.py` (`_pairwise_distance_matrix`, `_validate_conley_kwargs`, `_compute_conley_vcov`) also annotated. Moved the `from diff_diff.conley import ...` from late-file (line ~1192 with `noqa: E402`) to top-of-module so the `ConleyMetric` name is in scope for the early function signatures (linalg.py has no `from __future__ import annotations`, so types evaluate eagerly). Verified locally: 72 Conley tests pass, ruff clean, sphinx -W exits 0. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0500909 commit bdc8292

3 files changed

Lines changed: 40 additions & 20 deletions

File tree

diff_diff/conley.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,22 @@
2222
from __future__ import annotations
2323

2424
import warnings
25-
from typing import Optional
25+
from typing import Callable, Literal, Optional, Union
2626

2727
import numpy as np
2828

29+
# Public type alias for the ``conley_metric`` parameter accepted by
30+
# ``compute_robust_vcov``, ``solve_ols``, and ``LinearRegression``. The
31+
# implementation accepts the two named strings as well as a user-supplied
32+
# callable ``(coords1, coords2) -> n×n distance matrix`` for custom
33+
# (e.g. network) distance metrics. Exported so the public signatures
34+
# in :mod:`diff_diff.linalg` can advertise the full accepted type to
35+
# static checkers and IDEs.
36+
ConleyMetric = Union[
37+
Literal["haversine", "euclidean"],
38+
Callable[[np.ndarray, np.ndarray], np.ndarray],
39+
]
40+
2941
# Earth's mean radius (km), matching R conleyreg's haversine convention
3042
# (Düsterhöft 2021, conleyreg::haversine_dist in src/distance_functions.cpp,
3143
# CRAN v0.1.9). WGS-84 equatorial radius is 6378.137 km; the 0.01 km delta
@@ -63,7 +75,7 @@ def _haversine_km(
6375
return _CONLEY_EARTH_RADIUS_KM * 2.0 * np.arcsin(np.sqrt(a))
6476

6577

66-
def _pairwise_distance_matrix(coords: np.ndarray, metric) -> np.ndarray:
78+
def _pairwise_distance_matrix(coords: np.ndarray, metric: ConleyMetric) -> np.ndarray:
6779
"""Build the dense n×n pairwise distance matrix.
6880
6981
``metric`` is one of ``"haversine"`` (lat/lon in degrees, distance in km),
@@ -116,7 +128,7 @@ def _uniform_kernel(u: np.ndarray) -> np.ndarray:
116128
def _validate_conley_kwargs(
117129
coords: Optional[np.ndarray],
118130
cutoff: Optional[float],
119-
metric,
131+
metric: ConleyMetric,
120132
kernel: str,
121133
n: int,
122134
) -> None:
@@ -197,7 +209,7 @@ def _compute_conley_vcov(
197209
residuals: np.ndarray,
198210
coords: np.ndarray,
199211
cutoff: float,
200-
metric,
212+
metric: ConleyMetric,
201213
kernel: str,
202214
bread_matrix: np.ndarray,
203215
) -> np.ndarray:

diff_diff/linalg.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@
4949
_rust_solve_ols,
5050
)
5151

52+
# Conley (1999) spatial HAC helpers live in diff_diff.conley to keep this
53+
# module focused on linear-algebra primitives. Imported at the top so the
54+
# `ConleyMetric` type alias is in scope for the public function signatures
55+
# below (which advertise `conley_metric: ConleyMetric`).
56+
from diff_diff.conley import (
57+
ConleyMetric,
58+
_compute_conley_vcov,
59+
_validate_conley_kwargs,
60+
)
61+
5262
# =============================================================================
5363
# Utility Functions
5464
# =============================================================================
@@ -352,7 +362,7 @@ def solve_ols(
352362
vcov_type: str = ...,
353363
conley_coords: Optional[np.ndarray] = ...,
354364
conley_cutoff_km: Optional[float] = ...,
355-
conley_metric: str = ...,
365+
conley_metric: ConleyMetric = ...,
356366
conley_kernel: str = ...,
357367
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
358368

@@ -374,7 +384,7 @@ def solve_ols(
374384
vcov_type: str = ...,
375385
conley_coords: Optional[np.ndarray] = ...,
376386
conley_cutoff_km: Optional[float] = ...,
377-
conley_metric: str = ...,
387+
conley_metric: ConleyMetric = ...,
378388
conley_kernel: str = ...,
379389
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
380390

@@ -396,7 +406,7 @@ def solve_ols(
396406
vcov_type: str = ...,
397407
conley_coords: Optional[np.ndarray] = ...,
398408
conley_cutoff_km: Optional[float] = ...,
399-
conley_metric: str = ...,
409+
conley_metric: ConleyMetric = ...,
400410
conley_kernel: str = ...,
401411
) -> Union[
402412
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
@@ -454,7 +464,7 @@ def solve_ols(
454464
vcov_type: str = "hc1",
455465
conley_coords: Optional[np.ndarray] = None,
456466
conley_cutoff_km: Optional[float] = None,
457-
conley_metric: str = "haversine",
467+
conley_metric: ConleyMetric = "haversine",
458468
conley_kernel: str = "bartlett",
459469
) -> Union[
460470
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
@@ -818,7 +828,7 @@ def _solve_ols_numpy(
818828
vcov_type: str = ...,
819829
conley_coords: Optional[np.ndarray] = ...,
820830
conley_cutoff_km: Optional[float] = ...,
821-
conley_metric: str = ...,
831+
conley_metric: ConleyMetric = ...,
822832
conley_kernel: str = ...,
823833
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
824834

@@ -838,7 +848,7 @@ def _solve_ols_numpy(
838848
vcov_type: str = ...,
839849
conley_coords: Optional[np.ndarray] = ...,
840850
conley_cutoff_km: Optional[float] = ...,
841-
conley_metric: str = ...,
851+
conley_metric: ConleyMetric = ...,
842852
conley_kernel: str = ...,
843853
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
844854

@@ -858,7 +868,7 @@ def _solve_ols_numpy(
858868
vcov_type: str = ...,
859869
conley_coords: Optional[np.ndarray] = ...,
860870
conley_cutoff_km: Optional[float] = ...,
861-
conley_metric: str = ...,
871+
conley_metric: ConleyMetric = ...,
862872
conley_kernel: str = ...,
863873
) -> Union[
864874
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
@@ -880,7 +890,7 @@ def _solve_ols_numpy(
880890
vcov_type: str = "hc1",
881891
conley_coords: Optional[np.ndarray] = None,
882892
conley_cutoff_km: Optional[float] = None,
883-
conley_metric: str = "haversine",
893+
conley_metric: ConleyMetric = "haversine",
884894
conley_kernel: str = "bartlett",
885895
) -> Union[
886896
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
@@ -1175,11 +1185,8 @@ def resolve_vcov_type(
11751185
return vcov_type
11761186

11771187

1178-
# Conley (1999) spatial HAC helpers live in diff_diff.conley to keep linalg.py
1179-
# focused on linear-algebra primitives. Imported here so the dispatch in
1180-
# `_compute_robust_vcov_numpy` can route `vcov_type="conley"` without a
1181-
# late/local import.
1182-
from diff_diff.conley import _compute_conley_vcov, _validate_conley_kwargs # noqa: E402
1188+
# Conley helpers are imported at module top — see the from-import near the
1189+
# header of this file.
11831190

11841191

11851192
def compute_robust_vcov(
@@ -1193,7 +1200,7 @@ def compute_robust_vcov(
11931200
*,
11941201
conley_coords: Optional[np.ndarray] = None,
11951202
conley_cutoff_km: Optional[float] = None,
1196-
conley_metric: str = "haversine",
1203+
conley_metric: ConleyMetric = "haversine",
11971204
conley_kernel: str = "bartlett",
11981205
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
11991206
"""
@@ -1640,7 +1647,7 @@ def _compute_robust_vcov_numpy(
16401647
*,
16411648
conley_coords: Optional[np.ndarray] = None,
16421649
conley_cutoff_km: Optional[float] = None,
1643-
conley_metric: str = "haversine",
1650+
conley_metric: ConleyMetric = "haversine",
16441651
conley_kernel: str = "bartlett",
16451652
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
16461653
"""
@@ -2478,7 +2485,7 @@ def __init__(
24782485
vcov_type: Optional[str] = None,
24792486
conley_coords: Optional[np.ndarray] = None,
24802487
conley_cutoff_km: Optional[float] = None,
2481-
conley_metric: str = "haversine",
2488+
conley_metric: ConleyMetric = "haversine",
24822489
conley_kernel: str = "bartlett",
24832490
):
24842491
self.include_intercept = include_intercept

docs/api/_autosummary/diff_diff.ChaisemartinDHaultfoeuilleResults.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
~ChaisemartinDHaultfoeuilleResults.p_value
4242
~ChaisemartinDHaultfoeuilleResults.path_cumulated_event_study
4343
~ChaisemartinDHaultfoeuilleResults.path_effects
44+
~ChaisemartinDHaultfoeuilleResults.path_heterogeneity_effects
4445
~ChaisemartinDHaultfoeuilleResults.path_placebo_event_study
4546
~ChaisemartinDHaultfoeuilleResults.path_sup_t_bands
4647
~ChaisemartinDHaultfoeuilleResults.placebo_event_study

0 commit comments

Comments
 (0)