Skip to content
This repository was archived by the owner on Aug 29, 2024. It is now read-only.

Commit c226881

Browse files
committed
created more advanced interface to allow more control of priors
1 parent bc323e0 commit c226881

File tree

11 files changed

+688
-350
lines changed

11 files changed

+688
-350
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
S3method(b_gt_a,beta_dist)
44
S3method(b_gt_a,gamma_dist)
55
S3method(b_gt_a,normal_gamma_dist)
6+
S3method(compute_moments,beta_dist)
7+
S3method(compute_moments,gamma_dist)
8+
S3method(compute_moments,normal_gamma_dist)
69
S3method(expected_loss_b,beta_dist)
710
S3method(expected_loss_b,gamma_dist)
811
S3method(expected_loss_b,normal_gamma_dist)
@@ -26,6 +29,7 @@ export(beta_dist)
2629
export(calc_beta_dist)
2730
export(calc_gamma_dist)
2831
export(calc_normal_gamma_dist)
32+
export(compute_moments)
2933
export(create_empty_dt)
3034
export(expected_loss_b)
3135
export(gamma_cdf)
@@ -43,6 +47,7 @@ export(normal_gamma_dist)
4347
export(plot_beta)
4448
export(plot_gamma)
4549
export(plot_normal)
50+
export(plot_relative_gain)
4651
export(poisson_dist)
4752
export(sim_effect_size)
4853
export(simulate_ab_test)

R/compute_moments.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# these functions compute the mean and standard deviation of distribution objects
2+
3+
#' @export
4+
compute_moments.beta_dist <- function(dist) {
5+
a <- dist[['alpha']]; b <- dist[['beta']]
6+
return(list(mu = a / (a + b), sigma = sqrt(a * b / (a + b) ^ 2 / (a + b + 1))))
7+
}
8+
9+
#' @export
10+
compute_moments.normal_gamma_dist <- function(dist) {
11+
mu <- dist[['mu']]; lambda <- dist[['lambda']]
12+
a <- dist[['alpha']]; b <- dist[['beta']]
13+
return(list(x = list(mu = mu, sigma = sqrt(b / lambda / (a - 1))), tau = list(mu = a / b, sigma = sqrt(a / b ^ 2))))
14+
}
15+
16+
#' @export
17+
compute_moments.gamma_dist <- function(dist) {
18+
a <- dist[['alpha']]; b <- dist[['beta']]
19+
return(list(mu = a / b, sigma = sqrt(a / b ^ 2)))
20+
}
21+
22+
#' @title Simulate Data According to Some Distribution
23+
#' @name simulate_data
24+
#' @description Simulate a vector of data from a given distribution object.
25+
#' @export
26+
#' @param dist An object of class \code{'beta_dist'}, \code{'normal_gamma_dist'},
27+
#' \code{'gamma_dist'} that specifies the parameters of some distribution
28+
#' @return A list
29+
compute_moments <- function(dist) {
30+
UseMethod('compute_moments')
31+
}

R/plotting.R

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
#' @param xlab The title of the x axis
1111
#' @param ylab The title of the y axis
1212
#' @param color The color for the plot
13-
#' @param level The desired amount of area between the lower and upper bounds. Default is \code{0.99}.
13+
#' @param support_level The desired amount of area between the lower and upper bounds. Default is \code{0.99}.
1414
#'
1515
#' @return NULL. A plot is generated
1616
plot_beta <- function(betas
1717
, title = 'Beta Distribution'
1818
, xlab = 'Rate that the Event Occurs'
1919
, ylab = 'Density of that Rate'
2020
, color = '#f65335'
21-
, level = 0.99
21+
, support_level = 0.99
2222
) {
2323

2424
n_samp <- 1e5
@@ -33,8 +33,8 @@ plot_beta <- function(betas
3333

3434
# remove the lower and upper extremes of the data
3535
beta_vec <- beta_vec[data.table::between(beta_vec
36-
, quantile(beta_vec, 0.005)
37-
, quantile(beta_vec, 0.995))]
36+
, quantile(beta_vec, (1 - support_level) / 2)
37+
, quantile(beta_vec, support_level + (1 - support_level) / 2))]
3838

3939
beta_dt <- data.table::rbindlist(list(beta_dt, data.table::data.table('variant' = rep(var_name, length(beta_vec))
4040
, 'betas' = beta_vec)))
@@ -70,7 +70,11 @@ plot_beta <- function(betas
7070
#' @param normals A list of lists of normal distributions
7171
#' @inheritParams plot_beta
7272
#' @return NULL. A plot is generated
73-
plot_normal <- function(normals) {
73+
plot_normal <- function(normals
74+
, title = 'Normal Distribution'
75+
, color = '#f65335'
76+
, support_level = 0.99
77+
) {
7478
n_samp <- 1e5
7579

7680
sd_dt <- NULL
@@ -91,12 +95,12 @@ plot_normal <- function(normals) {
9195

9296
# remove the lower and upper extremes of the data
9397
sd_vec <- sd_vec[data.table::between(sd_vec
94-
, quantile(sd_vec, 0.005)
95-
, quantile(sd_vec, 0.995))]
98+
, quantile(sd_vec, (1 - support_level) / 2)
99+
, quantile(sd_vec, support_level + (1 - support_level) / 2))]
96100
n_sd <- length(sd_vec)
97101
mu_vec <- mu_vec[data.table::between(mu_vec
98-
, quantile(mu_vec, 0.005)
99-
, quantile(mu_vec, 0.995))]
102+
, quantile(mu_vec, (1 - support_level) / 2)
103+
, quantile(mu_vec, support_level + (1 - support_level) / 2))]
100104
n_mu <- length(mu_vec)
101105

102106
sd_dt <- data.table::rbindlist(list(sd_dt, data.table::data.table('variant' = rep(var_name, n_sd)
@@ -108,7 +112,7 @@ plot_normal <- function(normals) {
108112
if (length(normals) > 1) {
109113
col_vals <- c('darkred', 'darkblue')
110114
} else {
111-
col_vals <- 'black'
115+
col_vals <- color
112116
}
113117

114118
mu_plot <- ggplot(mu_dt, aes(x = mus, colour = variant, fill = variant)) +
@@ -168,7 +172,11 @@ plot_normal <- function(normals) {
168172
#' @param gammas A list of lists of gamma distributions
169173
#' @inheritParams plot_beta
170174
#' @return NULL. A plot is generated
171-
plot_gamma <- function(gammas, title = 'Density of Gamma Distribution', level = 0.99) {
175+
plot_gamma <- function(gammas
176+
, title = 'Density of Gamma Distribution'
177+
, color = '#f65335'
178+
, support_level = 0.99
179+
) {
172180

173181
n_samp <- 1e5
174182
gamma_dt <- NULL
@@ -182,8 +190,8 @@ plot_gamma <- function(gammas, title = 'Density of Gamma Distribution', level =
182190

183191
# remove the lower and upper extremes of the data
184192
gamma_vec <- gamma_vec[data.table::between(gamma_vec
185-
, quantile(gamma_vec, 0.005)
186-
, quantile(gamma_vec, 0.995))]
193+
, quantile(gamma_vec, (1 - support_level) / 2)
194+
, quantile(gamma_vec, support_level + (1 - support_level) / 2))]
187195

188196
gamma_dt <- data.table::rbindlist(list(gamma_dt, data.table::data.table('variant' = rep(var_name, length(gamma_vec))
189197
, 'gammas' = gamma_vec)))
@@ -192,10 +200,10 @@ plot_gamma <- function(gammas, title = 'Density of Gamma Distribution', level =
192200
if (length(gammas) > 1) {
193201
col_vals <- c('darkred', 'darkblue')
194202
} else {
195-
col_vals <- 'black'
203+
col_vals <- color
196204
}
197205

198-
lambda_plot <- ggplot(gamma_dt, aes(x = gammas, colour = variant, fill =variant)) +
206+
lambda_plot <- ggplot(gamma_dt, aes(x = gammas, colour = variant, fill = variant)) +
199207
geom_density(size = 1, alpha = 0.1) +
200208
ggtitle(title) +
201209
xlab('Expected Amount of Times Event Occurrs') +
@@ -211,3 +219,23 @@ plot_gamma <- function(gammas, title = 'Density of Gamma Distribution', level =
211219
return(lambda_plot)
212220
}
213221

222+
#' @title Plot Relative Gain
223+
#' @name plot_relative_gain
224+
#' @description Plot the cumulative density of the ratio of the metric under variant B to the metric under variant A
225+
#' @importFrom purrr map
226+
#' @export
227+
#' @param dists A list of distribution objects, with elements named \code{'a'} and \code{'b'}
228+
#' @param sim_batch_size The number of objects to simulate
229+
#' @return A plot
230+
plot_relative_gain <- function(dists, sim_batch_size = 1e5, title = 'Cumulative Density of B / A') {
231+
thetas <- purrr::map(dists, function(dist) simulate_data(dist = dist, n = sim_batch_size))
232+
ratios <- thetas[['b']] / thetas[['a']]
233+
df <- data.frame(x = ratios)
234+
ecdf_plot <- ggplot(df, aes(x, colour = '#f65335')) + stat_ecdf(size = 1.5) +
235+
ggtitle(title) + xlab('Relative Gain (B / A)') + ylab('Cumulative Density') +
236+
theme(plot.title = element_text(hjust = 0.5, size = 22)
237+
, axis.title = element_text(size = 18)
238+
, axis.text = element_text(size = 14)
239+
, legend.position = 'none')
240+
return(ecdf_plot)
241+
}

0 commit comments

Comments
 (0)