@@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
277
277
# ' of rank-normalized MCMC samples. Defaults to `20`.
278
278
# ' @param ref_line For the rank plots, whether to draw a horizontal line at the
279
279
# ' average number of ranks per bin. Defaults to `FALSE`.
280
+ # ' @param split_chains Logical indicating whether to split each chain into two parts.
281
+ # ' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
282
+ # ' Defaults to `FALSE`.
280
283
# ' @export
281
284
mcmc_rank_overlay <- function (x ,
282
285
pars = character (),
@@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
285
288
facet_args = list (),
286
289
... ,
287
290
n_bins = 20 ,
288
- ref_line = FALSE ) {
291
+ ref_line = FALSE ,
292
+ split_chains = FALSE ) {
289
293
check_ignored_arguments(... )
290
294
data <- mcmc_trace_data(
291
295
x ,
@@ -294,7 +298,28 @@ mcmc_rank_overlay <- function(x,
294
298
transformations = transformations
295
299
)
296
300
297
- n_chains <- unique(data $ n_chains )
301
+ # Split chains if requested
302
+ if (split_chains ) {
303
+ data $ n_chains = data $ n_chains / 2
304
+ data $ n_iterations = data $ n_iterations / 2
305
+ # Calculate midpoint for each chain
306
+ n_samples <- length(unique(data $ iteration ))
307
+ midpoint <- n_samples / 2
308
+
309
+ # Create new data frame with split chains
310
+ data <- data %> %
311
+ group_by(.data $ chain ) %> %
312
+ mutate(
313
+ chain = ifelse(
314
+ iteration < = midpoint ,
315
+ paste0(.data $ chain , " _1" ),
316
+ paste0(.data $ chain , " _2" )
317
+ )
318
+ ) %> %
319
+ ungroup()
320
+ }
321
+
322
+ n_chains <- length(unique(data $ chain ))
298
323
n_param <- unique(data $ n_parameters )
299
324
300
325
# We have to bin and count the data ourselves because
@@ -319,6 +344,7 @@ mcmc_rank_overlay <- function(x,
319
344
bin_start = unique(histobins $ bin_start ),
320
345
stringsAsFactors = FALSE
321
346
))
347
+
322
348
d_bin_counts <- all_combos %> %
323
349
left_join(d_bin_counts , by = c(" parameter" , " chain" , " bin_start" )) %> %
324
350
mutate(n = dplyr :: if_else(is.na(n ), 0L , n ))
@@ -331,7 +357,9 @@ mcmc_rank_overlay <- function(x,
331
357
mutate(bin_start = right_edge ) %> %
332
358
dplyr :: bind_rows(d_bin_counts )
333
359
334
- scale_color <- scale_color_manual(" Chain" , values = chain_colors(n_chains ))
360
+ # Update legend title based on split_chains
361
+ legend_title <- if (split_chains ) " Split Chains" else " Chain"
362
+ scale_color <- scale_color_manual(legend_title , values = chain_colors(n_chains ))
335
363
336
364
layer_ref_line <- if (ref_line ) {
337
365
geom_hline(
@@ -352,7 +380,7 @@ mcmc_rank_overlay <- function(x,
352
380
}
353
381
354
382
ggplot(d_bin_counts ) +
355
- aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
383
+ aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
356
384
geom_step() +
357
385
layer_ref_line +
358
386
facet_call +
0 commit comments