Skip to content

Commit fc67900

Browse files
authored
Merge pull request #295 from cmu-delphi/djm/plotting
Djm/plotting
2 parents 0c0a5e8 + 8ac8d4c commit fc67900

10 files changed

+456
-8
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
^data-raw$
1717
^vignettes/articles$
1818
^.git-blame-ignore-revs$
19+
^DEVELOPMENT\.md$
1920
^doc$
2021
^Meta$
2122
^.lintr$

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
2323
https://cmu-delphi.github.io/epipredict
2424
BugReports: https://github.com/cmu-delphi/epipredict/issues/
2525
Depends:
26-
epiprocess (>= 0.6.0),
26+
epiprocess (>= 0.7.5),
2727
parsnip (>= 1.0.0),
2828
R (>= 3.5.0)
2929
Imports:
@@ -32,6 +32,7 @@ Imports:
3232
distributional,
3333
dplyr,
3434
generics,
35+
ggplot2,
3536
glue,
3637
hardhat (>= 1.3.0),
3738
magrittr,
@@ -51,7 +52,6 @@ Suggests:
5152
data.table,
5253
epidatr (>= 1.0.0),
5354
fs,
54-
ggplot2,
5555
knitr,
5656
lubridate,
5757
poissonreg,

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ S3method(adjust_frosting,frosting)
1010
S3method(apply_frosting,default)
1111
S3method(apply_frosting,epi_workflow)
1212
S3method(augment,epi_workflow)
13+
S3method(autoplot,canned_epipred)
14+
S3method(autoplot,epi_workflow)
1315
S3method(bake,check_enough_train_data)
1416
S3method(bake,epi_recipe)
1517
S3method(bake,step_epi_ahead)
@@ -23,6 +25,7 @@ S3method(detect_layer,workflow)
2325
S3method(epi_keys,data.frame)
2426
S3method(epi_keys,default)
2527
S3method(epi_keys,epi_df)
28+
S3method(epi_keys,epi_workflow)
2629
S3method(epi_keys,recipe)
2730
S3method(epi_recipe,default)
2831
S3method(epi_recipe,epi_df)
@@ -128,6 +131,7 @@ export(arx_class_epi_workflow)
128131
export(arx_classifier)
129132
export(arx_fcast_epi_workflow)
130133
export(arx_forecaster)
134+
export(autoplot)
131135
export(bake)
132136
export(cdc_baseline_args_list)
133137
export(cdc_baseline_forecaster)
@@ -215,6 +219,7 @@ importFrom(dplyr,ungroup)
215219
importFrom(epiprocess,growth_rate)
216220
importFrom(generics,augment)
217221
importFrom(generics,fit)
222+
importFrom(ggplot2,autoplot)
218223
importFrom(hardhat,refresh_blueprint)
219224
importFrom(hardhat,run_mold)
220225
importFrom(magrittr,"%>%")

R/autoplot.R

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
#' @importFrom ggplot2 autoplot
2+
#' @export
3+
ggplot2::autoplot
4+
5+
#' Automatically plot an `epi_workflow` or `canned_epipred` object
6+
#'
7+
#' For a fit workflow, the training data will be displayed, the response by
8+
#' default. If `predictions` is not `NULL` then point and interval forecasts
9+
#' will be shown as well. Unfit workflows will result in an error, (you
10+
#' can simply call `autoplot()` on the original `epi_df`).
11+
#'
12+
#'
13+
#'
14+
#'
15+
#' @inheritParams epiprocess::autoplot.epi_df
16+
#' @param object An `epi_workflow`
17+
#'
18+
#' @param predictions A data frame with predictions. If `NULL`, only the
19+
#' original data is shown.
20+
#' @param .levels A numeric vector of levels to plot for any prediction bands.
21+
#' More than 3 levels begins to be difficult to see.
22+
#' @param ... Ignored
23+
#' @param .color_by A character string indicating how to color the data. See
24+
#' `epiprocess::autoplot.epi_df()` for more details.
25+
#' @param .facet_by A character string indicating how to facet the data. See
26+
#' `epiprocess::autoplot.epi_df()` for more details.
27+
#' @param .base_color If available, prediction bands will be shown with this
28+
#' color.
29+
#' @param .point_pred_color If available, point forecasts will be shown with this
30+
#' color.
31+
#' @param .max_facets The maximum number of facets to show. If the number of
32+
#' facets is greater than this value, only the top facets will be shown.
33+
#'
34+
#' @name autoplot-epipred
35+
#' @examples
36+
#' jhu <- case_death_rate_subset %>%
37+
#' filter(time_value >= as.Date("2021-11-01"))
38+
#'
39+
#' r <- epi_recipe(jhu) %>%
40+
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
41+
#' step_epi_ahead(death_rate, ahead = 7) %>%
42+
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
43+
#' step_epi_naomit()
44+
#'
45+
#' f <- frosting() %>%
46+
#' layer_residual_quantiles(
47+
#' quantile_levels = c(.025, .1, .25, .75, .9, .975)
48+
#' ) %>%
49+
#' layer_threshold(dplyr::starts_with(".pred")) %>%
50+
#' layer_add_target_date()
51+
#'
52+
#' wf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu)
53+
#'
54+
#' autoplot(wf)
55+
#'
56+
#' latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14)
57+
#' preds <- predict(wf, latest)
58+
#' autoplot(wf, preds, .max_facets = 4)
59+
#'
60+
#' # ------- Show multiple horizons
61+
#'
62+
#' p <- lapply(c(7, 14, 21, 28), \(h) {
63+
#' r <- epi_recipe(jhu) %>%
64+
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
65+
#' step_epi_ahead(death_rate, ahead = h) %>%
66+
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
67+
#' step_epi_naomit()
68+
#' ewf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu)
69+
#' td <- get_test_data(r, jhu)
70+
#' predict(ewf, new_data = td)
71+
#' })
72+
#'
73+
#' p <- do.call(rbind, p)
74+
#' autoplot(wf, p, .max_facets = 4)
75+
#'
76+
#' # ------- Plotting canned forecaster output
77+
#'
78+
#' jhu <- case_death_rate_subset %>% filter(time_value >= as.Date("2021-11-01"))
79+
#' flat <- flatline_forecaster(jhu, "death_rate")
80+
#' autoplot(flat, .max_facets = 4)
81+
#'
82+
#' arx <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
83+
#' args_list = arx_args_list(ahead = 14L)
84+
#' )
85+
#' autoplot(arx, .max_facets = 6)
86+
NULL
87+
88+
#' @export
89+
#' @rdname autoplot-epipred
90+
autoplot.epi_workflow <- function(
91+
object, predictions = NULL,
92+
.levels = c(.5, .8, .95), ...,
93+
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
94+
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
95+
.base_color = "dodgerblue4",
96+
.point_pred_color = "orange",
97+
.max_facets = Inf) {
98+
rlang::check_dots_empty()
99+
arg_is_probabilities(.levels)
100+
rlang::arg_match(.color_by)
101+
rlang::arg_match(.facet_by)
102+
103+
if (!workflows::is_trained_workflow(object)) {
104+
cli::cli_abort(c(
105+
"Can't plot an untrained {.cls epi_workflow}.",
106+
i = "Do you need to call `fit()`?"
107+
))
108+
}
109+
110+
mold <- workflows::extract_mold(object)
111+
y <- mold$outcomes
112+
if (ncol(y) > 1) {
113+
y <- y[, 1]
114+
cli::cli_warn("Multiple outcome variables were detected. Displaying only 1.")
115+
}
116+
keys <- c("time_value", "geo_value", "key")
117+
mold_roles <- names(mold$extras$roles)
118+
edf <- dplyr::bind_cols(mold$extras$roles[mold_roles %in% keys], y)
119+
if (starts_with_impl("ahead_", names(y))) {
120+
old_name_y <- unlist(strsplit(names(y), "_"))
121+
shift <- as.numeric(old_name_y[2])
122+
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
123+
edf <- dplyr::rename(edf, !!new_name_y := !!names(y))
124+
} else if (starts_with_impl("lag_", names(y))) {
125+
old_name_y <- unlist(strsplit(names(y), "_"))
126+
shift <- -as.numeric(old_name_y[2])
127+
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
128+
edf <- dplyr::rename(edf, !!new_name_y := !!names(y))
129+
}
130+
131+
if (!is.null(shift)) {
132+
edf <- dplyr::mutate(edf, time_value = time_value + shift)
133+
}
134+
extra_keys <- setdiff(epi_keys_mold(mold), c("time_value", "geo_value"))
135+
if (length(extra_keys) == 0L) extra_keys <- NULL
136+
edf <- as_epi_df(edf,
137+
as_of = object$fit$meta$as_of,
138+
additional_metadata = list(other_keys = extra_keys)
139+
)
140+
if (is.null(predictions)) {
141+
return(autoplot(
142+
edf, new_name_y,
143+
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
144+
.max_facets = .max_facets
145+
))
146+
}
147+
148+
if ("target_date" %in% names(predictions)) {
149+
if ("time_value" %in% names(predictions)) {
150+
predictions <- dplyr::select(predictions, -time_value)
151+
}
152+
predictions <- dplyr::rename(predictions, time_value = target_date)
153+
}
154+
pred_cols_ok <- hardhat::check_column_names(predictions, epi_keys(edf))
155+
if (!pred_cols_ok$ok) {
156+
cli::cli_warn(c(
157+
"`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.",
158+
i = "Plotting the original data."
159+
))
160+
return(autoplot(
161+
edf, !!new_name_y,
162+
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
163+
.max_facets = .max_facets
164+
))
165+
}
166+
167+
# First we plot the history, always faceted by everything
168+
bp <- autoplot(edf, !!new_name_y,
169+
.color_by = "none", .facet_by = "all_keys",
170+
.base_color = "black", .max_facets = .max_facets
171+
)
172+
173+
# Now, prepare matching facets in the predictions
174+
ek <- kill_time_value(epi_keys(edf))
175+
predictions <- predictions %>%
176+
dplyr::mutate(
177+
.facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"),
178+
)
179+
if (.max_facets < Inf) {
180+
top_n <- levels(as.factor(bp$data$.facets))[seq_len(.max_facets)]
181+
predictions <- dplyr::filter(predictions, .facets %in% top_n) %>%
182+
dplyr::mutate(.facets = droplevels(.facets))
183+
}
184+
185+
186+
if (".pred_distn" %in% names(predictions)) {
187+
bp <- plot_bands(bp, predictions, .levels, .base_color)
188+
}
189+
190+
if (".pred" %in% names(predictions)) {
191+
ntarget_dates <- dplyr::n_distinct(predictions$time_value)
192+
if (ntarget_dates > 1L) {
193+
bp <- bp +
194+
ggplot2::geom_line(
195+
data = predictions, ggplot2::aes(y = .data$.pred),
196+
color = .point_pred_color
197+
)
198+
} else {
199+
bp <- bp +
200+
ggplot2::geom_point(
201+
data = predictions, ggplot2::aes(y = .data$.pred),
202+
color = .point_pred_color
203+
)
204+
}
205+
}
206+
bp
207+
}
208+
209+
#' @export
210+
#' @rdname autoplot-epipred
211+
autoplot.canned_epipred <- function(
212+
object, ...,
213+
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
214+
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
215+
.base_color = "dodgerblue4",
216+
.point_pred_color = "orange",
217+
.max_facets = Inf) {
218+
rlang::check_dots_empty()
219+
rlang::arg_match(.color_by)
220+
rlang::arg_match(.facet_by)
221+
222+
ewf <- object$epi_workflow
223+
predictions <- object$predictions %>%
224+
dplyr::rename(time_value = target_date)
225+
226+
autoplot(ewf, predictions,
227+
.color_by = .color_by, .facet_by = .facet_by,
228+
.base_color = .base_color, .max_facets = .max_facets
229+
)
230+
}
231+
232+
starts_with_impl <- function(x, vars) {
233+
n <- nchar(x)
234+
x == substr(vars, 1, n)
235+
}
236+
237+
plot_bands <- function(
238+
base_plot, predictions,
239+
levels = c(.5, .8, .95),
240+
fill = "blue4",
241+
alpha = 0.6,
242+
linewidth = 0.05) {
243+
innames <- names(predictions)
244+
n <- length(levels)
245+
alpha <- alpha / (n - 1)
246+
l <- (1 - levels) / 2
247+
l <- c(rev(l), 1 - l)
248+
249+
ntarget_dates <- dplyr::n_distinct(predictions$time_value)
250+
251+
predictions <- predictions %>%
252+
dplyr::mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, l), l)) %>%
253+
pivot_quantiles_wider(.pred_distn)
254+
qnames <- setdiff(names(predictions), innames)
255+
256+
for (i in 1:n) {
257+
bottom <- qnames[i]
258+
top <- rev(qnames)[i]
259+
if (i == 1) {
260+
if (ntarget_dates > 1L) {
261+
base_plot <- base_plot +
262+
ggplot2::geom_ribbon(
263+
data = predictions,
264+
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
265+
alpha = 0.2, linewidth = linewidth, fill = fill
266+
)
267+
} else {
268+
base_plot <- base_plot +
269+
ggplot2::geom_linerange(
270+
data = predictions,
271+
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
272+
alpha = 0.2, linewidth = 2, color = fill
273+
)
274+
}
275+
} else {
276+
if (ntarget_dates > 1L) {
277+
base_plot <- base_plot +
278+
ggplot2::geom_ribbon(
279+
data = predictions,
280+
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
281+
fill = fill, alpha = alpha
282+
)
283+
} else {
284+
base_plot <- base_plot +
285+
ggplot2::geom_linerange(
286+
data = predictions,
287+
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
288+
color = fill, alpha = alpha, linewidth = 2
289+
)
290+
}
291+
}
292+
}
293+
base_plot
294+
}
295+
296+
find_level <- function(x) {
297+
unique((x < .5) * (1 - 2 * x) + (x > .5) * (1 - 2 * (1 - x)))
298+
}

R/epi_keys.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,19 @@ epi_keys.data.frame <- function(x, other_keys = character(0L), ...) {
2424

2525
#' @export
2626
epi_keys.epi_df <- function(x, ...) {
27-
c("time_value", "geo_value", attributes(x)$metadata$other_keys)
27+
c("time_value", "geo_value", attr(x, "metadata")$other_keys)
2828
}
2929

3030
#' @export
3131
epi_keys.recipe <- function(x, ...) {
3232
x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")]
3333
}
3434

35+
#' @export
36+
epi_keys.epi_workflow <- function(x, ...) {
37+
epi_keys_mold(hardhat::extract_mold(x))
38+
}
39+
3540
# a mold is a list extracted from a fitted workflow, gives info about
3641
# training data. But it doesn't have a class
3742
epi_keys_mold <- function(mold) {
@@ -45,3 +50,7 @@ kill_time_value <- function(v) {
4550
arg_is_chr(v)
4651
v[v != "time_value"]
4752
}
53+
54+
epi_keys_only <- function(x, ...) {
55+
kill_time_value(epi_keys(x, ...))
56+
}

0 commit comments

Comments
 (0)