1717import numpy as np
1818import pandas as pd
1919
20+ from diff_diff .linalg import compute_r_squared , compute_robust_vcov , solve_ols
2021from diff_diff .results import DiDResults , MultiPeriodDiDResults , PeriodEffect
2122from diff_diff .utils import (
2223 WildBootstrapResults ,
2324 compute_confidence_interval ,
2425 compute_p_value ,
25- compute_robust_se ,
2626 validate_binary ,
2727 wild_bootstrap_se ,
2828)
@@ -261,8 +261,11 @@ def fit(
261261 X = np .column_stack ([X , dummies [col ].values .astype (float )])
262262 var_names .append (col )
263263
264- # Fit OLS
265- coefficients , residuals , fitted , r_squared = self ._fit_ols (X , y )
264+ # Fit OLS using unified backend
265+ coefficients , residuals , fitted , vcov = solve_ols (
266+ X , y , return_fitted = True , return_vcov = False
267+ )
268+ r_squared = compute_r_squared (y , residuals )
266269
267270 # Extract ATT (coefficient on interaction term)
268271 att_idx = 3 # Index of interaction term
@@ -285,13 +288,13 @@ def fit(
285288 )
286289 elif self .cluster is not None :
287290 cluster_ids = data [self .cluster ].values
288- vcov = compute_robust_se (X , residuals , cluster_ids )
291+ vcov = compute_robust_vcov (X , residuals , cluster_ids )
289292 se = np .sqrt (vcov [att_idx , att_idx ])
290293 t_stat = att / se
291294 p_value = compute_p_value (t_stat , df = df )
292295 conf_int = compute_confidence_interval (att , se , self .alpha , df = df )
293296 elif self .robust :
294- vcov = compute_robust_se (X , residuals )
297+ vcov = compute_robust_vcov (X , residuals )
295298 se = np .sqrt (vcov [att_idx , att_idx ])
296299 t_stat = att / se
297300 p_value = compute_p_value (t_stat , df = df )
@@ -300,7 +303,7 @@ def fit(
300303 # Classical OLS standard errors
301304 n = len (y )
302305 k = X .shape [1 ]
303- mse = np .sum (residuals ** 2 ) / (n - k )
306+ mse = np .sum (residuals ** 2 ) / (n - k )
304307 # Use solve() instead of inv() for numerical stability
305308 # solve(A, B) computes X where AX=B, so this yields (X'X)^{-1} * mse
306309 vcov = np .linalg .solve (X .T @ X , mse * np .eye (k ))
@@ -352,10 +355,15 @@ def fit(
352355
353356 return self .results_
354357
355- def _fit_ols (self , X : np .ndarray , y : np .ndarray ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , float ]:
358+ def _fit_ols (
359+ self , X : np .ndarray , y : np .ndarray
360+ ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , float ]:
356361 """
357362 Fit OLS regression.
358363
364+ This method is kept for backwards compatibility. Internally uses the
365+ unified solve_ols from diff_diff.linalg for optimized computation.
366+
359367 Parameters
360368 ----------
361369 X : np.ndarray
@@ -367,32 +375,12 @@ def _fit_ols(self, X: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray
367375 -------
368376 tuple
369377 (coefficients, residuals, fitted_values, r_squared)
370-
371- Raises
372- ------
373- ValueError
374- If design matrix is rank-deficient (perfect multicollinearity).
375378 """
376- # Check for rank deficiency (perfect multicollinearity)
377- rank = np .linalg .matrix_rank (X )
378- if rank < X .shape [1 ]:
379- raise ValueError (
380- f"Design matrix is rank-deficient (rank { rank } < { X .shape [1 ]} columns). "
381- "This indicates perfect multicollinearity. Check your fixed effects "
382- "and covariates for linear dependencies."
383- )
384-
385- # Solve normal equations: β = (X'X)^(-1) X'y
386- coefficients = np .linalg .lstsq (X , y , rcond = None )[0 ]
387-
388- # Compute fitted values and residuals
389- fitted = X @ coefficients
390- residuals = y - fitted
391-
392- # Compute R-squared
393- ss_res = np .sum (residuals ** 2 )
394- ss_tot = np .sum ((y - np .mean (y )) ** 2 )
395- r_squared = 1 - (ss_res / ss_tot ) if ss_tot > 0 else 0.0
379+ # Use unified OLS backend
380+ coefficients , residuals , fitted , _ = solve_ols (
381+ X , y , return_fitted = True , return_vcov = False
382+ )
383+ r_squared = compute_r_squared (y , residuals )
396384
397385 return coefficients , residuals , fitted , r_squared
398386
@@ -442,7 +430,7 @@ def _run_wild_bootstrap_inference(
442430 t_stat = bootstrap_results .t_stat_original
443431
444432 # Also compute vcov for storage (using cluster-robust for consistency)
445- vcov = compute_robust_se (X , residuals , cluster_ids )
433+ vcov = compute_robust_vcov (X , residuals , cluster_ids )
446434
447435 return se , p_value , conf_int , t_stat , vcov , bootstrap_results
448436
@@ -889,8 +877,11 @@ def fit( # type: ignore[override]
889877 X = np .column_stack ([X , dummies [col ].values .astype (float )])
890878 var_names .append (col )
891879
892- # Fit OLS
893- coefficients , residuals , fitted , r_squared = self ._fit_ols (X , y )
880+ # Fit OLS using unified backend
881+ coefficients , residuals , fitted , _ = solve_ols (
882+ X , y , return_fitted = True , return_vcov = False
883+ )
884+ r_squared = compute_r_squared (y , residuals )
894885
895886 # Degrees of freedom
896887 df = len (y ) - X .shape [1 ] - n_absorbed_effects
@@ -900,13 +891,13 @@ def fit( # type: ignore[override]
900891 # For now, we use analytical inference even if inference="wild_bootstrap"
901892 if self .cluster is not None :
902893 cluster_ids = data [self .cluster ].values
903- vcov = compute_robust_se (X , residuals , cluster_ids )
894+ vcov = compute_robust_vcov (X , residuals , cluster_ids )
904895 elif self .robust :
905- vcov = compute_robust_se (X , residuals )
896+ vcov = compute_robust_vcov (X , residuals )
906897 else :
907898 n = len (y )
908899 k = X .shape [1 ]
909- mse = np .sum (residuals ** 2 ) / (n - k )
900+ mse = np .sum (residuals ** 2 ) / (n - k )
910901 # Use solve() instead of inv() for numerical stability
911902 # solve(A, B) computes X where AX=B, so this yields (X'X)^{-1} * mse
912903 vcov = np .linalg .solve (X .T @ X , mse * np .eye (k ))
0 commit comments