Skip to content

Commit e063e29

Browse files
igerberclaude
andcommitted
Fix plotly heatmap masking, cmap passthrough, and __all__ exports
- Plotly mask_insignificant now uses grey overlay trace instead of NaN replacement, preserving distinction between insignificant and missing - Remove cmap name swapping — pass through unchanged to plotly - Add 4 new plot functions to diff_diff.__all__ - Add regression tests for heatmap masking, colorscale, and exports Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0c7b234 commit e063e29

3 files changed

Lines changed: 80 additions & 7 deletions

File tree

diff_diff/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@
259259
"plot_group_effects",
260260
"plot_sensitivity",
261261
"plot_honest_event_study",
262+
"plot_synth_weights",
263+
"plot_staircase",
264+
"plot_dose_response",
265+
"plot_group_time_heatmap",
262266
# Parallel trends testing
263267
"check_parallel_trends",
264268
"check_parallel_trends_robust",

diff_diff/visualization/_staggered.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -757,11 +757,9 @@ def _render_group_time_heatmap_plotly(
757757

758758
go = _require_plotly()
759759

760-
# Map matplotlib cmap names to plotly
760+
# Pass cmap name through to plotly unchanged — plotly supports the same
761+
# diverging colorscale names as matplotlib (RdBu, RdBu_r, etc.)
761762
plotly_cmap = cmap
762-
cmap_mapping = {"RdBu_r": "RdBu", "RdBu": "RdBu_r", "coolwarm": "RdBu"}
763-
if cmap in cmap_mapping:
764-
plotly_cmap = cmap_mapping[cmap]
765763

766764
# Build text annotations
767765
text = None
@@ -777,10 +775,11 @@ def _render_group_time_heatmap_plotly(
777775
row.append(f"{val:{fmt}}")
778776
text.append(row)
779777

780-
display = effect_matrix.copy()
778+
# Build significance mask for overlay (do NOT replace with NaN — that
779+
# conflates "insignificant" with "missing cell")
780+
sig_mask = None
781781
if mask_insignificant and p_matrix is not None:
782782
sig_mask = p_matrix > alpha
783-
display = np.where(sig_mask, np.nan, display)
784783

785784
# Center the colorscale
786785
finite_vals = effect_matrix[np.isfinite(effect_matrix)]
@@ -791,9 +790,10 @@ def _render_group_time_heatmap_plotly(
791790
else:
792791
zmin, zmax = -1, 1
793792

793+
# Main heatmap — always shows all values (insignificant cells greyed via opacity)
794794
fig = go.Figure(
795795
data=go.Heatmap(
796-
z=display,
796+
z=effect_matrix,
797797
x=[str(t) for t in time_labels],
798798
y=[str(g) for g in group_labels],
799799
colorscale=plotly_cmap,
@@ -805,6 +805,20 @@ def _render_group_time_heatmap_plotly(
805805
)
806806
)
807807

808+
# Grey overlay for insignificant cells (preserves underlying value)
809+
if sig_mask is not None and np.any(sig_mask):
810+
grey_z = np.where(sig_mask, 1.0, np.nan)
811+
fig.add_trace(
812+
go.Heatmap(
813+
z=grey_z,
814+
x=[str(t) for t in time_labels],
815+
y=[str(g) for g in group_labels],
816+
colorscale=[[0, "rgba(255,255,255,0.6)"], [1, "rgba(255,255,255,0.6)"]],
817+
showscale=False,
818+
hoverinfo="skip",
819+
)
820+
)
821+
808822
_plotly_default_layout(
809823
fig,
810824
title=title,

tests/test_visualization_plotly.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,58 @@ def test_css_color_outside_original_subset(self):
238238
show=False,
239239
)
240240
assert isinstance(fig, go.Figure)
241+
242+
243+
class TestPlotlyHeatmapMasking:
244+
"""Regression: mask_insignificant must grey out, not NaN-ify cells."""
245+
246+
def test_mask_preserves_values(self, cs_results):
247+
from diff_diff.visualization import plot_group_time_heatmap
248+
249+
fig = plot_group_time_heatmap(
250+
cs_results, mask_insignificant=True, backend="plotly", show=False
251+
)
252+
assert isinstance(fig, go.Figure)
253+
# Should have 2 traces: main heatmap + grey overlay
254+
assert len(fig.data) == 2
255+
# Main heatmap should NOT have NaN where cells were insignificant
256+
import numpy as np
257+
258+
main_z = fig.data[0].z
259+
assert np.any(np.isfinite(main_z))
260+
261+
def test_no_mask_single_trace(self, cs_results):
262+
from diff_diff.visualization import plot_group_time_heatmap
263+
264+
fig = plot_group_time_heatmap(
265+
cs_results, mask_insignificant=False, backend="plotly", show=False
266+
)
267+
assert isinstance(fig, go.Figure)
268+
assert len(fig.data) == 1 # Only main heatmap, no overlay
269+
270+
def test_cmap_not_swapped(self, cs_results):
271+
"""RdBu_r should not be swapped — last color should be warm (red)."""
272+
from diff_diff.visualization import plot_group_time_heatmap
273+
274+
fig = plot_group_time_heatmap(cs_results, cmap="RdBu_r", backend="plotly", show=False)
275+
assert isinstance(fig, go.Figure)
276+
# Plotly resolves named colorscales to tuples. RdBu_r ends with red.
277+
cs = fig.data[0].colorscale
278+
# Last entry should be a reddish color (high R value)
279+
last_color = cs[-1][1] # e.g. "rgb(103,0,31)"
280+
assert "103" in last_color or "178" in last_color # dark red end of RdBu_r
281+
282+
283+
class TestTopLevelExports:
284+
"""Regression: new plot functions must be in diff_diff.__all__."""
285+
286+
def test_new_plots_in_all(self):
287+
import diff_diff
288+
289+
for name in [
290+
"plot_synth_weights",
291+
"plot_staircase",
292+
"plot_dose_response",
293+
"plot_group_time_heatmap",
294+
]:
295+
assert name in diff_diff.__all__, f"{name} missing from diff_diff.__all__"

0 commit comments

Comments
 (0)