@@ -1074,7 +1074,9 @@ def fit(
10741074 )
10751075
10761076 # Compute overall ATT (simple aggregation)
1077- overall_att , overall_se = self ._aggregate_simple (group_time_effects , df , unit )
1077+ overall_att , overall_se = self ._aggregate_simple (
1078+ group_time_effects , influence_func_info , df , unit
1079+ )
10781080 overall_t = overall_att / overall_se if overall_se > 0 else 0.0
10791081 overall_p = compute_p_value (overall_t )
10801082 overall_ci = compute_confidence_interval (overall_att , overall_se , self .alpha )
@@ -1085,12 +1087,13 @@ def fit(
10851087
10861088 if aggregate in ["event_study" , "all" ]:
10871089 event_study_effects = self ._aggregate_event_study (
1088- group_time_effects , treatment_groups , time_periods , balance_e
1090+ group_time_effects , influence_func_info ,
1091+ treatment_groups , time_periods , balance_e
10891092 )
10901093
10911094 if aggregate in ["group" , "all" ]:
10921095 group_effects = self ._aggregate_by_group (
1093- group_time_effects , treatment_groups
1096+ group_time_effects , influence_func_info , treatment_groups
10941097 )
10951098
10961099 # Run bootstrap inference if requested
@@ -1423,42 +1426,108 @@ def _doubly_robust(
14231426 def _aggregate_simple (
14241427 self ,
14251428 group_time_effects : Dict ,
1429+ influence_func_info : Dict ,
14261430 df : pd .DataFrame ,
14271431 unit : str ,
14281432 ) -> Tuple [float , float ]:
14291433 """
14301434 Compute simple weighted average of ATT(g,t).
14311435
14321436 Weights by group size (number of treated units).
1437+
1438+ Standard errors are computed using influence function aggregation,
1439+ which properly accounts for covariances across (g,t) pairs due to
1440+ shared control units. This matches R's `did` package approach.
14331441 """
14341442 effects = []
1435- weights = []
1436- variances = []
1443+ weights_list = []
1444+ gt_pairs = []
14371445
14381446 for (g , t ), data in group_time_effects .items ():
14391447 effects .append (data ['effect' ])
1440- weights .append (data ['n_treated' ])
1441- variances .append (data [ 'se' ] ** 2 )
1448+ weights_list .append (data ['n_treated' ])
1449+ gt_pairs .append (( g , t ) )
14421450
14431451 effects = np .array (effects )
1444- weights = np .array (weights , dtype = float )
1445- variances = np .array (variances )
1452+ weights = np .array (weights_list , dtype = float )
14461453
14471454 # Normalize weights
14481455 weights = weights / np .sum (weights )
14491456
14501457 # Weighted average
14511458 overall_att = np .sum (weights * effects )
14521459
1453- # Standard error (assuming independence across g,t)
1454- overall_var = np .sum ((weights ** 2 ) * variances )
1455- overall_se = np .sqrt (overall_var )
1460+ # Compute SE using influence function aggregation
1461+ overall_se = self ._compute_aggregated_se (
1462+ gt_pairs , weights , influence_func_info
1463+ )
14561464
14571465 return overall_att , overall_se
14581466
1467+ def _compute_aggregated_se (
1468+ self ,
1469+ gt_pairs : List [Tuple [Any , Any ]],
1470+ weights : np .ndarray ,
1471+ influence_func_info : Dict ,
1472+ ) -> float :
1473+ """
1474+ Compute standard error using influence function aggregation.
1475+
1476+ This properly accounts for covariances across (g,t) pairs by
1477+ aggregating unit-level influence functions:
1478+
1479+ ψ_i(overall) = Σ_{(g,t)} w_(g,t) × ψ_i(g,t)
1480+ Var(overall) = (1/n) Σ_i [ψ_i]²
1481+
1482+ This matches R's `did` package analytical SE formula.
1483+ """
1484+ if not influence_func_info :
1485+ # Fallback if no influence functions available
1486+ return 0.0
1487+
1488+ # Build unit index mapping from all (g,t) pairs
1489+ all_units = set ()
1490+ for (g , t ) in gt_pairs :
1491+ if (g , t ) in influence_func_info :
1492+ info = influence_func_info [(g , t )]
1493+ all_units .update (info ['treated_units' ])
1494+ all_units .update (info ['control_units' ])
1495+
1496+ if not all_units :
1497+ return 0.0
1498+
1499+ all_units = sorted (all_units )
1500+ n_units = len (all_units )
1501+ unit_to_idx = {u : i for i , u in enumerate (all_units )}
1502+
1503+ # Aggregate influence functions across (g,t) pairs
1504+ psi_overall = np .zeros (n_units )
1505+
1506+ for j , (g , t ) in enumerate (gt_pairs ):
1507+ if (g , t ) not in influence_func_info :
1508+ continue
1509+
1510+ info = influence_func_info [(g , t )]
1511+ w = weights [j ]
1512+
1513+ # Treated unit contributions
1514+ for i , unit_id in enumerate (info ['treated_units' ]):
1515+ idx = unit_to_idx [unit_id ]
1516+ psi_overall [idx ] += w * info ['treated_inf' ][i ]
1517+
1518+ # Control unit contributions
1519+ for i , unit_id in enumerate (info ['control_units' ]):
1520+ idx = unit_to_idx [unit_id ]
1521+ psi_overall [idx ] += w * info ['control_inf' ][i ]
1522+
1523+ # Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ²
1524+ variance = np .sum (psi_overall ** 2 )
1525+ return np .sqrt (variance )
1526+
14591527 def _aggregate_event_study (
14601528 self ,
14611529 group_time_effects : Dict ,
1530+ influence_func_info : Dict ,
14621531 groups : List [Any ],
14631532 time_periods : List [Any ],
14641533 balance_e : Optional [int ] = None ,
@@ -1467,17 +1536,20 @@ def _aggregate_event_study(
14671536 Aggregate effects by relative time (event study).
14681537
14691538 Computes average effect at each event time e = t - g.
1539+
1540+ Standard errors use influence function aggregation to account for
1541+ covariances across (g,t) pairs.
14701542 """
1471- # Organize effects by relative time
1472- effects_by_e : Dict [int , List [Tuple [float , float , int ]]] = {}
1543+ # Organize effects by relative time, keeping track of (g,t) pairs
1544+ effects_by_e : Dict [int , List [Tuple [Tuple [ Any , Any ] , float , int ]]] = {}
14731545
14741546 for (g , t ), data in group_time_effects .items ():
14751547 e = t - g # Relative time
14761548 if e not in effects_by_e :
14771549 effects_by_e [e ] = []
14781550 effects_by_e [e ].append ((
1551+ (g , t ), # Keep track of the (g,t) pair
14791552 data ['effect' ],
1480- data ['se' ],
14811553 data ['n_treated' ]
14821554 ))
14831555
@@ -1490,15 +1562,15 @@ def _aggregate_event_study(
14901562 groups_at_e .add (g )
14911563
14921564 # Filter effects to only include balanced groups
1493- balanced_effects : Dict [int , List [Tuple [float , float , int ]]] = {}
1565+ balanced_effects : Dict [int , List [Tuple [Tuple [ Any , Any ] , float , int ]]] = {}
14941566 for (g , t ), data in group_time_effects .items ():
14951567 if g in groups_at_e :
14961568 e = t - g
14971569 if e not in balanced_effects :
14981570 balanced_effects [e ] = []
14991571 balanced_effects [e ].append ((
1572+ (g , t ),
15001573 data ['effect' ],
1501- data ['se' ],
15021574 data ['n_treated' ]
15031575 ))
15041576 effects_by_e = balanced_effects
@@ -1507,16 +1579,19 @@ def _aggregate_event_study(
15071579 event_study_effects = {}
15081580
15091581 for e , effect_list in sorted (effects_by_e .items ()):
1510- effs = np . array ( [x [0 ] for x in effect_list ])
1511- ses = np .array ([x [1 ] for x in effect_list ])
1582+ gt_pairs = [x [0 ] for x in effect_list ]
1583+ effs = np .array ([x [1 ] for x in effect_list ])
15121584 ns = np .array ([x [2 ] for x in effect_list ], dtype = float )
15131585
15141586 # Weight by group size
15151587 weights = ns / np .sum (ns )
15161588
15171589 agg_effect = np .sum (weights * effs )
1518- agg_var = np .sum ((weights ** 2 ) * (ses ** 2 ))
1519- agg_se = np .sqrt (agg_var )
1590+
1591+ # Compute SE using influence function aggregation
1592+ agg_se = self ._compute_aggregated_se (
1593+ gt_pairs , weights , influence_func_info
1594+ )
15201595
15211596 t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
15221597 p_val = compute_p_value (t_stat )
@@ -1536,35 +1611,43 @@ def _aggregate_event_study(
15361611 def _aggregate_by_group (
15371612 self ,
15381613 group_time_effects : Dict ,
1614+ influence_func_info : Dict ,
15391615 groups : List [Any ],
15401616 ) -> Dict [Any , Dict [str , Any ]]:
15411617 """
15421618 Aggregate effects by treatment cohort.
15431619
15441620 Computes average effect for each cohort across all post-treatment periods.
1621+
1622+ Standard errors use influence function aggregation to account for
1623+ covariances across time periods within a cohort.
15451624 """
15461625 group_effects = {}
15471626
15481627 for g in groups :
15491628 # Get all effects for this group (post-treatment only: t >= g)
1629+ # Keep track of (g, t) pairs for influence function aggregation
15501630 g_effects = [
1551- (data [ 'effect' ], data [ 'se' ] , data ['n_treated ' ])
1631+ (( g , t ) , data ['effect ' ])
15521632 for (gg , t ), data in group_time_effects .items ()
15531633 if gg == g and t >= g
15541634 ]
15551635
15561636 if not g_effects :
15571637 continue
15581638
1559- effs = np . array ( [x [0 ] for x in g_effects ])
1560- ses = np .array ([x [1 ] for x in g_effects ])
1639+ gt_pairs = [x [0 ] for x in g_effects ]
1640+ effs = np .array ([x [1 ] for x in g_effects ])
15611641
15621642 # Equal weight across time periods for a group
15631643 weights = np .ones (len (effs )) / len (effs )
15641644
15651645 agg_effect = np .sum (weights * effs )
1566- agg_var = np .sum ((weights ** 2 ) * (ses ** 2 ))
1567- agg_se = np .sqrt (agg_var )
1646+
1647+ # Compute SE using influence function aggregation
1648+ agg_se = self ._compute_aggregated_se (
1649+ gt_pairs , weights , influence_func_info
1650+ )
15681651
15691652 t_stat = agg_effect / agg_se if agg_se > 0 else 0.0
15701653 p_val = compute_p_value (t_stat )
0 commit comments