diff --git a/TODO.md b/TODO.md
index 264e76ae..5ae73fa9 100644
--- a/TODO.md
+++ b/TODO.md
@@ -71,7 +71,6 @@ Deferred items from PR reviews that were not addressed before merge.
| Issue | Location | PR | Priority |
|-------|----------|----|----------|
-| Plotly renderers silently ignore styling kwargs (marker, markersize, linewidth, capsize, ci_linewidth) that the matplotlib backend honors; thread them through or reject when `backend="plotly"` | `visualization/_event_study.py`, `_diagnostic.py`, `_power.py` | #222 | Medium |
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low |
| ~376 `duplicate object description` Sphinx warnings — restructure `docs/api/*.rst` to avoid duplicate `:members:` + `autosummary` | `docs/api/*.rst` | — | Low |
diff --git a/diff_diff/visualization/_common.py b/diff_diff/visualization/_common.py
index 1c08ca3f..464398c6 100644
--- a/diff_diff/visualization/_common.py
+++ b/diff_diff/visualization/_common.py
@@ -283,6 +283,42 @@ def _color_to_rgba(color, alpha=1.0):
)
+# Matplotlib marker code -> plotly symbol name mapping
+_MPL_TO_PLOTLY_SYMBOL = {
+ "o": "circle",
+ "s": "square",
+ "D": "diamond",
+ "d": "diamond",
+ "^": "triangle-up",
+ "v": "triangle-down",
+ "<": "triangle-left",
+ ">": "triangle-right",
+ "p": "pentagon",
+ "h": "hexagon",
+ "+": "cross",
+ "x": "x",
+ "*": "star",
+ ".": "circle",
+}
+
+
+def _mpl_marker_to_plotly_symbol(marker):
+ """Convert a matplotlib marker code to a plotly symbol name.
+
+ Parameters
+ ----------
+ marker : str
+ Matplotlib marker shorthand (e.g., ``"o"``, ``"s"``, ``"D"``).
+
+ Returns
+ -------
+ str
+ Plotly symbol name (e.g., ``"circle"``, ``"square"``, ``"diamond"``).
+ Returns ``"circle"`` for unrecognized markers.
+ """
+ return _MPL_TO_PLOTLY_SYMBOL.get(marker, "circle")
+
+
# Default color constants
DEFAULT_BLUE = "#2563eb"
DEFAULT_RED = "#dc2626"
diff --git a/diff_diff/visualization/_diagnostic.py b/diff_diff/visualization/_diagnostic.py
index ef876f55..66320ce5 100644
--- a/diff_diff/visualization/_diagnostic.py
+++ b/diff_diff/visualization/_diagnostic.py
@@ -111,6 +111,7 @@ def plot_sensitivity(
bounds_color=bounds_color,
bounds_alpha=bounds_alpha,
ci_color=ci_color,
+ ci_linewidth=ci_linewidth,
breakdown_color=breakdown_color,
original_color=original_color,
show=show,
@@ -242,6 +243,7 @@ def _render_sensitivity_plotly(
bounds_color,
bounds_alpha,
ci_color,
+ ci_linewidth,
breakdown_color,
original_color,
show,
@@ -291,7 +293,7 @@ def _render_sensitivity_plotly(
x=M_list,
y=list(ci_arr[:, 0]),
mode="lines",
- line=dict(color=ci_color, width=1.5),
+ line=dict(color=ci_color, width=ci_linewidth),
name="Robust CI",
)
)
@@ -300,7 +302,7 @@ def _render_sensitivity_plotly(
x=M_list,
y=list(ci_arr[:, 1]),
mode="lines",
- line=dict(color=ci_color, width=1.5),
+ line=dict(color=ci_color, width=ci_linewidth),
showlegend=False,
)
)
@@ -449,6 +451,8 @@ def plot_bacon(
xlabel=xlabel,
ylabel=ylabel,
colors=colors,
+ marker=marker,
+ markersize=markersize,
alpha=alpha,
show_weighted_avg=show_weighted_avg,
show_twfe_line=show_twfe_line,
@@ -699,13 +703,19 @@ def _render_bacon_plotly(
xlabel,
ylabel,
colors,
+ marker,
+ markersize,
alpha,
show_weighted_avg,
show_twfe_line,
show,
):
"""Render Bacon decomposition plot with plotly."""
- from diff_diff.visualization._common import _plotly_default_layout, _require_plotly
+ from diff_diff.visualization._common import (
+ _mpl_marker_to_plotly_symbol,
+ _plotly_default_layout,
+ _require_plotly,
+ )
go = _require_plotly()
@@ -727,6 +737,10 @@ def _render_bacon_plotly(
"later_vs_earlier": "Later vs Earlier (forbidden)",
}
+ # Convert matplotlib scatter area (points^2) to plotly diameter (px)
+ plotly_size = max(1, int(round(markersize**0.5)))
+ symbol = _mpl_marker_to_plotly_symbol(marker)
+
for ctype, points in by_type.items():
if not points:
continue
@@ -737,7 +751,12 @@ def _render_bacon_plotly(
x=estimates,
y=weights,
mode="markers",
- marker=dict(color=colors[ctype], size=10, opacity=alpha),
+ marker=dict(
+ color=colors[ctype],
+ size=plotly_size,
+ symbol=symbol,
+ opacity=alpha,
+ ),
name=labels[ctype],
)
)
diff --git a/diff_diff/visualization/_event_study.py b/diff_diff/visualization/_event_study.py
index 4deea681..dbc9cb0a 100644
--- a/diff_diff/visualization/_event_study.py
+++ b/diff_diff/visualization/_event_study.py
@@ -272,6 +272,8 @@ def plot_event_study(
xlabel=xlabel,
ylabel=ylabel,
color=color,
+ marker=marker,
+ markersize=markersize,
shade_pre=shade_pre,
shade_color=shade_color,
show_zero_line=show_zero_line,
@@ -422,6 +424,8 @@ def _render_event_study_plotly(
xlabel,
ylabel,
color,
+ marker,
+ markersize,
shade_pre,
shade_color,
show_zero_line,
@@ -431,6 +435,7 @@ def _render_event_study_plotly(
"""Render event study plot with plotly."""
from diff_diff.visualization._common import (
_color_to_rgba,
+ _mpl_marker_to_plotly_symbol,
_plotly_default_layout,
_require_plotly,
)
@@ -504,13 +509,15 @@ def _render_event_study_plotly(
hover_tpl = "Period: %{customdata}
Effect: %{y:.4f}"
+ symbol = _mpl_marker_to_plotly_symbol(marker)
+
if non_ref_x:
fig.add_trace(
go.Scatter(
x=non_ref_x,
y=non_ref_e,
mode="markers",
- marker=dict(color=color, size=10),
+ marker=dict(color=color, size=markersize, symbol=symbol),
name="Effect",
customdata=non_ref_labels,
hovertemplate=hover_tpl,
@@ -525,7 +532,8 @@ def _render_event_study_plotly(
mode="markers",
marker=dict(
color="white",
- size=10,
+ size=markersize,
+ symbol=symbol,
line=dict(color=color, width=2),
),
name="Reference",
@@ -842,6 +850,8 @@ def plot_honest_event_study(
ylabel=ylabel,
original_color=original_color,
honest_color=honest_color,
+ marker=marker,
+ markersize=markersize,
show=show,
)
@@ -987,11 +997,14 @@ def _render_honest_event_study_plotly(
ylabel,
original_color,
honest_color,
+ marker,
+ markersize,
show,
):
"""Render honest event study plot with plotly."""
from diff_diff.visualization._common import (
_color_to_rgba,
+ _mpl_marker_to_plotly_symbol,
_plotly_default_layout,
_require_plotly,
)
@@ -1036,13 +1049,15 @@ def _render_honest_event_study_plotly(
ref_p = [p for p, r in zip(periods, is_ref) if r]
ref_e = [e for e, r in zip(effects, is_ref) if r]
+ symbol = _mpl_marker_to_plotly_symbol(marker)
+
if non_ref_p:
fig.add_trace(
go.Scatter(
x=non_ref_p,
y=non_ref_e,
mode="markers",
- marker=dict(color=honest_color, size=10),
+ marker=dict(color=honest_color, size=markersize, symbol=symbol),
name="Effect",
)
)
@@ -1053,7 +1068,12 @@ def _render_honest_event_study_plotly(
x=ref_p,
y=ref_e,
mode="markers",
- marker=dict(color="white", size=10, line=dict(color=honest_color, width=2)),
+ marker=dict(
+ color="white",
+ size=markersize,
+ symbol=symbol,
+ line=dict(color=honest_color, width=2),
+ ),
name="Reference",
)
)
diff --git a/diff_diff/visualization/_power.py b/diff_diff/visualization/_power.py
index fd610df7..141352d9 100644
--- a/diff_diff/visualization/_power.py
+++ b/diff_diff/visualization/_power.py
@@ -165,8 +165,10 @@ def plot_power_curve(
color=color,
mde_color=mde_color,
target_color=target_color,
+ linewidth=linewidth,
show_mde_line=show_mde_line,
show_target_line=show_target_line,
+ show_grid=show_grid,
show=show,
)
@@ -291,8 +293,10 @@ def _render_power_curve_plotly(
color,
mde_color,
target_color,
+ linewidth,
show_mde_line,
show_target_line,
+ show_grid,
show,
):
"""Render power curve with plotly."""
@@ -307,7 +311,7 @@ def _render_power_curve_plotly(
x=effect_sizes,
y=powers,
mode="lines",
- line=dict(color=color, width=2),
+ line=dict(color=color, width=linewidth),
name="Power",
)
)
@@ -331,7 +335,8 @@ def _render_power_curve_plotly(
)
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
- fig.update_yaxes(range=[0, 1.05], tickformat=".0%")
+ fig.update_xaxes(showgrid=show_grid)
+ fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
if show:
fig.show()
@@ -482,8 +487,10 @@ def plot_pretrends_power(
color=color,
mdv_color=mdv_color,
target_color=target_color,
+ linewidth=linewidth,
show_mdv_line=show_mdv_line,
show_target_line=show_target_line,
+ show_grid=show_grid,
show=show,
)
@@ -602,8 +609,10 @@ def _render_pretrends_power_plotly(
color,
mdv_color,
target_color,
+ linewidth,
show_mdv_line,
show_target_line,
+ show_grid,
show,
):
"""Render pre-trends power curve with plotly."""
@@ -619,7 +628,7 @@ def _render_pretrends_power_plotly(
x=M_values,
y=powers,
mode="lines",
- line=dict(color=color, width=2),
+ line=dict(color=color, width=linewidth),
name="Power",
)
)
@@ -643,7 +652,8 @@ def _render_pretrends_power_plotly(
)
_plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel)
- fig.update_yaxes(range=[0, 1.05], tickformat=".0%")
+ fig.update_xaxes(showgrid=show_grid)
+ fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid)
if show:
fig.show()
diff --git a/tests/test_visualization_plotly.py b/tests/test_visualization_plotly.py
index c6428356..e84f0f47 100644
--- a/tests/test_visualization_plotly.py
+++ b/tests/test_visualization_plotly.py
@@ -318,12 +318,205 @@ def test_string_periods_in_customdata(self):
effects = {"pre": 0.0, "post": 0.5}
se = {"pre": 0.1, "post": 0.15}
- fig = plot_event_study(
- effects=effects, se=se, backend="plotly", show=False
- )
+ fig = plot_event_study(effects=effects, se=se, backend="plotly", show=False)
# At least one point trace should have customdata with original labels
point_traces = [t for t in fig.data if t.mode == "markers"]
assert len(point_traces) > 0
for trace in point_traces:
assert trace.customdata is not None, "Missing customdata on point trace"
assert trace.hovertemplate is not None, "Missing hovertemplate"
+
+
+# ── Marker Mapping ──────────────────────────────────────────────────────────
+
+
+class TestMarkerMapping:
+ """Unit tests for _mpl_marker_to_plotly_symbol."""
+
+ def test_common_markers(self):
+ from diff_diff.visualization._common import _mpl_marker_to_plotly_symbol
+
+ assert _mpl_marker_to_plotly_symbol("o") == "circle"
+ assert _mpl_marker_to_plotly_symbol("s") == "square"
+ assert _mpl_marker_to_plotly_symbol("D") == "diamond"
+ assert _mpl_marker_to_plotly_symbol("^") == "triangle-up"
+ assert _mpl_marker_to_plotly_symbol("v") == "triangle-down"
+ assert _mpl_marker_to_plotly_symbol("+") == "cross"
+ assert _mpl_marker_to_plotly_symbol("x") == "x"
+ assert _mpl_marker_to_plotly_symbol("*") == "star"
+
+ def test_dot_marker(self):
+ from diff_diff.visualization._common import _mpl_marker_to_plotly_symbol
+
+ assert _mpl_marker_to_plotly_symbol(".") == "circle"
+
+ def test_unknown_marker_returns_circle(self):
+ from diff_diff.visualization._common import _mpl_marker_to_plotly_symbol
+
+ assert _mpl_marker_to_plotly_symbol("Z") == "circle"
+ assert _mpl_marker_to_plotly_symbol("???") == "circle"
+
+
+# ── Plotly Styling Kwargs ───────────────────────────────────────────────────
+
+
+class TestPlotlyEventStudyStyling:
+ """Verify styling kwargs reach plotly traces."""
+
+ def test_marker_and_size_threaded(self):
+ from diff_diff import plot_event_study
+
+ effects = {-1: 0.0, 0: 0.5, 1: 0.6}
+ se = {-1: 0.1, 0: 0.15, 1: 0.15}
+ fig = plot_event_study(
+ effects=effects,
+ se=se,
+ marker="s",
+ markersize=12,
+ backend="plotly",
+ show=False,
+ )
+ point_traces = [t for t in fig.data if t.mode == "markers"]
+ assert len(point_traces) > 0
+ for trace in point_traces:
+ assert trace.marker.size == 12
+ assert trace.marker.symbol == "square"
+
+ def test_default_marker_values(self):
+ from diff_diff import plot_event_study
+
+ effects = {0: 0.5}
+ se = {0: 0.1}
+ fig = plot_event_study(effects=effects, se=se, backend="plotly", show=False)
+ point_traces = [t for t in fig.data if t.mode == "markers"]
+ assert len(point_traces) > 0
+ # Default: marker="o" -> circle, markersize=8
+ for trace in point_traces:
+ assert trace.marker.size == 8
+ assert trace.marker.symbol == "circle"
+
+
+class TestPlotlyHonestEventStudyStyling:
+ """Verify styling kwargs reach honest event study plotly traces."""
+
+ def test_marker_symbol_threaded(self, honest_results):
+ from diff_diff.visualization import plot_honest_event_study
+
+ fig = plot_honest_event_study(
+ honest_results, marker="D", markersize=14, backend="plotly", show=False
+ )
+ point_traces = [t for t in fig.data if t.mode == "markers"]
+ assert len(point_traces) > 0
+ for trace in point_traces:
+ assert trace.marker.size == 14
+ assert trace.marker.symbol == "diamond"
+
+
+class TestPlotlySensitivityStyling:
+ """Verify ci_linewidth reaches plotly CI line traces."""
+
+ def test_ci_linewidth_threaded(self, sensitivity_results):
+ from diff_diff.visualization import plot_sensitivity
+
+ fig = plot_sensitivity(sensitivity_results, ci_linewidth=3.0, backend="plotly", show=False)
+ # CI traces are lines (not fills) with name "Robust CI"
+ ci_traces = [t for t in fig.data if t.mode == "lines" and t.name == "Robust CI"]
+ assert len(ci_traces) > 0
+ for trace in ci_traces:
+ assert trace.line.width == 3.0
+
+
+class TestPlotlyBaconStyling:
+ """Verify marker/markersize reach Bacon scatter plotly traces."""
+
+ def test_marker_and_size_threaded(self, bacon_results):
+ from diff_diff.visualization import plot_bacon
+
+ fig = plot_bacon(
+ bacon_results,
+ marker="^",
+ markersize=100,
+ backend="plotly",
+ show=False,
+ )
+ scatter_traces = [t for t in fig.data if t.mode == "markers"]
+ assert len(scatter_traces) > 0
+ for trace in scatter_traces:
+ assert trace.marker.symbol == "triangle-up"
+ # sqrt(100) = 10
+ assert trace.marker.size == 10
+
+
+class TestPlotlyPowerCurveStyling:
+ """Verify linewidth and show_grid reach plotly power curve."""
+
+ def test_linewidth_threaded(self):
+ from diff_diff.visualization import plot_power_curve
+
+ fig = plot_power_curve(
+ effect_sizes=[0.1, 0.2, 0.3],
+ powers=[0.3, 0.6, 0.9],
+ linewidth=3.5,
+ backend="plotly",
+ show=False,
+ )
+ line_traces = [t for t in fig.data if t.mode == "lines"]
+ assert len(line_traces) > 0
+ assert line_traces[0].line.width == 3.5
+
+ def test_show_grid_false(self):
+ from diff_diff.visualization import plot_power_curve
+
+ fig = plot_power_curve(
+ effect_sizes=[0.1, 0.2, 0.3],
+ powers=[0.3, 0.6, 0.9],
+ show_grid=False,
+ backend="plotly",
+ show=False,
+ )
+ assert fig.layout.xaxis.showgrid is False
+ assert fig.layout.yaxis.showgrid is False
+
+ def test_show_grid_true(self):
+ from diff_diff.visualization import plot_power_curve
+
+ fig = plot_power_curve(
+ effect_sizes=[0.1, 0.2, 0.3],
+ powers=[0.3, 0.6, 0.9],
+ show_grid=True,
+ backend="plotly",
+ show=False,
+ )
+ assert fig.layout.xaxis.showgrid is True
+ assert fig.layout.yaxis.showgrid is True
+
+
+class TestPlotlyPretrendsPowerStyling:
+ """Verify linewidth and show_grid reach plotly pretrends power curve."""
+
+ def test_linewidth_threaded(self):
+ from diff_diff.visualization import plot_pretrends_power
+
+ fig = plot_pretrends_power(
+ M_values=[0.0, 0.5, 1.0],
+ powers=[0.1, 0.5, 0.8],
+ linewidth=4.0,
+ backend="plotly",
+ show=False,
+ )
+ line_traces = [t for t in fig.data if t.mode == "lines"]
+ assert len(line_traces) > 0
+ assert line_traces[0].line.width == 4.0
+
+ def test_show_grid_false(self):
+ from diff_diff.visualization import plot_pretrends_power
+
+ fig = plot_pretrends_power(
+ M_values=[0.0, 0.5, 1.0],
+ powers=[0.1, 0.5, 0.8],
+ show_grid=False,
+ backend="plotly",
+ show=False,
+ )
+ assert fig.layout.xaxis.showgrid is False
+ assert fig.layout.yaxis.showgrid is False