Skip to content

Commit d5cd435

Browse files
committed
autoplot new data
1 parent 06fa647 commit d5cd435

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

R/autoplot.R

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ ggplot2::autoplot
1616
#' @param object An `epi_workflow`
1717
#' @param predictions A data frame with predictions. If `NULL`, only the
1818
#' original data is shown.
19+
#' @param plot_data An epi_df of the data to plot against. This is for the case
20+
#' where you have the actual results to compare the forecast against.
1921
#' @param .levels A numeric vector of levels to plot for any prediction bands.
2022
#' More than 3 levels begins to be difficult to see.
2123
#' @param ... Ignored
@@ -82,7 +84,9 @@ NULL
8284
#' @export
8385
#' @rdname autoplot-epipred
8486
autoplot.epi_workflow <- function(
85-
object, predictions = NULL,
87+
object,
88+
predictions = NULL,
89+
plot_data = NULL,
8690
.levels = c(.5, .8, .9), ...,
8791
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
8892
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
@@ -109,30 +113,32 @@ autoplot.epi_workflow <- function(
109113
}
110114
keys <- c("geo_value", "time_value", "key")
111115
mold_roles <- names(mold$extras$roles)
112-
edf <- bind_cols(mold$extras$roles[mold_roles %in% keys], y)
113-
if (starts_with_impl("ahead_", names(y))) {
114-
old_name_y <- unlist(strsplit(names(y), "_"))
115-
shift <- as.numeric(old_name_y[2])
116-
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
117-
edf <- rename(edf, !!new_name_y := !!names(y))
118-
} else if (starts_with_impl("lag_", names(y))) {
119-
old_name_y <- unlist(strsplit(names(y), "_"))
120-
shift <- -as.numeric(old_name_y[2])
121-
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
122-
edf <- rename(edf, !!new_name_y := !!names(y))
123-
}
124-
125-
if (!is.null(shift)) {
126-
edf <- mutate(edf, time_value = time_value + shift)
116+
# extract the relevant column names for plotting
117+
old_name_y <- unlist(strsplit(names(y), "_"))
118+
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
119+
if (is.null(plot_data)) {
120+
# the outcome has shifted, so we need to shift it forward (or back)
121+
# by the corresponding amount
122+
plot_data <- bind_cols(mold$extras$roles[mold_roles %in% keys], y)
123+
if (starts_with_impl("ahead_", names(y))) {
124+
shift <- as.numeric(old_name_y[2])
125+
} else if (starts_with_impl("lag_", names(y))) {
126+
old_name_y <- unlist(strsplit(names(y), "_"))
127+
shift <- -as.numeric(old_name_y[2])
128+
}
129+
plot_data <- rename(plot_data, !!new_name_y := !!names(y))
130+
if (!is.null(shift)) {
131+
plot_data <- mutate(plot_data, time_value = time_value + shift)
132+
}
133+
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
134+
plot_data <- as_epi_df(plot_data,
135+
as_of = object$fit$meta$as_of,
136+
other_keys = other_keys
137+
)
127138
}
128-
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
129-
edf <- as_epi_df(edf,
130-
as_of = object$fit$meta$as_of,
131-
other_keys = other_keys
132-
)
133139
if (is.null(predictions)) {
134140
return(autoplot(
135-
edf, new_name_y,
141+
plot_data, new_name_y,
136142
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
137143
.max_facets = .max_facets
138144
))
@@ -144,27 +150,27 @@ autoplot.epi_workflow <- function(
144150
}
145151
predictions <- rename(predictions, time_value = target_date)
146152
}
147-
pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(edf))
153+
pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(plot_data))
148154
if (!pred_cols_ok$ok) {
149155
cli_warn(c(
150156
"`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.",
151157
i = "Plotting the original data."
152158
))
153159
return(autoplot(
154-
edf, !!new_name_y,
160+
plot_data, !!new_name_y,
155161
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
156162
.max_facets = .max_facets
157163
))
158164
}
159165

160166
# First we plot the history, always faceted by everything
161-
bp <- autoplot(edf, !!new_name_y,
167+
bp <- autoplot(plot_data, !!new_name_y,
162168
.color_by = "none", .facet_by = "all_keys",
163169
.base_color = "black", .max_facets = .max_facets
164170
)
165171

166172
# Now, prepare matching facets in the predictions
167-
ek <- epi_keys_only(edf)
173+
ek <- epi_keys_only(plot_data)
168174
predictions <- predictions %>%
169175
mutate(
170176
.facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"),
@@ -202,7 +208,7 @@ autoplot.epi_workflow <- function(
202208
#' @export
203209
#' @rdname autoplot-epipred
204210
autoplot.canned_epipred <- function(
205-
object, ...,
211+
object, plot_data = NULL, ...,
206212
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
207213
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
208214
.base_color = "dodgerblue4",
@@ -216,7 +222,7 @@ autoplot.canned_epipred <- function(
216222
predictions <- object$predictions %>%
217223
rename(time_value = target_date)
218224

219-
autoplot(ewf, predictions,
225+
autoplot(ewf, predictions, plot_data, ...,
220226
.color_by = .color_by, .facet_by = .facet_by,
221227
.base_color = .base_color, .max_facets = .max_facets
222228
)

0 commit comments

Comments
 (0)