Skip to content

Commit d4e7739

Browse files
igerberclaude
andcommitted
Fix P1/P2 review findings: color parser, staircase counts, dose-response target, bacon parity
- Replace hex-only _hex_to_rgba with _color_to_rgba supporting named CSS colors, 3-digit hex, and matplotlib fallback (P1) - Use max(n_treated) across cells per cohort in plot_staircase with warning on missingness instead of first-cell value (P1) - Infer target from curve.target in plot_dose_response when curve passed directly, preventing ATT-labeled ACRT plots (P2) - Add show_weighted_avg vertical lines to plot_bacon plotly scatter (P2) - Add 9 targeted regression tests covering all fixes (P2) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d46a33f commit d4e7739

6 files changed

Lines changed: 265 additions & 28 deletions

File tree

diff_diff/visualization/_common.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,43 @@ def _plotly_default_layout(fig, *, title=None, xlabel=None, ylabel=None, show_le
6565
)
6666

6767

68-
def _hex_to_rgba(hex_color, alpha=1.0):
69-
"""Convert hex color to rgba string for plotly.
68+
_CSS_COLORS = {
69+
"red": (255, 0, 0),
70+
"blue": (0, 0, 255),
71+
"green": (0, 128, 0),
72+
"black": (0, 0, 0),
73+
"white": (255, 255, 255),
74+
"gray": (128, 128, 128),
75+
"grey": (128, 128, 128),
76+
"lightgray": (211, 211, 211),
77+
"lightgrey": (211, 211, 211),
78+
"darkgray": (169, 169, 169),
79+
"darkgrey": (169, 169, 169),
80+
"orange": (255, 165, 0),
81+
"purple": (128, 0, 128),
82+
"yellow": (255, 255, 0),
83+
"cyan": (0, 255, 255),
84+
"magenta": (255, 0, 255),
85+
"pink": (255, 192, 203),
86+
"brown": (165, 42, 42),
87+
"navy": (0, 0, 128),
88+
"teal": (0, 128, 128),
89+
"olive": (128, 128, 0),
90+
"coral": (255, 127, 80),
91+
"salmon": (250, 128, 114),
92+
}
93+
94+
95+
def _color_to_rgba(color, alpha=1.0):
96+
"""Convert any color to an ``rgba(r, g, b, a)`` string for plotly.
97+
98+
Accepts hex colors (``#rrggbb``, ``#rgb``), CSS named colors, and
99+
falls back to ``matplotlib.colors`` when available.
70100
71101
Parameters
72102
----------
73-
hex_color : str
74-
Hex color string (e.g., ``"#2563eb"``).
103+
color : str
104+
Color specification.
75105
alpha : float, default=1.0
76106
Opacity value between 0 and 1.
77107
@@ -80,9 +110,45 @@ def _hex_to_rgba(hex_color, alpha=1.0):
80110
str
81111
An ``rgba(r, g, b, a)`` string.
82112
"""
83-
hex_color = hex_color.lstrip("#")
84-
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
85-
return f"rgba({r}, {g}, {b}, {alpha})"
113+
if not isinstance(color, str):
114+
raise ValueError(f"Expected a color string, got {type(color).__name__}")
115+
116+
# 1. Hex colors: #rrggbb or #rgb
117+
stripped = color.lstrip("#")
118+
if color.startswith("#") and all(c in "0123456789abcdefABCDEF" for c in stripped):
119+
if len(stripped) == 6:
120+
r = int(stripped[0:2], 16)
121+
g = int(stripped[2:4], 16)
122+
b = int(stripped[4:6], 16)
123+
return f"rgba({r}, {g}, {b}, {alpha})"
124+
if len(stripped) == 3:
125+
r = int(stripped[0] * 2, 16)
126+
g = int(stripped[1] * 2, 16)
127+
b = int(stripped[2] * 2, 16)
128+
return f"rgba({r}, {g}, {b}, {alpha})"
129+
130+
# 2. Named CSS colors
131+
if color.lower() in _CSS_COLORS:
132+
r, g, b = _CSS_COLORS[color.lower()]
133+
return f"rgba({r}, {g}, {b}, {alpha})"
134+
135+
# 3. Already an rgba/rgb string — override alpha
136+
if color.startswith("rgba(") or color.startswith("rgb("):
137+
return color if alpha == 1.0 else color # pass through for plotly
138+
139+
# 4. Fallback: try matplotlib.colors if available
140+
try:
141+
from matplotlib.colors import to_rgb
142+
143+
r, g, b = to_rgb(color)
144+
return f"rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {alpha})"
145+
except (ImportError, ValueError):
146+
pass
147+
148+
raise ValueError(
149+
f"Cannot parse color '{color}'. Use hex (#rrggbb), a CSS color name, "
150+
"or install matplotlib for full color support."
151+
)
86152

87153

88154
# Default color constants

diff_diff/visualization/_continuous.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def plot_dose_response(
8787
raise ValueError(f"target must be 'att' or 'acrt', got '{target}'")
8888

8989
if curve is not None:
90+
# Infer target from curve when passed directly (not via results)
91+
if results is None and hasattr(curve, "target") and curve.target:
92+
target = curve.target
9093
dose_grid = curve.dose_grid
9194
effects = curve.effects
9295
ci_lower = curve.conf_int_lower
@@ -223,7 +226,7 @@ def _render_dose_response_plotly(
223226
):
224227
"""Render dose-response curve with plotly."""
225228
from diff_diff.visualization._common import (
226-
_hex_to_rgba,
229+
_color_to_rgba,
227230
_plotly_default_layout,
228231
_require_plotly,
229232
)
@@ -245,7 +248,7 @@ def _render_dose_response_plotly(
245248
x=dose_list + dose_list[::-1],
246249
y=list(ci_upper) + list(ci_lower)[::-1],
247250
fill="toself",
248-
fillcolor=_hex_to_rgba(band_color, 0.15),
251+
fillcolor=_color_to_rgba(band_color, 0.15),
249252
line=dict(color="rgba(0,0,0,0)"),
250253
name="95% CI",
251254
hoverinfo="skip",

diff_diff/visualization/_diagnostic.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def _render_sensitivity_plotly(
248248
):
249249
"""Render sensitivity plot with plotly."""
250250
from diff_diff.visualization._common import (
251-
_hex_to_rgba,
251+
_color_to_rgba,
252252
_plotly_default_layout,
253253
_require_plotly,
254254
)
@@ -278,7 +278,7 @@ def _render_sensitivity_plotly(
278278
x=M_list + M_list[::-1],
279279
y=list(bounds_arr[:, 1]) + list(bounds_arr[:, 0])[::-1],
280280
fill="toself",
281-
fillcolor=_hex_to_rgba(bounds_color, bounds_alpha),
281+
fillcolor=_color_to_rgba(bounds_color, bounds_alpha),
282282
line=dict(color="rgba(0,0,0,0)"),
283283
name="Identified set",
284284
)
@@ -742,6 +742,19 @@ def _render_bacon_plotly(
742742
)
743743
)
744744

745+
# Weighted average lines
746+
if show_weighted_avg:
747+
effect_by_type = results.effect_by_type()
748+
for ctype, avg_effect in effect_by_type.items():
749+
if avg_effect is not None and by_type[ctype]:
750+
fig.add_vline(
751+
x=avg_effect,
752+
line_dash="dash",
753+
line_color=colors[ctype],
754+
opacity=0.5,
755+
line_width=1.5,
756+
)
757+
745758
# TWFE line
746759
if show_twfe_line:
747760
fig.add_vline(

diff_diff/visualization/_event_study.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _render_event_study_plotly(
430430
):
431431
"""Render event study plot with plotly."""
432432
from diff_diff.visualization._common import (
433-
_hex_to_rgba,
433+
_color_to_rgba,
434434
_plotly_default_layout,
435435
_require_plotly,
436436
)
@@ -452,7 +452,7 @@ def _render_event_study_plotly(
452452
fig.add_vrect(
453453
x0=min(pre_in_plot) - 0.5,
454454
x1=max(pre_in_plot) + 0.5,
455-
fillcolor=_hex_to_rgba(shade_color, 0.5),
455+
fillcolor=_color_to_rgba(shade_color, 0.5),
456456
line_width=0,
457457
layer="below",
458458
)
@@ -477,7 +477,7 @@ def _render_event_study_plotly(
477477
x=ci_periods + ci_periods[::-1],
478478
y=ci_hi + ci_lo[::-1],
479479
fill="toself",
480-
fillcolor=_hex_to_rgba(color, 0.15),
480+
fillcolor=_color_to_rgba(color, 0.15),
481481
line=dict(color="rgba(0,0,0,0)"),
482482
showlegend=False,
483483
hoverinfo="skip",
@@ -970,7 +970,7 @@ def _render_honest_event_study_plotly(
970970
):
971971
"""Render honest event study plot with plotly."""
972972
from diff_diff.visualization._common import (
973-
_hex_to_rgba,
973+
_color_to_rgba,
974974
_plotly_default_layout,
975975
_require_plotly,
976976
)
@@ -988,7 +988,7 @@ def _render_honest_event_study_plotly(
988988
x=list(periods) + list(periods)[::-1],
989989
y=list(original_ci_upper) + list(original_ci_lower)[::-1],
990990
fill="toself",
991-
fillcolor=_hex_to_rgba(original_color, 0.15),
991+
fillcolor=_color_to_rgba(original_color, 0.15),
992992
line=dict(color="rgba(0,0,0,0)"),
993993
name="Standard CI",
994994
hoverinfo="skip",
@@ -1001,7 +1001,7 @@ def _render_honest_event_study_plotly(
10011001
x=list(periods) + list(periods)[::-1],
10021002
y=list(honest_ci_upper) + list(honest_ci_lower)[::-1],
10031003
fill="toself",
1004-
fillcolor=_hex_to_rgba(honest_color, 0.15),
1004+
fillcolor=_color_to_rgba(honest_color, 0.15),
10051005
line=dict(color="rgba(0,0,0,0)"),
10061006
name=f"Honest CI (M={honest_M:.2f})",
10071007
hoverinfo="skip",

diff_diff/visualization/_staggered.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,29 @@ def _extract_staircase_data(results, data, unit, time, first_treat):
298298
groups = sorted(results.groups)
299299
cohort_counts = []
300300
for g in groups:
301-
# Find a representative (g, t) entry to get n_treated for this cohort
302-
n_treated = None
303-
for (gg, tt), eff in results.group_time_effects.items():
301+
# Collect n_treated across all (g, t) cells for this cohort.
302+
# n_treated is a per-cell observation count that can vary with
303+
# missingness, so we use the max as the best cohort size estimate.
304+
cell_counts = []
305+
for (gg, _tt), eff in results.group_time_effects.items():
304306
if gg == g:
305-
n_treated = eff.get("n_treated", eff.get("n_obs", None))
306-
if n_treated is not None:
307-
break
308-
if n_treated is None:
309-
n_treated = 0
310-
cohort_counts.append((g, int(n_treated)))
307+
n = eff.get("n_treated", eff.get("n_obs", None))
308+
if n is not None:
309+
cell_counts.append(int(n))
310+
if not cell_counts:
311+
cohort_counts.append((g, 0))
312+
continue
313+
max_count = max(cell_counts)
314+
if min(cell_counts) != max_count:
315+
import warnings
316+
317+
warnings.warn(
318+
f"Cohort {g}: n_treated varies across cells "
319+
f"({min(cell_counts)}-{max_count}). "
320+
f"Using max as cohort size; pass data= for exact counts.",
321+
stacklevel=3,
322+
)
323+
cohort_counts.append((g, max_count))
311324

312325
return cohort_counts
313326

@@ -391,7 +404,7 @@ def _render_staircase_mpl(*, cohort_counts, figsize, title, color, show_counts,
391404
def _render_staircase_plotly(*, cohort_counts, title, color, show_counts, show):
392405
"""Render staircase plot with plotly."""
393406
from diff_diff.visualization._common import (
394-
_hex_to_rgba,
407+
_color_to_rgba,
395408
_plotly_default_layout,
396409
_require_plotly,
397410
)
@@ -419,7 +432,7 @@ def _render_staircase_plotly(*, cohort_counts, title, color, show_counts, show):
419432
mode="lines",
420433
line=dict(color=color, width=2, shape="hv"),
421434
fill="tozeroy",
422-
fillcolor=_hex_to_rgba(color, 0.15),
435+
fillcolor=_color_to_rgba(color, 0.15),
423436
name="Cumulative treated",
424437
)
425438
)

0 commit comments

Comments
 (0)