Skip to content

Add mcse plots similar to neff and rhat plots #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method("[",mcse_ratio)
S3method("[",neff_ratio)
S3method("[",rhat)
S3method(log_posterior,CmdStanMCMC)
@@ -63,6 +64,9 @@ export(mcmc_hist)
export(mcmc_hist_by_chain)
export(mcmc_intervals)
export(mcmc_intervals_data)
export(mcmc_mcse)
export(mcmc_mcse_data)
export(mcmc_mcse_hist)
export(mcmc_neff)
export(mcmc_neff_data)
export(mcmc_neff_hist)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -12,6 +12,10 @@
* `mcmc_areas()` and `mcmc_areas_ridges()` gain an argument `border_size` for
controlling the thickness of the ridgelines. (#224)

* New plotting functions `mcmc_mcse()` and `mcmc_mcse_hist()` that are similar
to `mcmc_neff()` and `mcmc_neff_hist()` but for plotting ratios of MCSE to
posterior SD. (#278, @VeenDuco)

* New plotting function `ppc_km_overlay_grouped()`, the grouped variant of
`ppc_km_overlay()`. (#260, @fweber144)

182 changes: 164 additions & 18 deletions R/mcmc-diagnostics.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#' General MCMC diagnostics
#'
#' Plots of Rhat statistics, ratios of effective sample size to total sample
#' size, and autocorrelation of MCMC draws. See the **Plot Descriptions**
#' section, below, for details. For models fit using the No-U-Turn-Sampler, see
#' also [MCMC-nuts] for additional MCMC diagnostic plots.
#' size, ratios of MCSE to posterior SD, and autocorrelation of MCMC draws. See
#' the **Plot Descriptions** section, below, for details. For models fit using
#' the No-U-Turn-Sampler, see also [MCMC-nuts] for additional MCMC diagnostic
#' plots.
#'
#' @name MCMC-diagnostics
#' @family MCMC
#'
#' @param ratio For effective sample size plots, a vector of *ratios* of
#' effective sample size estimates to total sample sizes (see [neff_ratio()]).
#' For MCSE plots, a vector of *ratios* of Monte Carlo standard errors to
#' posterior standard deviations.
#' @template args-hist
#' @param size An optional value to override [ggplot2::geom_point()]'s
#' default size (for `mcmc_rhat()`, `mcmc_neff()`) or
@@ -32,9 +37,19 @@
#' histogram. Values are colored using different shades (lighter is better).
#' The chosen thresholds are somewhat arbitrary, but can be useful guidelines
#' in practice.
#' * _light_: between 0.5 and 1 (high)
#' * _mid_: between 0.1 and 0.5 (good)
#' * _dark_: below 0.1 (low)
#' * _light_: between 0.5 and 1 (good)
#' * _mid_: between 0.1 and 0.5 (ok)
#' * _dark_: below 0.1 (too low)
#' }
#'
#' \item{`mcmc_mcse()`, `mcmc_mcse_hist()`}{
#' Ratios of Monte Carlo standard errors to posterior standard deviations as
#' either points or a histogram. Values are colored using different shades
#' (lighter is better). The chosen thresholds are somewhat arbitrary, but can
#' be useful guidelines in practice.
#' * _light_: below 0.05 (good)
#' * _mid_: between 0.05 and 0.1 (ok)
#' * _dark_: above 0.1 (too high)
#' }
#'
#' \item{`mcmc_acf()`, `mcmc_acf_bar()`}{
@@ -91,6 +106,11 @@
#' mcmc_neff_hist(ratio)
#' mcmc_neff(ratio)
#'
#' # fake mcse ratio values to use for demonstration
#' ratio <- c(runif(100, 0, 1.5))
#' mcmc_mcse_hist(ratio)
#' mcmc_mcse(ratio)
#'
#' \dontrun{
#' # Example using rstanarm model (requires rstanarm package)
#' library(rstanarm)
@@ -210,8 +230,6 @@ mcmc_rhat_data <- function(rhat, ...) {

#' @rdname MCMC-diagnostics
#' @export
#' @param ratio A vector of *ratios* of effective sample size estimates to
#' total sample size. See [neff_ratio()].
#'
mcmc_neff <- function(ratio, ..., size = NULL) {
check_ignored_arguments(...)
@@ -294,6 +312,93 @@ mcmc_neff_data <- function(ratio, ...) {
diagnostic_data_frame(ratio)
}

# monte carlo standard error -------------------------------------------

#' @rdname MCMC-diagnostics
#' @export
#'
mcmc_mcse <- function(ratio, ..., size = NULL) {
check_ignored_arguments(...)
data <- mcmc_mcse_data(ratio)

max_ratio <- max(ratio, na.rm = TRUE)
if (max_ratio < 1.25) {
additional_breaks <- numeric(0)
} else if (max_ratio < 1.5) {
additional_breaks <- 1.25
additional_labels <- "1.25"
} else {
additional_breaks <- seq(1.5, max_ratio, by = 0.5)
}
breaks <- c(0, 0.1, 0.25, 0.5, 0.75, 1, additional_breaks)

ggplot(
data,
mapping = aes_(
x = ~ value,
y = ~ parameter,
color = ~ rating,
fill = ~ rating)) +
geom_segment(
aes_(yend = ~ parameter, xend = -Inf),
na.rm = TRUE) +
diagnostic_points(size) +
vline_at(
c(0.1, 0.5, 1),
color = "gray",
linetype = 2,
size = 0.25) +
labs(y = NULL, x = expression(mcse/sd)) +
scale_fill_diagnostic("mcse") +
scale_color_diagnostic("mcse") +
scale_x_continuous(
breaks = breaks,
# as.character truncates trailing zeroes, while ggplot default does not
labels = as.character(breaks),
limits = c(0, max(1, max_ratio) + 0.05),
expand = c(0, 0)) +
bayesplot_theme_get() +
yaxis_text(FALSE) +
yaxis_title(FALSE) +
yaxis_ticks(FALSE)
}

#' @rdname MCMC-diagnostics
#' @export
mcmc_mcse_hist <- function(ratio, ..., binwidth = NULL, breaks = NULL) {
check_ignored_arguments(...)
data <- mcmc_mcse_data(ratio)

ggplot(
data,
mapping = aes_(
x = ~ value,
color = ~ rating,
fill = ~ rating)) +
geom_histogram(
size = .25,
na.rm = TRUE,
binwidth = binwidth,
breaks = breaks) +
scale_color_diagnostic("mcse") +
scale_fill_diagnostic("mcse") +
labs(x = expression(mcse/sd), y = NULL) +
dont_expand_y_axis(c(0.005, 0)) +
yaxis_title(FALSE) +
yaxis_text(FALSE) +
yaxis_ticks(FALSE) +
bayesplot_theme_get()
}

#' @rdname MCMC-diagnostics
#' @export
mcmc_mcse_data <- function(ratio, ...) {
check_ignored_arguments(...)
ratio <- drop_NAs_and_warn(new_mcse_ratio(ratio))
diagnostic_data_frame(ratio)
}



# autocorrelation ---------------------------------------------------------

@@ -354,7 +459,7 @@ mcmc_acf_bar <-
#'
#' @param x A numeric vector.
#' @param breaks A numeric vector of length two. The resulting factor variable
#' will have three levels ('low', 'ok', and 'high') corresponding to (
#' will have three levels ('low', 'mid', and 'high') corresponding to (
#' `x <= breaks[1]`, `breaks[1] < x <= breaks[2]`, `x > breaks[2]`).
#' @return A factor the same length as `x` with three levels.
#' @noRd
@@ -364,13 +469,19 @@ diagnostic_factor <- function(x, breaks, ...) {

diagnostic_factor.rhat <- function(x, breaks = c(1.05, 1.1)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "ok", "high"),
labels = c("low", "mid", "high"),
ordered_result = FALSE)
}

diagnostic_factor.neff_ratio <- function(x, breaks = c(0.1, 0.5)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "ok", "high"),
labels = c("low", "mid", "high"),
ordered_result = FALSE)
}

diagnostic_factor.mcse_ratio <- function(x, breaks = c(0.05, 0.1)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "mid", "high"),
ordered_result = FALSE)
}

@@ -411,17 +522,17 @@ diagnostic_points <- function(size = NULL) {

# Functions wrapping around scale_color_manual() and scale_fill_manual(), used to
# color the intervals by rhat value
scale_color_diagnostic <- function(diagnostic = c("rhat", "neff")) {
scale_color_diagnostic <- function(diagnostic = c("rhat", "neff", "mcse")) {
d <- match.arg(diagnostic)
diagnostic_color_scale(d, aesthetic = "color")
}

scale_fill_diagnostic <- function(diagnostic = c("rhat", "neff")) {
scale_fill_diagnostic <- function(diagnostic = c("rhat", "neff", "mcse")) {
d <- match.arg(diagnostic)
diagnostic_color_scale(d, aesthetic = "fill")
}

diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio"),
diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio", "mcse_ratio"),
aesthetic = c("color", "fill")) {
diagnostic <- match.arg(diagnostic)
aesthetic <- match.arg(aesthetic)
@@ -437,14 +548,17 @@ diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio"),
)
}

diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio"),
diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio", "mcse_ratio"),
aesthetic = c("color", "fill")) {
diagnostic <- match.arg(diagnostic)
aesthetic <- match.arg(aesthetic)
color_levels <- c("light", "mid", "dark")
if (diagnostic == "neff_ratio") {
color_levels <- rev(color_levels)
}
if (diagnostic == "mcse_ratio") {
color_levels <- color_levels
}
if (aesthetic == "color") {
color_levels <- paste0(color_levels, "_highlight")
}
@@ -455,19 +569,24 @@ diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio"),
aesthetic = aesthetic,
color_levels = color_levels,
color_labels = color_labels,
values = set_names(get_color(color_levels), c("low", "ok", "high")))
values = set_names(get_color(color_levels), c("low", "mid", "high")))
}

diagnostic_color_labels <- list(
rhat = c(
low = expression(hat(R) <= 1.05),
ok = expression(hat(R) <= 1.10),
mid = expression(hat(R) <= 1.10),
high = expression(hat(R) > 1.10)
),
neff_ratio = c(
low = expression(N[eff] / N <= 0.1),
ok = expression(N[eff] / N <= 0.5),
mid = expression(N[eff] / N <= 0.5),
high = expression(N[eff] / N > 0.5)
),
mcse_ratio = c(
low = expression(mcse / sd <= 0.05),
mid = expression(mcse / sd <= 0.1),
high = expression(mcse / sd > 0.1)
)
)

@@ -662,3 +781,30 @@ as_neff_ratio <- function(x) {
as_neff_ratio(NextMethod())
}

new_mcse_ratio <- function(x) {
# Convert a 1-d arrays to a vectors
if (is.array(x) && length(dim(x)) == 1) {
x <- as.vector(x)
}
as_mcse_ratio(validate_mcse_ratio(x))
}

validate_mcse_ratio <- function(x) {
stopifnot(is.numeric(x), !is.list(x), !is.array(x))
if (any(x < 0, na.rm = TRUE)) {
abort("All mcse ratios must be positive.")
}
x
}

as_mcse_ratio <- function(x) {
structure(x, class = c("mcse_ratio", "numeric"), names = names(x))
}

#' Indexing method -- needed so that sort, etc. don't strip names.
#' @export
#' @keywords internal
#' @noRd
`[.mcse_ratio` <- function (x, i, j, drop = TRUE, ...) {
as_mcse_ratio(NextMethod())
}
45 changes: 37 additions & 8 deletions man/MCMC-diagnostics.Rd
171 changes: 171 additions & 0 deletions tests/testthat/_snaps/mcmc-diagnostics/mcmc-mcse-default.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
179 changes: 179 additions & 0 deletions tests/testthat/_snaps/mcmc-diagnostics/mcmc-mcse-hist-default.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions tests/testthat/test-helpers-mcmc.R
Original file line number Diff line number Diff line change
@@ -299,20 +299,20 @@ test_that("tidy parameter selection throws correct errors", {
# rhat and neff helpers ---------------------------------------------------
test_that("diagnostic_factor.rhat works", {
rhats <- new_rhat(c(low = 0.99, low = 1, low = 1.01,
ok = 1.06, ok = 1.09, ok = 1.1,
mid = 1.06, mid = 1.09, mid = 1.1,
high = 1.2, high = 1.7))

r <- diagnostic_factor(unname(rhats))
expect_equivalent(r, as.factor(names(rhats)))
expect_identical(levels(r), c("low", "ok", "high"))
expect_identical(levels(r), c("low", "mid", "high"))
})
test_that("diagnostic_factor.neff_ratio works", {
ratios <- new_neff_ratio(c(low = 0.05, low = 0.01,
ok = 0.2, ok = 0.49,
mid = 0.2, mid = 0.49,
high = 0.51, high = 0.99, high = 1))

r <- diagnostic_factor(unname(ratios))
expect_equivalent(r, as.factor(names(ratios)))
expect_identical(levels(r), c("low", "ok", "high"))
expect_identical(levels(r), c("low", "mid", "high"))
})

43 changes: 38 additions & 5 deletions tests/testthat/test-mcmc-diagnostics.R
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ context("MCMC: diagnostics")

source(test_path("data-for-mcmc-tests.R"))

test_that("rhat and neff plots return a ggplot object", {
test_that("rhat, neff, and mcse plots return a ggplot object", {
rhat <- runif(100, 1, 1.5)
expect_gg(mcmc_rhat(rhat))
expect_gg(mcmc_rhat_hist(rhat))
@@ -12,35 +12,45 @@ test_that("rhat and neff plots return a ggplot object", {
expect_gg(mcmc_neff(ratio))
expect_gg(mcmc_neff_hist(ratio))

expect_gg(mcmc_mcse(ratio))
expect_gg(mcmc_mcse(ratio))

# 1-D array ok
expect_gg(mcmc_rhat(array(rhat)))
expect_gg(mcmc_rhat_hist(array(rhat)))
expect_gg(mcmc_neff(array(ratio)))
expect_gg(mcmc_neff_hist(array(ratio)))
expect_gg(mcmc_mcse(array(ratio)))
expect_gg(mcmc_mcse_hist(array(ratio)))

# named ok
rhat <- setNames(runif(5, 1, 1.5), paste0("alpha[", 1:5, "]"))
expect_gg(mcmc_rhat(rhat))
})

test_that("rhat and neff plot functions throw correct errors & warnings", {
test_that("rhat, neff, and mcse plot functions throw correct errors & warnings", {
# need vector or 1D array
expect_error(mcmc_rhat_hist(cbind(1:2)), "is.array")
expect_error(mcmc_neff_hist(list(1,2)), "is.numeric")
expect_error(mcmc_mcse_hist(list(1,2)), "is.numeric")

# need positive rhat values
expect_error(mcmc_rhat(c(-1, 1, 1)), "must be positive")

# need positive mcse values
expect_error(mcmc_mcse(c(-1, 1, 1)), "must be positive")

# need ratios between 0 and 1
expect_error(mcmc_neff(c(-1, 0.5, 0.7)), "must be positive")

# drop NAs and warn
expect_warning(mcmc_rhat(c(1, 1, NA)), "Dropped 1 NAs")
expect_warning(mcmc_neff(c(0.2, NA, 1, NA)), "Dropped 2 NAs")
expect_warning(mcmc_mcse(c(0.2, NA, 1, NA)), "Dropped 2 NAs")
})


test_that("duplicated rhats and neffs are kept (#105)", {
test_that("duplicated rhats, neffs, mcses are kept (#105)", {
# https://github.com/stan-dev/bayesplot/issues/105
rhats <- runif(3, 1, 1.2)
rhats <- c(rhats, rhats, rhats)
@@ -51,13 +61,16 @@ test_that("duplicated rhats and neffs are kept (#105)", {
ratios <- c(ratios, ratios, ratios)
df <- mcmc_neff_data(ratios)
expect_equal(nrow(df), length(ratios))

df <- mcmc_mcse_data(ratios)
expect_equal(nrow(df), length(ratios))
})

test_that("'description' & 'rating' columns are correct (#176)", {
# https://github.com/stan-dev/bayesplot/issues/176
rhats <- c(1, 1.07, 1.19, 1.07, 1.3, 1)
expected_rhats <- sort(rhats)
expected_ratings <- rep(c("low", "ok", "high"), each = 2)
expected_ratings <- rep(c("low", "mid", "high"), each = 2)
expected_descriptions <-
rep(c("hat(R) <= 1.05", "hat(R) <= 1.1", "hat(R) > 1.1"), each = 2)

@@ -68,7 +81,7 @@ test_that("'description' & 'rating' columns are correct (#176)", {

ratios <- c(0.4, 0.05, 0.6)
expected_ratios <- sort(ratios)
expected_ratings <- c("low", "ok", "high")
expected_ratings <- c("low", "mid", "high")
expected_descriptions <-
c("N[eff]/N <= 0.1", "N[eff]/N <= 0.5", "N[eff]/N > 0.5")

@@ -151,3 +164,23 @@ test_that("mcmc_neff_hist renders correctly", {
p_binwidth <- mcmc_neff_hist(neffs, binwidth = .05)
vdiffr::expect_doppelganger("mcmc_neff_hist (binwidth)", p_binwidth)
})

test_that("mcmc_mcse renders correctly", {
testthat::skip_on_cran()
testthat::skip_if_not_installed("vdiffr")

mcses <- seq(from = 0, to = 1.5, length.out = 40)

p_base <- mcmc_mcse(mcses)
vdiffr::expect_doppelganger("mcmc_mcse (default)", p_base)
})

test_that("mcmc_mcse_hist renders correctly", {
testthat::skip_on_cran()
testthat::skip_if_not_installed("vdiffr")

mcses <- seq(from = 0, to = 1.5, length.out = 40)

p_base <- mcmc_mcse_hist(mcses)
vdiffr::expect_doppelganger("mcmc_mcse_hist (default)", p_base)
})