Skip to content

Djm/plotting #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
13b9fd9
implement epi_df method
dajmcdon Nov 5, 2023
dd17058
document / export
dajmcdon Nov 5, 2023
d6524ee
missing rlang hook
dajmcdon Nov 7, 2023
04742ed
fix a bug that might have allowed unordered quantile extrapolation, s…
dajmcdon Nov 15, 2023
8c62229
add autoplot method
dajmcdon Nov 15, 2023
fb74311
ignore vignette caches
dajmcdon Feb 6, 2024
08947ed
use the autoplot method in epiprocess
dajmcdon Feb 6, 2024
103a601
merge dev
dajmcdon Feb 20, 2024
a2af588
fix: naming issue, missing dots arg
dajmcdon Feb 20, 2024
593208b
add documentation, fix ahead/lag behaviour
dajmcdon Mar 5, 2024
dde25f1
drafted
dajmcdon Mar 5, 2024
6ee75c4
import ggplot2
dajmcdon Mar 5, 2024
af47f09
remove test for old internal function
dajmcdon Mar 5, 2024
6809dda
typo in subfunction name
dajmcdon Mar 5, 2024
59a5485
prefix head()/tail()
dajmcdon Mar 5, 2024
042073a
rm: from suggests to imports
dajmcdon Mar 5, 2024
db5acde
ignore the dev md
dajmcdon Mar 5, 2024
6eb8863
merge dev, quantile handling
dajmcdon Mar 6, 2024
95ce4d3
complete dev quantile merge
dajmcdon Mar 6, 2024
046e7de
remove test for deleted function
dajmcdon Mar 6, 2024
a06286e
complete merge, checks pass
dajmcdon Mar 6, 2024
86b8708
add plotting to pkgdown, clean up references
dajmcdon Mar 6, 2024
410da60
revert changes
dajmcdon Mar 6, 2024
29673be
bump version, add to news
dajmcdon Mar 6, 2024
e66e75f
Merge branch 'dev' into djm/plotting
dajmcdon Apr 2, 2024
77e684b
docs+lint: add arg checking and doc variables in autoplot
dshemetov Apr 4, 2024
e0512e7
test: add a few autoplot snapshots
dshemetov Apr 4, 2024
e4bebf2
Merge branch 'dev' into djm/plotting
dshemetov Apr 4, 2024
e530cca
Merge branch 'dev' into djm/plotting
dshemetov Apr 4, 2024
70d9f55
fix: ggplot2::ggsave
dshemetov Apr 4, 2024
8ac8d4c
test: remove autoplot snapshots
dshemetov Apr 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
^data-raw$
^vignettes/articles$
^.git-blame-ignore-revs$
^DEVELOPMENT\.md$
^doc$
^Meta$
^.lintr$
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
https://cmu-delphi.github.io/epipredict
BugReports: https://github.com/cmu-delphi/epipredict/issues/
Depends:
epiprocess (>= 0.6.0),
epiprocess (>= 0.7.5),
parsnip (>= 1.0.0),
R (>= 3.5.0)
Imports:
Expand All @@ -32,6 +32,7 @@ Imports:
distributional,
dplyr,
generics,
ggplot2,
glue,
hardhat (>= 1.3.0),
magrittr,
Expand All @@ -51,7 +52,6 @@ Suggests:
data.table,
epidatr (>= 1.0.0),
fs,
ggplot2,
knitr,
lubridate,
poissonreg,
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ S3method(adjust_frosting,frosting)
S3method(apply_frosting,default)
S3method(apply_frosting,epi_workflow)
S3method(augment,epi_workflow)
S3method(autoplot,canned_epipred)
S3method(autoplot,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_epi_ahead)
Expand All @@ -23,6 +25,7 @@ S3method(detect_layer,workflow)
S3method(epi_keys,data.frame)
S3method(epi_keys,default)
S3method(epi_keys,epi_df)
S3method(epi_keys,epi_workflow)
S3method(epi_keys,recipe)
S3method(epi_recipe,default)
S3method(epi_recipe,epi_df)
Expand Down Expand Up @@ -128,6 +131,7 @@ export(arx_class_epi_workflow)
export(arx_classifier)
export(arx_fcast_epi_workflow)
export(arx_forecaster)
export(autoplot)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
Expand Down Expand Up @@ -215,6 +219,7 @@ importFrom(dplyr,ungroup)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
importFrom(ggplot2,autoplot)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
Expand Down
298 changes: 298 additions & 0 deletions R/autoplot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
#' @importFrom ggplot2 autoplot
#' @export
ggplot2::autoplot

#' Automatically plot an `epi_workflow` or `canned_epipred` object
#'
#' For a fit workflow, the training data will be displayed, the response by
#' default. If `predictions` is not `NULL` then point and interval forecasts
#' will be shown as well. Unfit workflows will result in an error, (you
#' can simply call `autoplot()` on the original `epi_df`).
#'
#'
#'
#'
#' @inheritParams epiprocess::autoplot.epi_df
#' @param object An `epi_workflow`
#'
#' @param predictions A data frame with predictions. If `NULL`, only the
#' original data is shown.
#' @param .levels A numeric vector of levels to plot for any prediction bands.
#' More than 3 levels begins to be difficult to see.
#' @param ... Ignored
#' @param .color_by A character string indicating how to color the data. See
#' `epiprocess::autoplot.epi_df()` for more details.
#' @param .facet_by A character string indicating how to facet the data. See
#' `epiprocess::autoplot.epi_df()` for more details.
#' @param .base_color If available, prediction bands will be shown with this
#' color.
#' @param .point_pred_color If available, point forecasts will be shown with this
#' color.
#' @param .max_facets The maximum number of facets to show. If the number of
#' facets is greater than this value, only the top facets will be shown.
#'
#' @name autoplot-epipred
#' @examples
#' jhu <- case_death_rate_subset %>%
#' filter(time_value >= as.Date("2021-11-01"))
#'
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7) %>%
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
#' step_epi_naomit()
#'
#' f <- frosting() %>%
#' layer_residual_quantiles(
#' quantile_levels = c(.025, .1, .25, .75, .9, .975)
#' ) %>%
#' layer_threshold(dplyr::starts_with(".pred")) %>%
#' layer_add_target_date()
#'
#' wf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu)
#'
#' autoplot(wf)
#'
#' latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14)
#' preds <- predict(wf, latest)
#' autoplot(wf, preds, .max_facets = 4)
#'
#' # ------- Show multiple horizons
#'
#' p <- lapply(c(7, 14, 21, 28), \(h) {
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = h) %>%
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
#' step_epi_naomit()
#' ewf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu)
#' td <- get_test_data(r, jhu)
#' predict(ewf, new_data = td)
#' })
#'
#' p <- do.call(rbind, p)
#' autoplot(wf, p, .max_facets = 4)
#'
#' # ------- Plotting canned forecaster output
#'
#' jhu <- case_death_rate_subset %>% filter(time_value >= as.Date("2021-11-01"))
#' flat <- flatline_forecaster(jhu, "death_rate")
#' autoplot(flat, .max_facets = 4)
#'
#' arx <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
#' args_list = arx_args_list(ahead = 14L)
#' )
#' autoplot(arx, .max_facets = 6)
NULL

#' @export
#' @rdname autoplot-epipred
autoplot.epi_workflow <- function(
object, predictions = NULL,
.levels = c(.5, .8, .95), ...,
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
.base_color = "dodgerblue4",
.point_pred_color = "orange",
.max_facets = Inf) {
rlang::check_dots_empty()
arg_is_probabilities(.levels)
rlang::arg_match(.color_by)
rlang::arg_match(.facet_by)

if (!workflows::is_trained_workflow(object)) {
cli::cli_abort(c(
"Can't plot an untrained {.cls epi_workflow}.",
i = "Do you need to call `fit()`?"
))
}

mold <- workflows::extract_mold(object)
y <- mold$outcomes
if (ncol(y) > 1) {
y <- y[, 1]
cli::cli_warn("Multiple outcome variables were detected. Displaying only 1.")
}
keys <- c("time_value", "geo_value", "key")
mold_roles <- names(mold$extras$roles)
edf <- dplyr::bind_cols(mold$extras$roles[mold_roles %in% keys], y)
if (starts_with_impl("ahead_", names(y))) {
old_name_y <- unlist(strsplit(names(y), "_"))
shift <- as.numeric(old_name_y[2])
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
edf <- dplyr::rename(edf, !!new_name_y := !!names(y))
} else if (starts_with_impl("lag_", names(y))) {
old_name_y <- unlist(strsplit(names(y), "_"))
shift <- -as.numeric(old_name_y[2])
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
edf <- dplyr::rename(edf, !!new_name_y := !!names(y))
}

if (!is.null(shift)) {
edf <- dplyr::mutate(edf, time_value = time_value + shift)
}
extra_keys <- setdiff(epi_keys_mold(mold), c("time_value", "geo_value"))
if (length(extra_keys) == 0L) extra_keys <- NULL
edf <- as_epi_df(edf,
as_of = object$fit$meta$as_of,
additional_metadata = list(other_keys = extra_keys)
)
if (is.null(predictions)) {
return(autoplot(
edf, new_name_y,
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
.max_facets = .max_facets
))
}

if ("target_date" %in% names(predictions)) {
if ("time_value" %in% names(predictions)) {
predictions <- dplyr::select(predictions, -time_value)
}
predictions <- dplyr::rename(predictions, time_value = target_date)
}
pred_cols_ok <- hardhat::check_column_names(predictions, epi_keys(edf))
if (!pred_cols_ok$ok) {
cli::cli_warn(c(
"`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.",
i = "Plotting the original data."
))
return(autoplot(
edf, !!new_name_y,
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
.max_facets = .max_facets
))
}

# First we plot the history, always faceted by everything
bp <- autoplot(edf, !!new_name_y,
.color_by = "none", .facet_by = "all_keys",
.base_color = "black", .max_facets = .max_facets
)

# Now, prepare matching facets in the predictions
ek <- kill_time_value(epi_keys(edf))
predictions <- predictions %>%
dplyr::mutate(
.facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"),
)
if (.max_facets < Inf) {
top_n <- levels(as.factor(bp$data$.facets))[seq_len(.max_facets)]
predictions <- dplyr::filter(predictions, .facets %in% top_n) %>%
dplyr::mutate(.facets = droplevels(.facets))
}


if (".pred_distn" %in% names(predictions)) {
bp <- plot_bands(bp, predictions, .levels, .base_color)
}

if (".pred" %in% names(predictions)) {
ntarget_dates <- dplyr::n_distinct(predictions$time_value)
if (ntarget_dates > 1L) {
bp <- bp +
ggplot2::geom_line(
data = predictions, ggplot2::aes(y = .data$.pred),
color = .point_pred_color
)
} else {
bp <- bp +
ggplot2::geom_point(
data = predictions, ggplot2::aes(y = .data$.pred),
color = .point_pred_color
)
}
}
bp
}

#' @export
#' @rdname autoplot-epipred
autoplot.canned_epipred <- function(
object, ...,
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
.base_color = "dodgerblue4",
.point_pred_color = "orange",
.max_facets = Inf) {
rlang::check_dots_empty()
rlang::arg_match(.color_by)
rlang::arg_match(.facet_by)

ewf <- object$epi_workflow
predictions <- object$predictions %>%
dplyr::rename(time_value = target_date)

autoplot(ewf, predictions,
.color_by = .color_by, .facet_by = .facet_by,
.base_color = .base_color, .max_facets = .max_facets
)
}

starts_with_impl <- function(x, vars) {
n <- nchar(x)
x == substr(vars, 1, n)
}

plot_bands <- function(
base_plot, predictions,
levels = c(.5, .8, .95),
fill = "blue4",
alpha = 0.6,
linewidth = 0.05) {
innames <- names(predictions)
n <- length(levels)
alpha <- alpha / (n - 1)
l <- (1 - levels) / 2
l <- c(rev(l), 1 - l)

ntarget_dates <- dplyr::n_distinct(predictions$time_value)

predictions <- predictions %>%
dplyr::mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, l), l)) %>%
pivot_quantiles_wider(.pred_distn)
qnames <- setdiff(names(predictions), innames)

for (i in 1:n) {
bottom <- qnames[i]
top <- rev(qnames)[i]
if (i == 1) {
if (ntarget_dates > 1L) {
base_plot <- base_plot +
ggplot2::geom_ribbon(
data = predictions,
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
alpha = 0.2, linewidth = linewidth, fill = fill
)
} else {
base_plot <- base_plot +
ggplot2::geom_linerange(
data = predictions,
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
alpha = 0.2, linewidth = 2, color = fill
)
}
} else {
if (ntarget_dates > 1L) {
base_plot <- base_plot +
ggplot2::geom_ribbon(
data = predictions,
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
fill = fill, alpha = alpha
)
} else {
base_plot <- base_plot +
ggplot2::geom_linerange(
data = predictions,
ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]),
color = fill, alpha = alpha, linewidth = 2
)
}
}
}
base_plot
}

find_level <- function(x) {
unique((x < .5) * (1 - 2 * x) + (x > .5) * (1 - 2 * (1 - x)))
}
11 changes: 10 additions & 1 deletion R/epi_keys.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@ epi_keys.data.frame <- function(x, other_keys = character(0L), ...) {

#' @export
epi_keys.epi_df <- function(x, ...) {
c("time_value", "geo_value", attributes(x)$metadata$other_keys)
c("time_value", "geo_value", attr(x, "metadata")$other_keys)
}

#' @export
epi_keys.recipe <- function(x, ...) {
x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")]
}

#' @export
epi_keys.epi_workflow <- function(x, ...) {
epi_keys_mold(hardhat::extract_mold(x))
}

# a mold is a list extracted from a fitted workflow, gives info about
# training data. But it doesn't have a class
epi_keys_mold <- function(mold) {
Expand All @@ -45,3 +50,7 @@ kill_time_value <- function(v) {
arg_is_chr(v)
v[v != "time_value"]
}

epi_keys_only <- function(x, ...) {
kill_time_value(epi_keys(x, ...))
}
Loading