-
Notifications
You must be signed in to change notification settings - Fork 10
240 quantile pivot #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
240 quantile pivot #241
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
8a1e2d6
start CDC baseline layer
dajmcdon cea1599
upgrade enframer
dajmcdon d606741
functions, remains to check validity
dajmcdon 7294c00
correct symmetrization, enhance documentation of the "ahead" param in…
dajmcdon f18e88f
better defaults, cli, pred is scalar in propagate_samples
dajmcdon d6a28f3
redocument
dajmcdon 237ec50
run styler
dajmcdon c13b83e
redocument after styling
dajmcdon 16f6c2c
example plotting with ggplot2 handled correctly
dajmcdon ce0b180
finish quantile pivotting helpers, redocument
dajmcdon f97166b
fix extra check note.
dajmcdon 9dd0a2c
run styler
dajmcdon 16139ff
add lifecycle, deprecate pivot_quantiles.
dajmcdon c9b4667
working cdc baseline
dajmcdon d59a691
add cdc baseline to pkgdown
dajmcdon 21b4c85
local checks pass
dajmcdon 965155d
finish quantile pivotting helpers, redocument
dajmcdon 1cf5dff
fix extra check note.
dajmcdon 1458ab0
add lifecycle, deprecate pivot_quantiles.
dajmcdon 0b9b537
Merge branch '240-quantile-pivot' of https://github.com/cmu-delphi/ep…
dajmcdon 7358c13
CI: also needs to be on the branch
dsweber2 169f764
address @nmdefries comments
dajmcdon 8d1e47d
pass local checks
dajmcdon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,4 @@ | |
^musings$ | ||
^data-raw$ | ||
^vignettes/articles$ | ||
^.git-blame-ignore-revs$ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ Imports: | |
generics, | ||
glue, | ||
hardhat (>= 1.3.0), | ||
lifecycle, | ||
magrittr, | ||
methods, | ||
quantreg, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
#' Predict the future with the most recent value | ||
#' | ||
#' This is a simple forecasting model for | ||
#' [epiprocess::epi_df] data. It uses the most recent observation as the | ||
#' forecast for any future date, and produces intervals by shuffling the quantiles | ||
#' of the residuals of such a "flatline" forecast and incrementing these | ||
#' forward over all available training data. | ||
#' | ||
#' By default, the predictive intervals are computed separately for each | ||
#' combination of `geo_value` in the `epi_data` argument. | ||
#' | ||
#' This forecaster is meant to produce exactly the CDC Baseline used for | ||
#' [COVID19ForecastHub](https://covid19forecasthub.org) | ||
#' | ||
#' @param epi_data An [epiprocess::epi_df] | ||
#' @param outcome A scalar character for the column name we wish to predict. | ||
#' @param args_list A list of additional arguments as created by the | ||
#' [cdc_baseline_args_list()] constructor function. | ||
#' | ||
#' @return A data frame of point and interval forecasts at for all | ||
#' aheads (unique horizons) for each unique combination of `key_vars`. | ||
#' @export | ||
#' | ||
#' @examples | ||
#' library(dplyr) | ||
#' weekly_deaths <- case_death_rate_subset %>% | ||
#' select(geo_value, time_value, death_rate) %>% | ||
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>% | ||
#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>% | ||
#' select(-pop, -death_rate) %>% | ||
#' group_by(geo_value) %>% | ||
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>% | ||
#' ungroup() %>% | ||
#' filter(weekdays(time_value) == "Saturday") | ||
#' | ||
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") | ||
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn) | ||
#' | ||
#' if (require(ggplot2)) { | ||
#' forecast_date <- unique(preds$forecast_date) | ||
#' four_states <- c("ca", "pa", "wa", "ny") | ||
#' preds %>% | ||
#' filter(geo_value %in% four_states) %>% | ||
#' ggplot(aes(target_date)) + | ||
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) + | ||
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) + | ||
#' geom_line(aes(y = .pred), color = "orange") + | ||
#' geom_line( | ||
#' data = weekly_deaths %>% filter(geo_value %in% four_states), | ||
#' aes(x = time_value, y = deaths) | ||
#' ) + | ||
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) + | ||
#' labs(x = "Date", y = "Weekly deaths") + | ||
#' facet_wrap(~geo_value, scales = "free_y") + | ||
#' theme_bw() + | ||
#' geom_vline(xintercept = forecast_date) | ||
#' } | ||
cdc_baseline_forecaster <- function( | ||
epi_data, | ||
outcome, | ||
args_list = cdc_baseline_args_list()) { | ||
validate_forecaster_inputs(epi_data, outcome, "time_value") | ||
if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) { | ||
cli_stop("args_list was not created using `cdc_baseline_args_list().") | ||
} | ||
keys <- epi_keys(epi_data) | ||
ek <- kill_time_value(keys) | ||
outcome <- rlang::sym(outcome) | ||
|
||
|
||
r <- epi_recipe(epi_data) %>% | ||
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>% | ||
recipes::update_role(!!outcome, new_role = "predictor") %>% | ||
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>% | ||
step_training_window(n_recent = args_list$n_training) | ||
|
||
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) | ||
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead | ||
|
||
|
||
latest <- get_test_data( | ||
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer, | ||
forecast_date | ||
) | ||
|
||
f <- frosting() %>% | ||
layer_predict() %>% | ||
layer_cdc_flatline_quantiles( | ||
aheads = args_list$aheads, | ||
quantile_levels = args_list$quantile_levels, | ||
nsims = args_list$nsims, | ||
by_key = args_list$quantile_by_key, | ||
symmetrize = args_list$symmetrize, | ||
nonneg = args_list$nonneg | ||
) %>% | ||
layer_add_forecast_date(forecast_date = forecast_date) %>% | ||
layer_unnest(.pred_distn_all) | ||
# layer_add_target_date(target_date = target_date) | ||
if (args_list$nonneg) f <- layer_threshold(f, ".pred") | ||
|
||
eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") | ||
|
||
wf <- epi_workflow(r, eng, f) | ||
wf <- generics::fit(wf, epi_data) | ||
preds <- suppressWarnings(predict(wf, new_data = latest)) %>% | ||
tibble::as_tibble() %>% | ||
dplyr::select(-time_value) %>% | ||
dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency) | ||
|
||
structure( | ||
list( | ||
predictions = preds, | ||
epi_workflow = wf, | ||
metadata = list( | ||
training = attr(epi_data, "metadata"), | ||
forecast_created = Sys.time() | ||
) | ||
), | ||
class = c("cdc_baseline_fcast", "canned_epipred") | ||
) | ||
} | ||
|
||
|
||
|
||
#' CDC baseline forecaster argument constructor | ||
#' | ||
#' Constructs a list of arguments for [cdc_baseline_forecaster()]. | ||
#' | ||
#' @inheritParams arx_args_list | ||
#' @param data_frequency Integer or string. This describes the frequency of the | ||
#' input `epi_df`. For typical FluSight forecasts, this would be `"1 week"`. | ||
#' Allowable arguments are integers (taken to mean numbers of days) or a | ||
#' string like `"7 days"` or `"2 weeks"`. Currently, all other periods | ||
#' (other than days or weeks) result in an error. | ||
#' @param aheads Integer vector. Unlike [arx_forecaster()], this doesn't have | ||
#' any effect on the predicted values. | ||
#' Predictions are always the most recent observation. This determines the | ||
#' set of prediction horizons for [layer_cdc_flatline_quantiles()]`. It interacts | ||
#' with the `data_frequency` argument. So, for example, if the data is daily | ||
#' and you want forecasts for 1:4 days ahead, then you would use `1:4`. However, | ||
#' if you want one-week predictions, you would set this as `c(7, 14, 21, 28)`. | ||
#' But if `data_frequency` is `"1 week"`, then you would set it as `1:4`. | ||
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce | ||
#' prediction intervals. These are created by computing the quantiles of | ||
#' training residuals. A `NULL` value will result in point forecasts only. | ||
#' @param nsims Positive integer. The number of draws from the empirical CDF. | ||
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting | ||
#' in linear interpolation on the X scale. This is achieved with | ||
#' [stats::quantile()] Type 7 (the default for that function). | ||
#' @param nonneg Logical. Force all predictive intervals be non-negative. | ||
#' Because non-negativity is forced _before_ propagating forward, this | ||
#' has slightly different behaviour than would occur if using | ||
#' [layer_threshold()]. | ||
#' | ||
#' @return A list containing updated parameter choices with class `cdc_flat_fcast`. | ||
#' @export | ||
#' | ||
#' @examples | ||
#' cdc_baseline_args_list() | ||
#' cdc_baseline_args_list(symmetrize = FALSE) | ||
#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120) | ||
cdc_baseline_args_list <- function( | ||
data_frequency = "1 week", | ||
aheads = 1:4, | ||
n_training = Inf, | ||
forecast_date = NULL, | ||
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99), | ||
nsims = 1e3L, | ||
symmetrize = TRUE, | ||
nonneg = TRUE, | ||
quantile_by_key = "geo_value", | ||
nafill_buffer = Inf) { | ||
arg_is_scalar(n_training, nsims, data_frequency) | ||
data_frequency <- parse_period(data_frequency) | ||
arg_is_pos_int(data_frequency) | ||
arg_is_chr(quantile_by_key, allow_empty = TRUE) | ||
arg_is_scalar(forecast_date, allow_null = TRUE) | ||
arg_is_date(forecast_date, allow_null = TRUE) | ||
arg_is_nonneg_int(aheads, nsims) | ||
arg_is_lgl(symmetrize, nonneg) | ||
arg_is_probabilities(quantile_levels, allow_null = TRUE) | ||
arg_is_pos(n_training) | ||
if (is.finite(n_training)) arg_is_pos_int(n_training) | ||
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) | ||
|
||
structure( | ||
enlist( | ||
data_frequency, | ||
aheads, | ||
n_training, | ||
forecast_date, | ||
quantile_levels, | ||
nsims, | ||
symmetrize, | ||
nonneg, | ||
quantile_by_key, | ||
nafill_buffer | ||
), | ||
class = c("cdc_baseline_fcast", "alist") | ||
) | ||
} | ||
|
||
#' @export | ||
print.cdc_baseline_fcast <- function(x, ...) { | ||
name <- "CDC Baseline" | ||
NextMethod(name = name, ...) | ||
} | ||
|
||
parse_period <- function(x) { | ||
arg_is_scalar(x) | ||
if (is.character(x)) { | ||
x <- unlist(strsplit(x, " ")) | ||
if (length(x) == 1L) x <- as.numeric(x) | ||
if (length(x) == 2L) { | ||
mult <- substr(x[2], 1, 3) | ||
mult <- switch( | ||
mult, | ||
day = 1L, | ||
wee = 7L, | ||
cli::cli_abort("incompatible timespan in `aheads`.") | ||
) | ||
x <- as.numeric(x[1]) * mult | ||
} | ||
if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.") | ||
} | ||
stopifnot(rlang::is_integerish(x)) | ||
as.integer(x) | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.