@@ -16,6 +16,8 @@ ggplot2::autoplot
1616# ' @param object An `epi_workflow`
1717# ' @param predictions A data frame with predictions. If `NULL`, only the
1818# ' original data is shown.
19+ # ' @param plot_data An epi_df of the data to plot against. This is for the case
20+ # ' where you have the actual results to compare the forecast against.
1921# ' @param .levels A numeric vector of levels to plot for any prediction bands.
2022# ' More than 3 levels begins to be difficult to see.
2123# ' @param ... Ignored
8486# ' @export
8587# ' @rdname autoplot-epipred
8688autoplot.epi_workflow <- function (
87- object , predictions = NULL ,
89+ object ,
90+ predictions = NULL ,
91+ plot_data = NULL ,
8892 .levels = c(.5 , .8 , .95 ), ... ,
8993 .color_by = c(" all_keys" , " geo_value" , " other_keys" , " .response" , " all" , " none" ),
9094 .facet_by = c(" .response" , " other_keys" , " all_keys" , " geo_value" , " all" , " none" ),
@@ -111,30 +115,32 @@ autoplot.epi_workflow <- function(
111115 }
112116 keys <- c(" geo_value" , " time_value" , " key" )
113117 mold_roles <- names(mold $ extras $ roles )
114- edf <- bind_cols(mold $ extras $ roles [mold_roles %in% keys ], y )
115- if (starts_with_impl(" ahead_" , names(y ))) {
116- old_name_y <- unlist(strsplit(names(y ), " _" ))
117- shift <- as.numeric(old_name_y [2 ])
118- new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
119- edf <- rename(edf , !! new_name_y : = !! names(y ))
120- } else if (starts_with_impl(" lag_" , names(y ))) {
121- old_name_y <- unlist(strsplit(names(y ), " _" ))
122- shift <- - as.numeric(old_name_y [2 ])
123- new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
124- edf <- rename(edf , !! new_name_y : = !! names(y ))
125- }
126-
127- if (! is.null(shift )) {
128- edf <- mutate(edf , time_value = time_value + shift )
118+ # extract the relevant column names for plotting
119+ old_name_y <- unlist(strsplit(names(y ), " _" ))
120+ new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
121+ if (is.null(plot_data )) {
122+ # the outcome has shifted, so we need to shift it forward (or back)
123+ # by the corresponding amount
124+ plot_data <- bind_cols(mold $ extras $ roles [mold_roles %in% keys ], y )
125+ if (starts_with_impl(" ahead_" , names(y ))) {
126+ shift <- as.numeric(old_name_y [2 ])
127+ } else if (starts_with_impl(" lag_" , names(y ))) {
128+ old_name_y <- unlist(strsplit(names(y ), " _" ))
129+ shift <- - as.numeric(old_name_y [2 ])
130+ }
131+ plot_data <- rename(plot_data , !! new_name_y : = !! names(y ))
132+ if (! is.null(shift )) {
133+ plot_data <- mutate(plot_data , time_value = time_value + shift )
134+ }
135+ other_keys <- setdiff(key_colnames(object ), c(" geo_value" , " time_value" ))
136+ plot_data <- as_epi_df(plot_data ,
137+ as_of = object $ fit $ meta $ as_of ,
138+ other_keys = other_keys
139+ )
129140 }
130- other_keys <- setdiff(key_colnames(object ), c(" geo_value" , " time_value" ))
131- edf <- as_epi_df(edf ,
132- as_of = object $ fit $ meta $ as_of ,
133- other_keys = other_keys
134- )
135141 if (is.null(predictions )) {
136142 return (autoplot(
137- edf , new_name_y ,
143+ plot_data , new_name_y ,
138144 .color_by = .color_by , .facet_by = .facet_by , .base_color = .base_color ,
139145 .max_facets = .max_facets
140146 ))
@@ -146,27 +152,27 @@ autoplot.epi_workflow <- function(
146152 }
147153 predictions <- rename(predictions , time_value = target_date )
148154 }
149- pred_cols_ok <- hardhat :: check_column_names(predictions , key_colnames(edf ))
155+ pred_cols_ok <- hardhat :: check_column_names(predictions , key_colnames(plot_data ))
150156 if (! pred_cols_ok $ ok ) {
151157 cli_warn(c(
152158 " `predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}." ,
153159 i = " Plotting the original data."
154160 ))
155161 return (autoplot(
156- edf , !! new_name_y ,
162+ plot_data , !! new_name_y ,
157163 .color_by = .color_by , .facet_by = .facet_by , .base_color = .base_color ,
158164 .max_facets = .max_facets
159165 ))
160166 }
161167
162168 # First we plot the history, always faceted by everything
163- bp <- autoplot(edf , !! new_name_y ,
169+ bp <- autoplot(plot_data , !! new_name_y ,
164170 .color_by = " none" , .facet_by = " all_keys" ,
165171 .base_color = " black" , .max_facets = .max_facets
166172 )
167173
168174 # Now, prepare matching facets in the predictions
169- ek <- epi_keys_only(edf )
175+ ek <- epi_keys_only(plot_data )
170176 predictions <- predictions %> %
171177 mutate(
172178 .facets = interaction(!!! rlang :: syms(as.list(ek )), sep = " /" ),
@@ -204,7 +210,7 @@ autoplot.epi_workflow <- function(
204210# ' @export
205211# ' @rdname autoplot-epipred
206212autoplot.canned_epipred <- function (
207- object , ... ,
213+ object , plot_data = NULL , ... ,
208214 .color_by = c(" all_keys" , " geo_value" , " other_keys" , " .response" , " all" , " none" ),
209215 .facet_by = c(" .response" , " other_keys" , " all_keys" , " geo_value" , " all" , " none" ),
210216 .base_color = " dodgerblue4" ,
@@ -218,7 +224,7 @@ autoplot.canned_epipred <- function(
218224 predictions <- object $ predictions %> %
219225 rename(time_value = target_date )
220226
221- autoplot(ewf , predictions ,
227+ autoplot(ewf , predictions , plot_data , ... ,
222228 .color_by = .color_by , .facet_by = .facet_by ,
223229 .base_color = .base_color , .max_facets = .max_facets
224230 )
0 commit comments