Skip to content

Commit 3ac550c

Browse files
committed
Add split-chain option to rank overlay plots
Related to #333
1 parent 20910f5 commit 3ac550c

File tree

4 files changed

+113
-4
lines changed

4 files changed

+113
-4
lines changed

R/mcmc-traces.R

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
277277
#' of rank-normalized MCMC samples. Defaults to `20`.
278278
#' @param ref_line For the rank plots, whether to draw a horizontal line at the
279279
#' average number of ranks per bin. Defaults to `FALSE`.
280+
#' @param split_chains Logical indicating whether to split each chain into two parts.
281+
#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
282+
#' Defaults to `FALSE`.
280283
#' @export
281284
mcmc_rank_overlay <- function(x,
282285
pars = character(),
@@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
285288
facet_args = list(),
286289
...,
287290
n_bins = 20,
288-
ref_line = FALSE) {
291+
ref_line = FALSE,
292+
split_chains = FALSE) {
289293
check_ignored_arguments(...)
290294
data <- mcmc_trace_data(
291295
x,
@@ -294,7 +298,28 @@ mcmc_rank_overlay <- function(x,
294298
transformations = transformations
295299
)
296300

297-
n_chains <- unique(data$n_chains)
301+
# Split chains if requested
302+
if (split_chains) {
303+
data$n_chains = data$n_chains/2
304+
data$n_iterations = data$n_iterations/2
305+
# Calculate midpoint for each chain
306+
n_samples <- length(unique(data$iteration))
307+
midpoint <- n_samples/2
308+
309+
# Create new data frame with split chains
310+
data <- data %>%
311+
group_by(.data$chain) %>%
312+
mutate(
313+
chain = ifelse(
314+
iteration <= midpoint,
315+
paste0(.data$chain, "_1"),
316+
paste0(.data$chain, "_2")
317+
)
318+
) %>%
319+
ungroup()
320+
}
321+
322+
n_chains <- length(unique(data$chain))
298323
n_param <- unique(data$n_parameters)
299324

300325
# We have to bin and count the data ourselves because
@@ -319,6 +344,7 @@ mcmc_rank_overlay <- function(x,
319344
bin_start = unique(histobins$bin_start),
320345
stringsAsFactors = FALSE
321346
))
347+
322348
d_bin_counts <- all_combos %>%
323349
left_join(d_bin_counts, by = c("parameter", "chain", "bin_start")) %>%
324350
mutate(n = dplyr::if_else(is.na(n), 0L, n))
@@ -331,7 +357,9 @@ mcmc_rank_overlay <- function(x,
331357
mutate(bin_start = right_edge) %>%
332358
dplyr::bind_rows(d_bin_counts)
333359

334-
scale_color <- scale_color_manual("Chain", values = chain_colors(n_chains))
360+
# Update legend title based on split_chains
361+
legend_title <- if (split_chains) "Split Chains" else "Chain"
362+
scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chains))
335363

336364
layer_ref_line <- if (ref_line) {
337365
geom_hline(
@@ -352,7 +380,7 @@ mcmc_rank_overlay <- function(x,
352380
}
353381

354382
ggplot(d_bin_counts) +
355-
aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
383+
aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
356384
geom_step() +
357385
layer_ref_line +
358386
facet_call +
Lines changed: 67 additions & 0 deletions
Loading

tests/testthat/data-for-mcmc-tests.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,11 @@ vdiff_dframe_rank_overlay_bins_test <- posterior::as_draws_df(
8080
)
8181
)
8282

83+
vdiff_dframe_rank_overlay_split_chain_test <- posterior::as_draws_df(
84+
list(
85+
list(theta = -2 + 0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)),
86+
list(theta = 1 + -0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5))
87+
)
88+
)
89+
8390
set.seed(seed = NULL)

tests/testthat/test-mcmc-traces.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ test_that("mcmc_rank_overlay renders correctly", {
157157
# https://github.com/stan-dev/bayesplot/issues/331
158158
p_not_all_bins_exist <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_bins_test)
159159

160+
# https://github.com/stan-dev/bayesplot/issues/333
161+
p_split_chains <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_split_chain_test,
162+
split_chains = TRUE)
163+
160164
vdiffr::expect_doppelganger("mcmc_rank_overlay (default)", p_base)
161165
vdiffr::expect_doppelganger(
162166
"mcmc_rank_overlay (reference line)",
@@ -170,6 +174,9 @@ test_that("mcmc_rank_overlay renders correctly", {
170174

171175
# https://github.com/stan-dev/bayesplot/issues/331
172176
vdiffr::expect_doppelganger("mcmc_rank_overlay (not all bins)", p_not_all_bins_exist)
177+
178+
# https://github.com/stan-dev/bayesplot/issues/333
179+
vdiffr::expect_doppelganger("mcmc_rank_overlay (split chains)", p_split_chains)
173180
})
174181

175182
test_that("mcmc_rank_hist renders correctly", {

0 commit comments

Comments
 (0)