diff --git a/TODO.md b/TODO.md index 264e76ae..5ae73fa9 100644 --- a/TODO.md +++ b/TODO.md @@ -71,7 +71,6 @@ Deferred items from PR reviews that were not addressed before merge. | Issue | Location | PR | Priority | |-------|----------|----|----------| -| Plotly renderers silently ignore styling kwargs (marker, markersize, linewidth, capsize, ci_linewidth) that the matplotlib backend honors; thread them through or reject when `backend="plotly"` | `visualization/_event_study.py`, `_diagnostic.py`, `_power.py` | #222 | Medium | | R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low | | CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low | | ~376 `duplicate object description` Sphinx warnings — restructure `docs/api/*.rst` to avoid duplicate `:members:` + `autosummary` | `docs/api/*.rst` | — | Low | diff --git a/diff_diff/visualization/_common.py b/diff_diff/visualization/_common.py index 1c08ca3f..464398c6 100644 --- a/diff_diff/visualization/_common.py +++ b/diff_diff/visualization/_common.py @@ -283,6 +283,42 @@ def _color_to_rgba(color, alpha=1.0): ) +# Matplotlib marker code -> plotly symbol name mapping +_MPL_TO_PLOTLY_SYMBOL = { + "o": "circle", + "s": "square", + "D": "diamond", + "d": "diamond", + "^": "triangle-up", + "v": "triangle-down", + "<": "triangle-left", + ">": "triangle-right", + "p": "pentagon", + "h": "hexagon", + "+": "cross", + "x": "x", + "*": "star", + ".": "circle", +} + + +def _mpl_marker_to_plotly_symbol(marker): + """Convert a matplotlib marker code to a plotly symbol name. + + Parameters + ---------- + marker : str + Matplotlib marker shorthand (e.g., ``"o"``, ``"s"``, ``"D"``). + + Returns + ------- + str + Plotly symbol name (e.g., ``"circle"``, ``"square"``, ``"diamond"``). + Returns ``"circle"`` for unrecognized markers. + """ + return _MPL_TO_PLOTLY_SYMBOL.get(marker, "circle") + + # Default color constants DEFAULT_BLUE = "#2563eb" DEFAULT_RED = "#dc2626" diff --git a/diff_diff/visualization/_diagnostic.py b/diff_diff/visualization/_diagnostic.py index ef876f55..66320ce5 100644 --- a/diff_diff/visualization/_diagnostic.py +++ b/diff_diff/visualization/_diagnostic.py @@ -111,6 +111,7 @@ def plot_sensitivity( bounds_color=bounds_color, bounds_alpha=bounds_alpha, ci_color=ci_color, + ci_linewidth=ci_linewidth, breakdown_color=breakdown_color, original_color=original_color, show=show, @@ -242,6 +243,7 @@ def _render_sensitivity_plotly( bounds_color, bounds_alpha, ci_color, + ci_linewidth, breakdown_color, original_color, show, @@ -291,7 +293,7 @@ def _render_sensitivity_plotly( x=M_list, y=list(ci_arr[:, 0]), mode="lines", - line=dict(color=ci_color, width=1.5), + line=dict(color=ci_color, width=ci_linewidth), name="Robust CI", ) ) @@ -300,7 +302,7 @@ def _render_sensitivity_plotly( x=M_list, y=list(ci_arr[:, 1]), mode="lines", - line=dict(color=ci_color, width=1.5), + line=dict(color=ci_color, width=ci_linewidth), showlegend=False, ) ) @@ -449,6 +451,8 @@ def plot_bacon( xlabel=xlabel, ylabel=ylabel, colors=colors, + marker=marker, + markersize=markersize, alpha=alpha, show_weighted_avg=show_weighted_avg, show_twfe_line=show_twfe_line, @@ -699,13 +703,19 @@ def _render_bacon_plotly( xlabel, ylabel, colors, + marker, + markersize, alpha, show_weighted_avg, show_twfe_line, show, ): """Render Bacon decomposition plot with plotly.""" - from diff_diff.visualization._common import _plotly_default_layout, _require_plotly + from diff_diff.visualization._common import ( + _mpl_marker_to_plotly_symbol, + _plotly_default_layout, + _require_plotly, + ) go = _require_plotly() @@ -727,6 +737,10 @@ def _render_bacon_plotly( "later_vs_earlier": "Later vs Earlier (forbidden)", } + # Convert matplotlib scatter area (points^2) to plotly diameter (px) + plotly_size = max(1, int(round(markersize**0.5))) + symbol = _mpl_marker_to_plotly_symbol(marker) + for ctype, points in by_type.items(): if not points: continue @@ -737,7 +751,12 @@ def _render_bacon_plotly( x=estimates, y=weights, mode="markers", - marker=dict(color=colors[ctype], size=10, opacity=alpha), + marker=dict( + color=colors[ctype], + size=plotly_size, + symbol=symbol, + opacity=alpha, + ), name=labels[ctype], ) ) diff --git a/diff_diff/visualization/_event_study.py b/diff_diff/visualization/_event_study.py index 4deea681..dbc9cb0a 100644 --- a/diff_diff/visualization/_event_study.py +++ b/diff_diff/visualization/_event_study.py @@ -272,6 +272,8 @@ def plot_event_study( xlabel=xlabel, ylabel=ylabel, color=color, + marker=marker, + markersize=markersize, shade_pre=shade_pre, shade_color=shade_color, show_zero_line=show_zero_line, @@ -422,6 +424,8 @@ def _render_event_study_plotly( xlabel, ylabel, color, + marker, + markersize, shade_pre, shade_color, show_zero_line, @@ -431,6 +435,7 @@ def _render_event_study_plotly( """Render event study plot with plotly.""" from diff_diff.visualization._common import ( _color_to_rgba, + _mpl_marker_to_plotly_symbol, _plotly_default_layout, _require_plotly, ) @@ -504,13 +509,15 @@ def _render_event_study_plotly( hover_tpl = "Period: %{customdata}
Effect: %{y:.4f}" + symbol = _mpl_marker_to_plotly_symbol(marker) + if non_ref_x: fig.add_trace( go.Scatter( x=non_ref_x, y=non_ref_e, mode="markers", - marker=dict(color=color, size=10), + marker=dict(color=color, size=markersize, symbol=symbol), name="Effect", customdata=non_ref_labels, hovertemplate=hover_tpl, @@ -525,7 +532,8 @@ def _render_event_study_plotly( mode="markers", marker=dict( color="white", - size=10, + size=markersize, + symbol=symbol, line=dict(color=color, width=2), ), name="Reference", @@ -842,6 +850,8 @@ def plot_honest_event_study( ylabel=ylabel, original_color=original_color, honest_color=honest_color, + marker=marker, + markersize=markersize, show=show, ) @@ -987,11 +997,14 @@ def _render_honest_event_study_plotly( ylabel, original_color, honest_color, + marker, + markersize, show, ): """Render honest event study plot with plotly.""" from diff_diff.visualization._common import ( _color_to_rgba, + _mpl_marker_to_plotly_symbol, _plotly_default_layout, _require_plotly, ) @@ -1036,13 +1049,15 @@ def _render_honest_event_study_plotly( ref_p = [p for p, r in zip(periods, is_ref) if r] ref_e = [e for e, r in zip(effects, is_ref) if r] + symbol = _mpl_marker_to_plotly_symbol(marker) + if non_ref_p: fig.add_trace( go.Scatter( x=non_ref_p, y=non_ref_e, mode="markers", - marker=dict(color=honest_color, size=10), + marker=dict(color=honest_color, size=markersize, symbol=symbol), name="Effect", ) ) @@ -1053,7 +1068,12 @@ def _render_honest_event_study_plotly( x=ref_p, y=ref_e, mode="markers", - marker=dict(color="white", size=10, line=dict(color=honest_color, width=2)), + marker=dict( + color="white", + size=markersize, + symbol=symbol, + line=dict(color=honest_color, width=2), + ), name="Reference", ) ) diff --git a/diff_diff/visualization/_power.py b/diff_diff/visualization/_power.py index fd610df7..141352d9 100644 --- a/diff_diff/visualization/_power.py +++ b/diff_diff/visualization/_power.py @@ -165,8 +165,10 @@ def plot_power_curve( color=color, mde_color=mde_color, target_color=target_color, + linewidth=linewidth, show_mde_line=show_mde_line, show_target_line=show_target_line, + show_grid=show_grid, show=show, ) @@ -291,8 +293,10 @@ def _render_power_curve_plotly( color, mde_color, target_color, + linewidth, show_mde_line, show_target_line, + show_grid, show, ): """Render power curve with plotly.""" @@ -307,7 +311,7 @@ def _render_power_curve_plotly( x=effect_sizes, y=powers, mode="lines", - line=dict(color=color, width=2), + line=dict(color=color, width=linewidth), name="Power", ) ) @@ -331,7 +335,8 @@ def _render_power_curve_plotly( ) _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel) - fig.update_yaxes(range=[0, 1.05], tickformat=".0%") + fig.update_xaxes(showgrid=show_grid) + fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid) if show: fig.show() @@ -482,8 +487,10 @@ def plot_pretrends_power( color=color, mdv_color=mdv_color, target_color=target_color, + linewidth=linewidth, show_mdv_line=show_mdv_line, show_target_line=show_target_line, + show_grid=show_grid, show=show, ) @@ -602,8 +609,10 @@ def _render_pretrends_power_plotly( color, mdv_color, target_color, + linewidth, show_mdv_line, show_target_line, + show_grid, show, ): """Render pre-trends power curve with plotly.""" @@ -619,7 +628,7 @@ def _render_pretrends_power_plotly( x=M_values, y=powers, mode="lines", - line=dict(color=color, width=2), + line=dict(color=color, width=linewidth), name="Power", ) ) @@ -643,7 +652,8 @@ def _render_pretrends_power_plotly( ) _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel) - fig.update_yaxes(range=[0, 1.05], tickformat=".0%") + fig.update_xaxes(showgrid=show_grid) + fig.update_yaxes(range=[0, 1.05], tickformat=".0%", showgrid=show_grid) if show: fig.show() diff --git a/tests/test_visualization_plotly.py b/tests/test_visualization_plotly.py index c6428356..e84f0f47 100644 --- a/tests/test_visualization_plotly.py +++ b/tests/test_visualization_plotly.py @@ -318,12 +318,205 @@ def test_string_periods_in_customdata(self): effects = {"pre": 0.0, "post": 0.5} se = {"pre": 0.1, "post": 0.15} - fig = plot_event_study( - effects=effects, se=se, backend="plotly", show=False - ) + fig = plot_event_study(effects=effects, se=se, backend="plotly", show=False) # At least one point trace should have customdata with original labels point_traces = [t for t in fig.data if t.mode == "markers"] assert len(point_traces) > 0 for trace in point_traces: assert trace.customdata is not None, "Missing customdata on point trace" assert trace.hovertemplate is not None, "Missing hovertemplate" + + +# ── Marker Mapping ────────────────────────────────────────────────────────── + + +class TestMarkerMapping: + """Unit tests for _mpl_marker_to_plotly_symbol.""" + + def test_common_markers(self): + from diff_diff.visualization._common import _mpl_marker_to_plotly_symbol + + assert _mpl_marker_to_plotly_symbol("o") == "circle" + assert _mpl_marker_to_plotly_symbol("s") == "square" + assert _mpl_marker_to_plotly_symbol("D") == "diamond" + assert _mpl_marker_to_plotly_symbol("^") == "triangle-up" + assert _mpl_marker_to_plotly_symbol("v") == "triangle-down" + assert _mpl_marker_to_plotly_symbol("+") == "cross" + assert _mpl_marker_to_plotly_symbol("x") == "x" + assert _mpl_marker_to_plotly_symbol("*") == "star" + + def test_dot_marker(self): + from diff_diff.visualization._common import _mpl_marker_to_plotly_symbol + + assert _mpl_marker_to_plotly_symbol(".") == "circle" + + def test_unknown_marker_returns_circle(self): + from diff_diff.visualization._common import _mpl_marker_to_plotly_symbol + + assert _mpl_marker_to_plotly_symbol("Z") == "circle" + assert _mpl_marker_to_plotly_symbol("???") == "circle" + + +# ── Plotly Styling Kwargs ─────────────────────────────────────────────────── + + +class TestPlotlyEventStudyStyling: + """Verify styling kwargs reach plotly traces.""" + + def test_marker_and_size_threaded(self): + from diff_diff import plot_event_study + + effects = {-1: 0.0, 0: 0.5, 1: 0.6} + se = {-1: 0.1, 0: 0.15, 1: 0.15} + fig = plot_event_study( + effects=effects, + se=se, + marker="s", + markersize=12, + backend="plotly", + show=False, + ) + point_traces = [t for t in fig.data if t.mode == "markers"] + assert len(point_traces) > 0 + for trace in point_traces: + assert trace.marker.size == 12 + assert trace.marker.symbol == "square" + + def test_default_marker_values(self): + from diff_diff import plot_event_study + + effects = {0: 0.5} + se = {0: 0.1} + fig = plot_event_study(effects=effects, se=se, backend="plotly", show=False) + point_traces = [t for t in fig.data if t.mode == "markers"] + assert len(point_traces) > 0 + # Default: marker="o" -> circle, markersize=8 + for trace in point_traces: + assert trace.marker.size == 8 + assert trace.marker.symbol == "circle" + + +class TestPlotlyHonestEventStudyStyling: + """Verify styling kwargs reach honest event study plotly traces.""" + + def test_marker_symbol_threaded(self, honest_results): + from diff_diff.visualization import plot_honest_event_study + + fig = plot_honest_event_study( + honest_results, marker="D", markersize=14, backend="plotly", show=False + ) + point_traces = [t for t in fig.data if t.mode == "markers"] + assert len(point_traces) > 0 + for trace in point_traces: + assert trace.marker.size == 14 + assert trace.marker.symbol == "diamond" + + +class TestPlotlySensitivityStyling: + """Verify ci_linewidth reaches plotly CI line traces.""" + + def test_ci_linewidth_threaded(self, sensitivity_results): + from diff_diff.visualization import plot_sensitivity + + fig = plot_sensitivity(sensitivity_results, ci_linewidth=3.0, backend="plotly", show=False) + # CI traces are lines (not fills) with name "Robust CI" + ci_traces = [t for t in fig.data if t.mode == "lines" and t.name == "Robust CI"] + assert len(ci_traces) > 0 + for trace in ci_traces: + assert trace.line.width == 3.0 + + +class TestPlotlyBaconStyling: + """Verify marker/markersize reach Bacon scatter plotly traces.""" + + def test_marker_and_size_threaded(self, bacon_results): + from diff_diff.visualization import plot_bacon + + fig = plot_bacon( + bacon_results, + marker="^", + markersize=100, + backend="plotly", + show=False, + ) + scatter_traces = [t for t in fig.data if t.mode == "markers"] + assert len(scatter_traces) > 0 + for trace in scatter_traces: + assert trace.marker.symbol == "triangle-up" + # sqrt(100) = 10 + assert trace.marker.size == 10 + + +class TestPlotlyPowerCurveStyling: + """Verify linewidth and show_grid reach plotly power curve.""" + + def test_linewidth_threaded(self): + from diff_diff.visualization import plot_power_curve + + fig = plot_power_curve( + effect_sizes=[0.1, 0.2, 0.3], + powers=[0.3, 0.6, 0.9], + linewidth=3.5, + backend="plotly", + show=False, + ) + line_traces = [t for t in fig.data if t.mode == "lines"] + assert len(line_traces) > 0 + assert line_traces[0].line.width == 3.5 + + def test_show_grid_false(self): + from diff_diff.visualization import plot_power_curve + + fig = plot_power_curve( + effect_sizes=[0.1, 0.2, 0.3], + powers=[0.3, 0.6, 0.9], + show_grid=False, + backend="plotly", + show=False, + ) + assert fig.layout.xaxis.showgrid is False + assert fig.layout.yaxis.showgrid is False + + def test_show_grid_true(self): + from diff_diff.visualization import plot_power_curve + + fig = plot_power_curve( + effect_sizes=[0.1, 0.2, 0.3], + powers=[0.3, 0.6, 0.9], + show_grid=True, + backend="plotly", + show=False, + ) + assert fig.layout.xaxis.showgrid is True + assert fig.layout.yaxis.showgrid is True + + +class TestPlotlyPretrendsPowerStyling: + """Verify linewidth and show_grid reach plotly pretrends power curve.""" + + def test_linewidth_threaded(self): + from diff_diff.visualization import plot_pretrends_power + + fig = plot_pretrends_power( + M_values=[0.0, 0.5, 1.0], + powers=[0.1, 0.5, 0.8], + linewidth=4.0, + backend="plotly", + show=False, + ) + line_traces = [t for t in fig.data if t.mode == "lines"] + assert len(line_traces) > 0 + assert line_traces[0].line.width == 4.0 + + def test_show_grid_false(self): + from diff_diff.visualization import plot_pretrends_power + + fig = plot_pretrends_power( + M_values=[0.0, 0.5, 1.0], + powers=[0.1, 0.5, 0.8], + show_grid=False, + backend="plotly", + show=False, + ) + assert fig.layout.xaxis.showgrid is False + assert fig.layout.yaxis.showgrid is False