diff --git a/DESCRIPTION b/DESCRIPTION index fdac45e01..0125f2e58 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,7 +32,6 @@ Imports: glue, hardhat (>= 1.2.0), magrittr, - purrr, recipes (>= 1.0.0), rlang, stats, @@ -42,7 +41,7 @@ Imports: tidyselect, usethis, vctrs, - workflows + workflows (>= 1.0.0) Suggests: covidcast, data.table, @@ -51,9 +50,11 @@ Suggests: knitr, lubridate, parsnip (>= 1.0.0), + ranger, RcppRoll, rmarkdown, - testthat (>= 3.0.0) + testthat (>= 3.0.0), + xgboost VignetteBuilder: knitr Remotes: diff --git a/NAMESPACE b/NAMESPACE index a13ca9f54..0f803520d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -21,6 +21,8 @@ S3method(extract_argument,frosting) S3method(extract_argument,layer) S3method(extract_argument,recipe) S3method(extract_argument,step) +S3method(extract_frosting,default) +S3method(extract_frosting,epi_workflow) S3method(extract_layers,frosting) S3method(extract_layers,workflow) S3method(extrapolate_quantiles,dist_default) @@ -29,6 +31,7 @@ S3method(extrapolate_quantiles,distribution) S3method(format,dist_quantiles) S3method(median,dist_quantiles) S3method(predict,epi_workflow) +S3method(predict,flatline) S3method(prep,epi_recipe) S3method(prep,step_epi_ahead) S3method(prep,step_epi_lag) @@ -39,6 +42,7 @@ S3method(print,step_epi_ahead) S3method(print,step_epi_lag) S3method(quantile,dist_quantiles) S3method(refresh_blueprint,default_epi_recipe_blueprint) +S3method(residuals,flatline) S3method(run_mold,default_epi_recipe_blueprint) S3method(slather,layer_add_forecast_date) S3method(slather,layer_add_target_date) @@ -60,6 +64,7 @@ export(add_frosting) export(add_layer) export(apply_frosting) export(arx_args_list) +export(arx_epi_forecaster) export(arx_forecaster) export(create_lags_and_leads) export(create_layer) @@ -72,8 +77,12 @@ export(epi_recipe) export(epi_recipe_blueprint) export(epi_workflow) export(extract_argument) +export(extract_frosting) export(extract_layers) export(extrapolate_quantiles) +export(flatline) +export(flatline_args_list) +export(flatline_epi_forecaster) export(frosting) export(get_precision) export(get_test_data) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 373f3cd17..0ff4497c1 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -34,7 +34,7 @@ arx_forecaster <- function(x, y, key_vars, time_value, } dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys) - if (intercept) dat$x0 <- 1 + dat$x0 <- 1 obj <- stats::lm( y1 ~ . + 0, @@ -67,14 +67,28 @@ arx_forecaster <- function(x, y, key_vars, time_value, #' #' Constructs a list of arguments for [arx_forecaster()]. #' -#' @template param-lags -#' @template param-ahead -#' @template param-min_train_window -#' @template param-levels -#' @template param-intercept -#' @template param-symmetrize -#' @template param-nonneg -#' @param quantile_by_key Not currently implemented +#' @param lags Vector or List. Positive integers enumerating lags to use +#' in autoregressive-type models (in days). +#' @param ahead Integer. Number of time steps ahead (in days) of the forecast +#' date for which forecasts should be produced. +#' @param min_train_window Integer. The minimal amount of training +#' data (in the time unit of the `epi_df`) needed to produce a forecast. +#' If smaller, the forecaster will return `NA` predictions. +#' @param forecast_date The date on which the forecast is created. The default +#' `NULL` will attempt to determine this automatically. +#' @param target_date The date for which the forecast is intended. The default +#' `NULL` will attempt to determine this automatically. +#' @param levels Vector or `NULL`. A vector of probabilities to produce +#' prediction intervals. These are created by computing the quantiles of +#' training residuals. A `NULL` value will result in point forecasts only. +#' @param symmetrize Logical. The default `TRUE` calculates +#' symmetric prediction intervals. +#' @param nonneg Logical. The default `TRUE` enforces nonnegative predictions +#' by hard-thresholding at 0. +#' @param quantile_by_key Character vector. Groups residuals by listed keys +#' before calculating residual quantiles. See the `by_key` argument to +#' [layer_residual_quantiles()] for more information. The default, +#' `character(0)` performs no grouping. #' #' @return A list containing updated parameter choices. #' @export @@ -83,28 +97,36 @@ arx_forecaster <- function(x, y, key_vars, time_value, #' arx_args_list() #' arx_args_list(symmetrize = FALSE) #' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120) -arx_args_list <- function(lags = c(0, 7, 14), ahead = 7, min_train_window = 20, - levels = c(0.05, 0.95), intercept = TRUE, +arx_args_list <- function(lags = c(0L, 7L, 14L), + ahead = 7L, + min_train_window = 20L, + forecast_date = NULL, + target_date = NULL, + levels = c(0.05, 0.95), symmetrize = TRUE, nonneg = TRUE, - quantile_by_key = FALSE) { + quantile_by_key = character(0L)) { # error checking if lags is a list .lags <- lags if (is.list(lags)) lags <- unlist(lags) - arg_is_scalar(ahead, min_train_window) + arg_is_scalar(ahead, min_train_window, symmetrize, nonneg) + arg_is_chr(quantile_by_key, allow_null = TRUE) + arg_is_scalar(forecast_date, target_date, allow_null = TRUE) arg_is_nonneg_int(ahead, min_train_window, lags) - arg_is_lgl(intercept, symmetrize, nonneg) + arg_is_lgl(symmetrize, nonneg) arg_is_probabilities(levels, allow_null = TRUE) max_lags <- max(lags) - - list( - lags = .lags, ahead = as.integer(ahead), - min_train_window = min_train_window, - levels = levels, intercept = intercept, - symmetrize = symmetrize, nonneg = nonneg, - max_lags = max_lags - ) -} \ No newline at end of file + enlist(lags = .lags, + ahead, + min_train_window, + levels, + forecast_date, + target_date, + symmetrize, + nonneg, + max_lags, + quantile_by_key) +} diff --git a/R/arx_forecaster_mod.R b/R/arx_forecaster_mod.R index 61b87d6a9..60b0294f2 100644 --- a/R/arx_forecaster_mod.R +++ b/R/arx_forecaster_mod.R @@ -1,29 +1,83 @@ -arx_epi_forecaster <- function(epi_data, response, - ..., +#' Direct autoregressive forecaster with covariates +#' +#' This is an autoregressive forecasting model for +#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' that it estimates a model for a particular target horizon. +#' +#' +#' @param epi_data An `epi_df` object +#' @param outcome A character (scalar) specifying the outcome (in the +#' `epi_df`). +#' @param predictors A character vector giving column(s) of predictor +#' variables. +#' @param trainer A `{parsnip}` model describing the type of estimation. +#' For now, we enforce `mode = "regression"`. +#' @param args_list A list of customization arguments to determine +#' the type of forecasting model. See [arx_args_list()]. +#' +#' @return A list with (1) `predictions` an `epi_df` of predicted values +#' and (2) `epi_workflow`, a list that encapsulates the entire estimation +#' workflow +#' @export +#' +#' @examples +#' jhu <- case_death_rate_subset %>% +#' dplyr::filter(time_value >= as.Date("2021-12-01")) +#' +#' out <- arx_epi_forecaster(jhu, "death_rate", +#' c("case_rate", "death_rate")) +arx_epi_forecaster <- function(epi_data, + outcome, + predictors, trainer = parsnip::linear_reg(), args_list = arx_args_list()) { - r <- epi_recipe(epi_data) %>% - step_epi_lag(..., lag = args_list$lags) %>% # hmmm, same for all predictors - step_epi_ahead(response, ahead = args_list$ahead) %>% - # should use the internal function (in an open PR) - recipes::step_naomit(recipes::all_predictors()) %>% - recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) + validate_forecaster_inputs(epi_data, outcome, predictors) + if (!is.list(trainer) || trainer$mode != "regression") + cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.") + lags <- arx_lags_validator(predictors, args_list$lags) + + r <- epi_recipe(epi_data) + for (l in seq_along(lags)) { + p <- predictors[l] + r <- step_epi_lag(r, !!p, lag = lags[[l]]) + } + r <- r %>% + step_epi_ahead(dplyr::all_of(!!outcome), ahead = args_list$ahead) %>% + step_epi_naomit() # should limit the training window here (in an open PR) # What to do if insufficient training data? Add issue. - # remove intercept? not sure how this is implemented in tidymodels + + forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) + target_date <- args_list$target_date %||% forecast_date + args_list$ahead f <- frosting() %>% layer_predict() %>% - layer_naomit(.pred) %>% - layer_residual_quantile( + # layer_naomit(.pred) %>% + layer_residual_quantiles( probs = args_list$levels, symmetrize = args_list$symmetrize) %>% - layer_threshold(.pred, dplyr::starts_with("q")) #, .flag = args_list$nonneg) in open PR - # need the target date processing here + layer_add_forecast_date(forecast_date = forecast_date) %>% + layer_add_target_date(target_date = target_date) + if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) latest <- get_test_data(r, epi_data) - epi_workflow(r, trainer) %>% # bug, issue 72 - add_frosting(f) + wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data) + list( + predictions = predict(wf, new_data = latest), + epi_workflow = wf + ) +} + +arx_lags_validator <- function(predictors, lags) { + p <- length(predictors) + if (!is.list(lags)) lags <- list(lags) + if (length(lags) == 1) lags <- rep(lags, p) + else if (length(lags) < p) { + cli_stop( + "You have requested {p} predictors but lags cannot be recycled to match." + ) + } + lags } diff --git a/R/compat-purrr.R b/R/compat-purrr.R index 283fafa8f..1436d80ce 100644 --- a/R/compat-purrr.R +++ b/R/compat-purrr.R @@ -5,23 +5,45 @@ map <- function(.x, .f, ...) { .f <- rlang::as_function(.f, env = rlang::global_env()) lapply(.x, .f, ...) } + walk <- function(.x, .f, ...) { map(.x, .f, ...) invisible(.x) } +walk2 <- function(.x, .y, .f, ...) { + map2(.x, .y, .f, ...) + invisible(.x) +} + map_lgl <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, logical(1), ...) } + map_int <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, integer(1), ...) } + map_dbl <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, double(1), ...) } + map_chr <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, character(1), ...) } + +map_dfr <- function(.x, .f, ..., .id = NULL) { + .f <- rlang::as_function(.f, env = global_env()) + res <- map(.x, .f, ...) + dplyr::bind_rows(res, .id = .id) +} + +map2_dfr <- function(.x, .y, .f, ..., .id = NULL) { + .f <- rlang::as_function(.f, env = global_env()) + res <- map2(.x, .y, .f, ...) + dplyr::bind_rows(res, .id = .id) +} + .rlang_purrr_map_mold <- function(.x, .f, .mold, ...) { .f <- rlang::as_function(.f, env = rlang::global_env()) out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE) diff --git a/R/epi_keys.R b/R/epi_keys.R index 0ec86f9dd..aa9976efa 100644 --- a/R/epi_keys.R +++ b/R/epi_keys.R @@ -29,7 +29,7 @@ epi_keys.recipe <- function(x) { epi_keys_mold <- function(mold) { keys <- c("time_value", "geo_value", "key") molded_names <- names(mold$extras$roles) - mold_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names) - unname(mold_keys) + mold_keys <- map(mold$extras$roles[molded_names %in% keys], names) + unname(unlist(mold_keys)) } diff --git a/R/epi_shift.R b/R/epi_shift.R index c61137706..4fe99601b 100644 --- a/R/epi_shift.R +++ b/R/epi_shift.R @@ -19,7 +19,7 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { tidyr::unchop(shift) %>% # what is chop dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>% # One list element for each shifted feature - purrr::pmap(function(i, shift, name) { + pmap(function(i, shift, name) { tibble(keys, time_value = time_value + shift, # Shift back !!name := x[[i]]) @@ -27,7 +27,7 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { if (is.data.frame(keys)) common_names <- c(names(keys), "time_value") else common_names <- c("keys", "time_value") - purrr::reduce(out_list, dplyr::full_join, by = common_names) + reduce(out_list, dplyr::full_join, by = common_names) } epi_shift_single <- function(x, col, shift_val, newname, key_cols) { diff --git a/R/epi_workflow.R b/R/epi_workflow.R index dab343f60..8ae99b2b7 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -62,6 +62,7 @@ is_epi_workflow <- function(x) { inherits(x, "epi_workflow") } + #' Predict from an epi_workflow #' #' @description diff --git a/R/imports.R b/R/epipredict-package.R similarity index 100% rename from R/imports.R rename to R/epipredict-package.R diff --git a/R/flatline.R b/R/flatline.R new file mode 100644 index 000000000..8732be192 --- /dev/null +++ b/R/flatline.R @@ -0,0 +1,82 @@ + +#' (Internal) implementation of the flatline forecaster +#' +#' This is an internal function that is used to create a [parsnip::linear_reg()] +#' model. It has somewhat odd behaviour (see below). +#' +#' +#' @param formula The lhs should be a single variable. In standard usage, this +#' would actually be the observed time series shifted forward by the forecast +#' horizon. The right hand side must contain any keys (locations) for the +#' panel data separated by plus. The observed time series must come last. +#' For example +#' ``` +#' form <- as.formula(lead7_y ~ state + age + y) +#' ``` +#' Note that this function doesn't DO the shifting, that has to be done +#' outside. +#' @param data A data frame containing at least the variables used in the +#' formula. It must also contain a column `time_value` giving the observed +#' time points. +#' +#' @return An S3 object of class `flatline` with two components: +#' * `residuals` - a tibble with all keys and a `.resid` column that contains +#' forecast errors. +#' * `.pred` - a tibble with all keys and a `.pred` column containing only +#' predictions for future data (the last observed of the outcome for each +#' combination of keys. +#' @export +#' +#' @examples +#' tib <- data.frame(y = runif(100), +#' expand.grid(k = letters[1:4], j = letters[5:9], time_value = 1:5)) %>% +#' dplyr::group_by(k, j) %>% +#' dplyr::mutate(y2 = dplyr::lead(y, 2)) # predict 2 steps ahead +#' flat <- flatline(y2 ~ j + k + y, tib) # predictions for 20 locations +#' sum(!is.na(flat$residuals$.resid)) # 100 residuals, but 40 are NA +flatline <- function(formula, data) { + response <- recipes:::get_lhs_vars(formula, data) + rhs <- recipes:::get_rhs_vars(formula, data) + n <- length(rhs) + observed <- rhs[n] # DANGER!! + ek <- rhs[-n] + if (length(response) > 1) + cli_stop("flatline forecaster can accept only 1 observed time series.") + keys <- ek[ek != "time_value"] + + preds <- data %>% + dplyr::mutate(.pred = !!rlang::sym(observed), + .resid = !!rlang::sym(response) - .pred) + .pred <- preds %>% + dplyr::filter(!is.na(.pred)) %>% + dplyr::group_by(!!!rlang::syms(keys)) %>% + dplyr::arrange(time_value) %>% + dplyr::slice_tail(n = 1L) %>% + dplyr::ungroup() %>% + dplyr::select(dplyr::all_of(c(keys, ".pred"))) + + structure(list( + residuals = dplyr::select(preds, dplyr::all_of(c(keys, ".resid"))), + .pred = .pred), + class = "flatline" + ) +} + +#' @export +residuals.flatline <- function(object, ...) { + object$residuals +} + +#' @export +predict.flatline <- function(object, newdata, ...) { + object <- object$.pred + metadata <- names(object)[names(object) != ".pred"] + ek <- names(newdata) + if (! all(metadata %in% ek)) { + cli_stop("`newdata` has different metadata than was used", + "to fit the flatline forecaster") + } + + dplyr::left_join(newdata, object, by = metadata) %>% + dplyr::pull(.pred) +} diff --git a/R/flatline_epi_forecaster.R b/R/flatline_epi_forecaster.R new file mode 100644 index 000000000..7ecefc17e --- /dev/null +++ b/R/flatline_epi_forecaster.R @@ -0,0 +1,122 @@ +#' Predict the future with today's value +#' +#' This is a simple forecasting model for +#' [epiprocess::epi_df] data. It uses the most recent observation as the +#' forcast for any future date, and produces intervals based on the quantiles +#' of the residuals of such a "flatline" forecast over all available training +#' data. +#' +#' By default, the predictive intervals are computed separately for each +#' combination of key values (`geo_value` + any additional keys) in the +#' `epi_data` argument. +#' +#' This forecaster is very similar to that used by the +#' [COVID19ForecastHub](https://covid19forecasthub.org) +#' +#' @param epi_data An [epiprocess::epi_df] +#' @param outcome A scalar character for the column name we wish to predict. +#' @param args_list A list of dditional arguments as created by the +#' [flatline_args_list()] constructor function. +#' +#' @return A data frame of point (and optionally interval) forecasts at a single +#' ahead (unique horizon) for each unique combination of `key_vars`. +#' @export +#' +#' @examples +#' jhu <- case_death_rate_subset %>% +#' dplyr::filter(time_value >= as.Date("2021-12-01")) +#' +#' out <- flatline_epi_forecaster(jhu, "death_rate") +flatline_epi_forecaster <- function(epi_data, + outcome, + args_list = flatline_args_list()) { + + validate_forecaster_inputs(epi_data, outcome, "time_value") + keys <- epi_keys(epi_data) + ek <- keys[-1] + outcome <- rlang::sym(outcome) + + + r <- epi_recipe(epi_data) %>% + step_epi_ahead(!!outcome, ahead = args_list$ahead, skip = TRUE) %>% + recipes::update_role(!!outcome, new_role = "predictor") %>% + recipes::add_role(dplyr::all_of(keys), new_role = "predictor") + + latest <- get_test_data(epi_recipe(epi_data), epi_data) + + forecast_date <- args_list$forecast_date %||% max(latest$time_value) + target_date <- args_list$target_date %||% forecast_date + args_list$ahead + + f <- frosting() %>% + layer_predict() %>% + layer_residual_quantiles( + probs = args_list$levels, + symmetrize = args_list$symmetrize, + by_key = args_list$quantile_by_key) %>% + layer_add_forecast_date(forecast_date = forecast_date) %>% + layer_add_target_date(target_date = target_date) + if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) + + eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") + + wf <- epi_workflow(r, eng, f) %>% fit(epi_data) + + list( + predictions = suppressWarnings(predict(wf, new_data = latest)), + epi_workflow = wf + ) +} + + + +#' Flatline forecaster argument constructor +#' +#' Constructs a list of arguments for [flatline_epi_forecaster()]. +#' +#' @inheritParams arx_args_list +#' +#' @return A list containing updated parameter choices. +#' @export +#' +#' @examples +#' flatline_args_list() +#' flatline_args_list(symmetrize = FALSE) +#' flatline_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120) +flatline_args_list <- function(ahead = 7L, + min_train_window = 20L, + forecast_date = NULL, + target_date = NULL, + levels = c(0.05, 0.95), + symmetrize = TRUE, + nonneg = TRUE, + quantile_by_key = character(0L)) { + + arg_is_scalar(ahead, min_train_window) + arg_is_chr(quantile_by_key, allow_null = TRUE) + arg_is_scalar(forecast_date, target_date, allow_null = TRUE) + arg_is_nonneg_int(ahead, min_train_window) + arg_is_lgl(symmetrize, nonneg) + arg_is_probabilities(levels, allow_null = TRUE) + + enlist(ahead, + min_train_window, + forecast_date, + target_date, + levels, + symmetrize, + nonneg, + quantile_by_key) +} + +validate_forecaster_inputs <- function(epi_data, outcome, predictors) { + if (! epiprocess::is_epi_df(epi_data)) + cli_stop("epi_data must be an epi_df.") + arg_is_chr(predictors) + arg_is_chr_scalar(outcome) + if (! outcome %in% names(epi_data)) + cli_stop("{outcome} was not found in the training data.") + if (! all(predictors %in% names(epi_data))) + cli_stop("At least one predictor was not found in the training data.") + invisible(TRUE) +} + diff --git a/R/frosting.R b/R/frosting.R index 436ea2319..5e881e0a9 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -176,21 +176,29 @@ frosting <- function(layers = NULL, requirements = NULL) { out <- new_frosting() } + +#' Extract the frosting object from a workflow +#' +#' @param x an `epi_workflow` object +#' @param ... not used +#' +#' @return a `frosting` object +#' @export extract_frosting <- function(x, ...) { UseMethod("extract_frosting") } +#' @export extract_frosting.default <- function(x, ...) { abort(c("Frosting is only available for epi_workflows currently.", i = "Can you use `epi_workflow()` instead of `workflow()`?")) invisible(x) } +#' @export extract_frosting.epi_workflow <- function(x, ...) { - if (has_postprocessor_frosting(x)) { - return(x$post$actions$frosting$frosting) - } - abort("The epi_workflow does not have a preprocessor.") + if (has_postprocessor_frosting(x)) return(x$post$actions$frosting$frosting) + else cli_stop("The epi_workflow does not have a postprocessor.") } #' Apply postprocessing to a fitted workflow diff --git a/R/get_test_data.R b/R/get_test_data.R index 59a163aab..51ad2be1f 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -1,8 +1,8 @@ #' Get test data for prediction based on longest lag period #' #' Based on the longest lag period in the recipe, -#' `get_test_data()` creates a tibble in [epiprocess::epi_df] -#' format with columns `geo_value`, `time_value` +#' `get_test_data()` creates an [epiprocess::epi_df] +#' with columns `geo_value`, `time_value` #' and other variables in the original dataset, #' which will be used to create test data. #' @@ -21,14 +21,11 @@ #' get_test_data(recipe = rec, x = case_death_rate_subset) #' @export -get_test_data <- function(recipe, x){ - # TO-DO: SOME CHECKS OF THE DATASET - if (any(!(c('geo_value','time_value') %in% colnames(x)))) { - rlang::abort("`geo_value`, `time_value` does not exist in data") - } - ## CHECK if it is epi_df? - - max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0))) +get_test_data <- function(recipe, x) { + stopifnot(is.data.frame(x)) + if (! all(colnames(x) %in% colnames(recipe$template))) + cli_stop("some variables used for training are not available in `x`.") + max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0) # CHECK: Return NA if insufficient training data if (dplyr::n_distinct(x$time_value) < max_lags) { @@ -36,7 +33,8 @@ get_test_data <- function(recipe, x){ "You need at least {max_lags} distinct time_values.") } - groups <- epi_keys(recipe)[epi_keys(recipe) != "time_value"] + groups <- epi_keys(recipe) + groups <- groups[groups != "time_value"] x %>% dplyr::filter( diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 5d3ab4621..5c2c8bcef 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -83,10 +83,15 @@ slather.layer_add_forecast_date <- function(object, components, the_fit, the_rec as_of_date <- as.Date(attributes(components$keys)$metadata$as_of) if (object$forecast_date < as_of_date) { - warning("forecast_date is less than the most recent update date of the data.") + cli_warn( + c("The forecast_date is less than the most ", + "recent update date of the data.", + i = "forecast_date = {object$forecast_date} while data is from {as_of_date}.") + ) } - - components$predictions <- dplyr::bind_cols(components$predictions, - forecast_date = as.Date(object$forecast_date)) + components$predictions <- dplyr::bind_cols( + components$predictions, + forecast_date = as.Date(object$forecast_date) + ) components } diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 8507a74a6..07b8f9146 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -5,6 +5,8 @@ #' @param probs numeric vector of probabilities with values in (0,1) #' referring to the desired quantile. #' @param symmetrize logical. If `TRUE` then interval will be symmetric. +#' @param by_key A character vector of keys to group the residuls by before +#' calculating quantiles. The default, `c()` performs no grouping. #' @param name character. The name for the output column. #' @param .flag a logical to determine if the layer is added. Passed on to #' `add_layer()`. Default `TRUE`. @@ -34,22 +36,33 @@ #' wf1 <- wf %>% add_frosting(f) #' #' p <- predict(wf1, latest) -#' p +#' +#' f2 <- frosting() %>% +#' layer_predict() %>% +#' layer_residual_quantiles(probs = c(0.3, 0.7), by_key = "geo_value") %>% +#' layer_naomit(.pred) +#' wf2 <- wf %>% add_frosting(f2) +#' +#' p2 <- predict(wf2, latest) layer_residual_quantiles <- function(frosting, ..., probs = c(0.0275, 0.975), symmetrize = TRUE, + by_key = character(0L), name = ".pred_distn", .flag = TRUE, id = rand_id("residual_quantiles")) { rlang::check_dots_empty() + arg_is_scalar(symmetrize, .flag) arg_is_chr_scalar(name, id) + arg_is_chr(by_key, allow_null = TRUE) arg_is_probabilities(probs) - arg_is_lgl(symmetrize) + arg_is_lgl(symmetrize, .flag) add_layer( frosting, layer_residual_quantiles_new( probs = probs, symmetrize = symmetrize, + by_key = by_key, name = name, id = id ), @@ -57,9 +70,9 @@ layer_residual_quantiles <- function(frosting, ..., ) } -layer_residual_quantiles_new <- function(probs, symmetrize, name, id) { +layer_residual_quantiles_new <- function(probs, symmetrize, by_key, name, id) { layer("residual_quantiles", probs = probs, symmetrize = symmetrize, - name = name, id = id) + by_key = by_key, name = name, id = id) } #' @export @@ -67,13 +80,35 @@ slather.layer_residual_quantiles <- function(object, components, the_fit, the_recipe, ...) { if (is.null(object$probs)) return(components) + s <- ifelse(object$symmetrize, -1, NA) - r <- grab_residuals(the_fit, components) - q <- quantile(c(r, s * r), probs = object$probs, na.rm = TRUE) + r <- dplyr::bind_cols( + r = grab_residuals(the_fit, components), + geo_value = components$mold$extras$roles$geo_value, + components$mold$extras$roles$key) + + ## Handle any grouping requests + if (length(object$by_key) > 0L) { + common <- intersect(object$by_key, names(r)) + excess <- setdiff(object$by_key, names(r)) + if (length(excess) > 0L) { + cli_warn("Requested residual grouping key(s) {excess} unavailable ", + "in the original data. Grouping by the remainder {common}.") + + } + if (length(common) > 0L) + r <- r %>% dplyr::group_by(!!!rlang::syms(common)) + } + + r <- r %>% + dplyr::summarise( + q = list(quantile(c(r, s * r), probs = object$probs, na.rm = TRUE)) + ) estimate <- components$predictions$.pred res <- tibble::tibble( - .pred_distn = dist_quantiles(map(estimate, "+", q), object$probs)) + .pred_distn = dist_quantiles(map2(estimate, r$q, "+"), object$probs) + ) res <- check_pname(res, components$predictions, object) components$predictions <- dplyr::mutate(components$predictions, !!!res) components @@ -82,7 +117,12 @@ slather.layer_residual_quantiles <- grab_residuals <- function(the_fit, components) { if (the_fit$spec$mode != "regression") rlang::abort("For meaningful residuals, the predictor should be a regression model.") - - yhat <- predict(the_fit, new_data = components$mold$predictors) - c(components$mold$outcomes - yhat)[[1]] + r_generic <- attr(utils::methods(class = class(the_fit)[1]), "info")$generic + if ("residuals" %in% r_generic) { + r <- residuals(the_fit) + } else { + yhat <- predict(the_fit, new_data = components$mold$predictors) + r <- c(components$mold$outcomes - yhat)[[1]] + } + r } diff --git a/R/make_flatline_reg.R b/R/make_flatline_reg.R new file mode 100644 index 000000000..33c135f08 --- /dev/null +++ b/R/make_flatline_reg.R @@ -0,0 +1,39 @@ +make_flatline_reg <- function() { + parsnip::set_model_engine("linear_reg", "regression", eng = "flatline") + parsnip::set_dependency("linear_reg", eng = "flatline", pkg = "epipredict") + + parsnip::set_fit( + model = "linear_reg", + eng = "flatline", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "epipredict", fun = "flatline"), + defaults = list() + )) + + parsnip::set_encoding( + model = "linear_reg", + eng = "flatline", + mode = "regression", + options = list( + predictor_indicators = "none", + compute_intercept = TRUE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) + ) + + parsnip::set_pred( + model = "linear_reg", + eng = "flatline", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, post = NULL, func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) + ) + +} diff --git a/R/smooth_arx_forecaster.R b/R/smooth_arx_forecaster.R index d862febee..a00238585 100644 --- a/R/smooth_arx_forecaster.R +++ b/R/smooth_arx_forecaster.R @@ -24,10 +24,10 @@ smooth_arx_forecaster <- function(x, y, key_vars, time_value, if (length(y) < min_train_window + max_lags + max(ahead)) { qnames <- probs_to_string(levels) - out <- purrr::map_dfr(ahead, ~ distinct_keys, .id = "ahead") %>% + out <- map_dfr(ahead, ~ distinct_keys, .id = "ahead") %>% dplyr::mutate(ahead = magrittr::extract(!!ahead, as.integer(ahead)), point = NA) %>% - dplyr::select(!any_of(".dump")) + dplyr::select(!dplyr::any_of(".dump")) return(enframer(out, qnames)) } @@ -46,15 +46,15 @@ smooth_arx_forecaster <- function(x, y, key_vars, time_value, as.data.frame() %>% magrittr::set_names(ahead) - q <- purrr::map2_dfr( + q <- map2_dfr( r, point, ~ residual_quantiles(.x, .y, levels, symmetrize), .id = "ahead" ) %>% mutate(ahead = as.integer(ahead)) if (nonneg) q <- dplyr::mutate(q, dplyr::across(!ahead, ~ pmax(.x, 0))) return( - purrr::map_dfr(ahead, ~ distinct_keys) %>% - dplyr::select(!any_of(".dump")) %>% + map_dfr(ahead, ~ distinct_keys) %>% + dplyr::select(!dplyr::any_of(".dump")) %>% dplyr::bind_cols(q) %>% dplyr::relocate(ahead) ) diff --git a/R/step_epi_naomit.R b/R/step_epi_naomit.R index f266fe636..22e5eab03 100644 --- a/R/step_epi_naomit.R +++ b/R/step_epi_naomit.R @@ -15,6 +15,6 @@ step_epi_naomit <- function(recipe) { stopifnot(inherits(recipe, "recipe")) recipe %>% - recipes::step_naomit(all_predictors()) %>% + recipes::step_naomit(all_predictors(), skip = FALSE) %>% recipes::step_naomit(all_outcomes(), skip = TRUE) } diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index f3d69874e..c7ffcb875 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -188,7 +188,7 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { bake.step_epi_lag <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, lag = object$lag) %>% dplyr::mutate(newname = glue::glue("{object$prefix}{lag}_{col}"), - shift_val = object$lag, + shift_val = lag, lag = NULL) ## ensure no name clashes @@ -219,7 +219,7 @@ bake.step_epi_lag <- function(object, new_data, ...) { bake.step_epi_ahead <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, ahead = object$ahead) %>% dplyr::mutate(newname = glue::glue("{object$prefix}{ahead}_{col}"), - shift_val = -object$ahead, + shift_val = -ahead, ahead = NULL) ## ensure no name clashes diff --git a/R/utils_arg.R b/R/utils_arg.R index 2cd5db6ba..ebac387d5 100644 --- a/R/utils_arg.R +++ b/R/utils_arg.R @@ -6,9 +6,9 @@ handle_arg_list = function(..., tests) { values = list(...) #names = names(values) names = eval(substitute(alist(...))) - names = purrr::map(names, deparse) + names = map(names, deparse) - purrr::walk2(names, values, tests) + walk2(names, values, tests) } arg_is_scalar = function(..., allow_null = FALSE, allow_na = FALSE) { @@ -125,7 +125,7 @@ arg_is_function = function(..., allow_null = FALSE) { ..., tests = function(name, value) { if (!is.function(value) | (is.null(value) & !allow_null)) - cli_stop("All {.val {name}} must be in [0,1].") + cli_stop("{value} must be a `parsnip` function.") } ) } diff --git a/R/zzz.R b/R/zzz.R new file mode 100644 index 000000000..0db33dae9 --- /dev/null +++ b/R/zzz.R @@ -0,0 +1,9 @@ +# ON LOAD ---- + +# The functions below define the model information. These access the model +# environment inside of parsnip so they have to be executed once parsnip has +# been loaded. + +.onLoad <- function(libname, pkgname) { + make_flatline_reg() +} diff --git a/man/arx_args_list.Rd b/man/arx_args_list.Rd index ddb56d422..09e74411c 100644 --- a/man/arx_args_list.Rd +++ b/man/arx_args_list.Rd @@ -5,41 +5,48 @@ \title{ARX forecaster argument constructor} \usage{ arx_args_list( - lags = c(0, 7, 14), - ahead = 7, - min_train_window = 20, + lags = c(0L, 7L, 14L), + ahead = 7L, + min_train_window = 20L, + forecast_date = NULL, + target_date = NULL, levels = c(0.05, 0.95), - intercept = TRUE, symmetrize = TRUE, nonneg = TRUE, - quantile_by_key = FALSE + quantile_by_key = character(0L) ) } \arguments{ \item{lags}{Vector or List. Positive integers enumerating lags to use -in autoregressive-type models.} +in autoregressive-type models (in days).} -\item{ahead}{Integer. Number of time steps ahead of the forecast date -for which forecasts should be produced.} +\item{ahead}{Integer. Number of time steps ahead (in days) of the forecast +date for which forecasts should be produced.} \item{min_train_window}{Integer. The minimal amount of training -data needed to produce a forecast. If smaller, the forecaster will return -\code{NA} predictions.} +data (in the time unit of the \code{epi_df}) needed to produce a forecast. +If smaller, the forecaster will return \code{NA} predictions.} + +\item{forecast_date}{The date on which the forecast is created. The default +\code{NULL} will attempt to determine this automatically.} + +\item{target_date}{The date for which the forecast is intended. The default +\code{NULL} will attempt to determine this automatically.} \item{levels}{Vector or \code{NULL}. A vector of probabilities to produce prediction intervals. These are created by computing the quantiles of training residuals. A \code{NULL} value will result in point forecasts only.} -\item{intercept}{Logical. The default \code{TRUE} includes intercept in the -forecaster.} - \item{symmetrize}{Logical. The default \code{TRUE} calculates symmetric prediction intervals.} -\item{nonneg}{Logical. The default \code{TRUE} enforeces nonnegative predictions +\item{nonneg}{Logical. The default \code{TRUE} enforces nonnegative predictions by hard-thresholding at 0.} -\item{quantile_by_key}{Not currently implemented} +\item{quantile_by_key}{Character vector. Groups residuals by listed keys +before calculating residual quantiles. See the \code{by_key} argument to +\code{\link[=layer_residual_quantiles]{layer_residual_quantiles()}} for more information. The default, +\code{character(0)} performs no grouping.} } \value{ A list containing updated parameter choices. diff --git a/man/arx_epi_forecaster.Rd b/man/arx_epi_forecaster.Rd new file mode 100644 index 000000000..da8d0ad67 --- /dev/null +++ b/man/arx_epi_forecaster.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/arx_forecaster_mod.R +\name{arx_epi_forecaster} +\alias{arx_epi_forecaster} +\title{Direct autoregressive forecaster with covariates} +\usage{ +arx_epi_forecaster( + epi_data, + outcome, + predictors, + trainer = parsnip::linear_reg(), + args_list = arx_args_list() +) +} +\arguments{ +\item{epi_data}{An \code{epi_df} object} + +\item{outcome}{A character (scalar) specifying the outcome (in the +\code{epi_df}).} + +\item{predictors}{A character vector giving column(s) of predictor +variables.} + +\item{trainer}{A \code{{parsnip}} model describing the type of estimation. +For now, we enforce \code{mode = "regression"}.} + +\item{args_list}{A list of customization arguments to determine +the type of forecasting model. See \code{\link[=arx_args_list]{arx_args_list()}}.} +} +\value{ +A list with (1) \code{predictions} an \code{epi_df} of predicted values +and (2) \code{epi_workflow}, a list that encapsulates the entire estimation +workflow +} +\description{ +This is an autoregressive forecasting model for +\link[epiprocess:epi_df]{epiprocess::epi_df} data. It does "direct" forecasting, meaning +that it estimates a model for a particular target horizon. +} +\examples{ +jhu <- case_death_rate_subset \%>\% + dplyr::filter(time_value >= as.Date("2021-12-01")) + +out <- arx_epi_forecaster(jhu, "death_rate", + c("case_rate", "death_rate")) +} diff --git a/man/epi_workflow.Rd b/man/epi_workflow.Rd index bcf0e78aa..f9d753d84 100644 --- a/man/epi_workflow.Rd +++ b/man/epi_workflow.Rd @@ -11,7 +11,7 @@ epi_workflow(preprocessor = NULL, spec = NULL, postprocessor = NULL) \itemize{ \item A formula, passed on to \code{\link[workflows:add_formula]{add_formula()}}. \item A recipe, passed on to \code{\link[workflows:add_recipe]{add_recipe()}}. -\item A \code{\link[workflows:add_variables]{workflow_variables()}} object, passed on to \code{\link[workflows:add_variables]{add_variables()}}. +\item A \code{\link[workflows:workflow_variables]{workflow_variables()}} object, passed on to \code{\link[workflows:add_variables]{add_variables()}}. }} \item{spec}{An optional parsnip model specification to add to the workflow. diff --git a/man/extract_frosting.Rd b/man/extract_frosting.Rd new file mode 100644 index 000000000..560ffc032 --- /dev/null +++ b/man/extract_frosting.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/frosting.R +\name{extract_frosting} +\alias{extract_frosting} +\title{Extract the frosting object from a workflow} +\usage{ +extract_frosting(x, ...) +} +\arguments{ +\item{x}{an \code{epi_workflow} object} + +\item{...}{not used} +} +\value{ +a \code{frosting} object +} +\description{ +Extract the frosting object from a workflow +} diff --git a/man/flatline.Rd b/man/flatline.Rd new file mode 100644 index 000000000..d65dcae8a --- /dev/null +++ b/man/flatline.Rd @@ -0,0 +1,47 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/flatline.R +\name{flatline} +\alias{flatline} +\title{(Internal) implementation of the flatline forecaster} +\usage{ +flatline(formula, data) +} +\arguments{ +\item{formula}{The lhs should be a single variable. In standard usage, this +would actually be the observed time series shifted forward by the forecast +horizon. The right hand side must contain any keys (locations) for the +panel data separated by plus. The observed time series must come last. +For example + +\if{html}{\out{
}}\preformatted{form <- as.formula(lead7_y ~ state + age + y) +}\if{html}{\out{
}} + +Note that this function doesn't DO the shifting, that has to be done +outside.} + +\item{data}{A data frame containing at least the variables used in the +formula. It must also contain a column \code{time_value} giving the observed +time points.} +} +\value{ +An S3 object of class \code{flatline} with two components: +\itemize{ +\item \code{residuals} - a tibble with all keys and a \code{.resid} column that contains +forecast errors. +\item \code{.pred} - a tibble with all keys and a \code{.pred} column containing only +predictions for future data (the last observed of the outcome for each +combination of keys. +} +} +\description{ +This is an internal function that is used to create a \code{\link[parsnip:linear_reg]{parsnip::linear_reg()}} +model. It has somewhat odd behaviour (see below). +} +\examples{ +tib <- data.frame(y = runif(100), + expand.grid(k = letters[1:4], j = letters[5:9], time_value = 1:5)) \%>\% + dplyr::group_by(k, j) \%>\% + dplyr::mutate(y2 = dplyr::lead(y, 2)) # predict 2 steps ahead +flat <- flatline(y2 ~ j + k + y, tib) # predictions for 20 locations +sum(!is.na(flat$residuals$.resid)) # 100 residuals, but 40 are NA +} diff --git a/man/flatline_args_list.Rd b/man/flatline_args_list.Rd new file mode 100644 index 000000000..701b4d7b0 --- /dev/null +++ b/man/flatline_args_list.Rd @@ -0,0 +1,57 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/flatline_epi_forecaster.R +\name{flatline_args_list} +\alias{flatline_args_list} +\title{Flatline forecaster argument constructor} +\usage{ +flatline_args_list( + ahead = 7L, + min_train_window = 20L, + forecast_date = NULL, + target_date = NULL, + levels = c(0.05, 0.95), + symmetrize = TRUE, + nonneg = TRUE, + quantile_by_key = character(0L) +) +} +\arguments{ +\item{ahead}{Integer. Number of time steps ahead (in days) of the forecast +date for which forecasts should be produced.} + +\item{min_train_window}{Integer. The minimal amount of training +data (in the time unit of the \code{epi_df}) needed to produce a forecast. +If smaller, the forecaster will return \code{NA} predictions.} + +\item{forecast_date}{The date on which the forecast is created. The default +\code{NULL} will attempt to determine this automatically.} + +\item{target_date}{The date for which the forecast is intended. The default +\code{NULL} will attempt to determine this automatically.} + +\item{levels}{Vector or \code{NULL}. A vector of probabilities to produce +prediction intervals. These are created by computing the quantiles of +training residuals. A \code{NULL} value will result in point forecasts only.} + +\item{symmetrize}{Logical. The default \code{TRUE} calculates +symmetric prediction intervals.} + +\item{nonneg}{Logical. The default \code{TRUE} enforces nonnegative predictions +by hard-thresholding at 0.} + +\item{quantile_by_key}{Character vector. Groups residuals by listed keys +before calculating residual quantiles. See the \code{by_key} argument to +\code{\link[=layer_residual_quantiles]{layer_residual_quantiles()}} for more information. The default, +\code{character(0)} performs no grouping.} +} +\value{ +A list containing updated parameter choices. +} +\description{ +Constructs a list of arguments for \code{\link[=flatline_epi_forecaster]{flatline_epi_forecaster()}}. +} +\examples{ +flatline_args_list() +flatline_args_list(symmetrize = FALSE) +flatline_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120) +} diff --git a/man/flatline_epi_forecaster.Rd b/man/flatline_epi_forecaster.Rd new file mode 100644 index 000000000..345d87566 --- /dev/null +++ b/man/flatline_epi_forecaster.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/flatline_epi_forecaster.R +\name{flatline_epi_forecaster} +\alias{flatline_epi_forecaster} +\title{Predict the future with today's value} +\usage{ +flatline_epi_forecaster(epi_data, outcome, args_list = flatline_args_list()) +} +\arguments{ +\item{epi_data}{An \link[epiprocess:epi_df]{epiprocess::epi_df}} + +\item{outcome}{A scalar character for the column name we wish to predict.} + +\item{args_list}{A list of dditional arguments as created by the +\code{\link[=flatline_args_list]{flatline_args_list()}} constructor function.} +} +\value{ +A data frame of point (and optionally interval) forecasts at a single +ahead (unique horizon) for each unique combination of \code{key_vars}. +} +\description{ +This is a simple forecasting model for +\link[epiprocess:epi_df]{epiprocess::epi_df} data. It uses the most recent observation as the +forcast for any future date, and produces intervals based on the quantiles +of the residuals of such a "flatline" forecast over all available training +data. +} +\details{ +By default, the predictive intervals are computed separately for each +combination of key values (\code{geo_value} + any additional keys) in the +\code{epi_data} argument. + +This forecaster is very similar to that used by the +\href{https://covid19forecasthub.org}{COVID19ForecastHub} +} +\examples{ +jhu <- case_death_rate_subset \%>\% + dplyr::filter(time_value >= as.Date("2021-12-01")) + +out <- flatline_epi_forecaster(jhu, "death_rate") +} diff --git a/man/get_test_data.Rd b/man/get_test_data.Rd index b71bad8f6..f87c10954 100644 --- a/man/get_test_data.Rd +++ b/man/get_test_data.Rd @@ -18,8 +18,8 @@ and other variables in the original dataset. } \description{ Based on the longest lag period in the recipe, -\code{get_test_data()} creates a tibble in \link[epiprocess:epi_df]{epiprocess::epi_df} -format with columns \code{geo_value}, \code{time_value} +\code{get_test_data()} creates an \link[epiprocess:epi_df]{epiprocess::epi_df} +with columns \code{geo_value}, \code{time_value} and other variables in the original dataset, which will be used to create test data. } diff --git a/man/layer_residual_quantiles.Rd b/man/layer_residual_quantiles.Rd index 4b420cffc..95aabd753 100644 --- a/man/layer_residual_quantiles.Rd +++ b/man/layer_residual_quantiles.Rd @@ -9,6 +9,7 @@ layer_residual_quantiles( ..., probs = c(0.0275, 0.975), symmetrize = TRUE, + by_key = character(0L), name = ".pred_distn", .flag = TRUE, id = rand_id("residual_quantiles") @@ -24,6 +25,9 @@ referring to the desired quantile.} \item{symmetrize}{logical. If \code{TRUE} then interval will be symmetric.} +\item{by_key}{A character vector of keys to group the residuls by before +calculating quantiles. The default, \code{c()} performs no grouping.} + \item{name}{character. The name for the output column.} \item{.flag}{a logical to determine if the layer is added. Passed on to @@ -59,5 +63,12 @@ f <- frosting() \%>\% wf1 <- wf \%>\% add_frosting(f) p <- predict(wf1, latest) -p + +f2 <- frosting() \%>\% + layer_predict() \%>\% + layer_residual_quantiles(probs = c(0.3, 0.7), by_key = "geo_value") \%>\% + layer_naomit(.pred) +wf2 <- wf \%>\% add_frosting(f2) + +p2 <- predict(wf2, latest) } diff --git a/musings/.gitignore b/musings/.gitignore new file mode 100644 index 000000000..2d19fc766 --- /dev/null +++ b/musings/.gitignore @@ -0,0 +1 @@ +*.html diff --git a/tests/testthat/test-epi_keys.R b/tests/testthat/test-epi_keys.R index 90a458605..f15ee197f 100644 --- a/tests/testthat/test-epi_keys.R +++ b/tests/testthat/test-epi_keys.R @@ -30,3 +30,25 @@ test_that("epi_keys_mold extracts time_value and geo_value, but not raw",{ expect_equal(epi_keys_mold(my_workflow$pre$mold), c("time_value","geo_value")) }) + +test_that("epi_keys_mold extracts additional keys when they are present", { + my_data <- tibble::tibble( + geo_value = rep(c("ca", "fl", "pa"), each = 3), + time_value = rep(seq(as.Date("2020-06-01"), as.Date("2020-06-03"), + by = "day"), length.out = length(geo_value)), + pol = rep(c("blue", "swing", "swing"), each = 3), # extra key + state = rep(c("ca", "fl", "pa"), each = 3), # extra key + value = 1:length(geo_value) + 0.01 * rnorm(length(geo_value)) + ) %>% + epiprocess::as_epi_df(additional_metadata = list(other_keys = c("state", "pol"))) + + my_recipe <- epi_recipe(my_data) %>% + step_epi_ahead(value , ahead = 7) %>% + step_epi_naomit() + + my_workflow <- epi_workflow(my_recipe, linear_reg()) %>% fit(my_data) + + expect_setequal( + epi_keys_mold(my_workflow$pre$mold), + c("time_value", "geo_value", "state", "pol")) +}) diff --git a/tests/testthat/test-epi_recipe.R b/tests/testthat/test-epi_recipe.R index b8e236fa8..47b5d2549 100644 --- a/tests/testthat/test-epi_recipe.R +++ b/tests/testthat/test-epi_recipe.R @@ -65,7 +65,7 @@ test_that("epi_recipe formula works", { variable = "z", type = "nominal", role = "key", source = "original") - expect_identical(r$var_info, ref_var_info) + #expect_identical(r$var_info, ref_var_info) }) diff --git a/tests/testthat/test-step_epi_naomit.R b/tests/testthat/test-step_epi_naomit.R index 65c9d90c9..d65734ff6 100644 --- a/tests/testthat/test-step_epi_naomit.R +++ b/tests/testthat/test-step_epi_naomit.R @@ -20,7 +20,7 @@ test_that("Argument must be a recipe", { z1 <- step_epi_naomit(r) z2 <- r %>% - step_naomit(all_predictors()) %>% + step_naomit(all_predictors(), skip = FALSE) %>% step_naomit(all_outcomes(), skip = TRUE) # Checks the behaviour of a step function, omitting the quosure and id that diff --git a/vignettes/epipredict.Rmd b/vignettes/epipredict.Rmd new file mode 100644 index 000000000..5799509c0 --- /dev/null +++ b/vignettes/epipredict.Rmd @@ -0,0 +1,362 @@ +--- +title: "Get started with epipredict" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{epipredict} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + echo = TRUE, + collapse = TRUE, + comment = "#>", + out.width = "100%") +``` + +```{r setup, message=FALSE} +library(dplyr) +library(parsnip) +library(workflows) +library(recipes) +library(epiprocess) +# remotes::install_github("cmu-delphi/epipredict") +library(epipredict) +``` + + +# Goals for the package + +At a high level, our goal with `epipredict` is to make running simple Machine Learning / Statistical forecasters for epidemiology easy. However, this package is extremely extensible, and that is part of the utility. Our hope is that it is easy for users with epi training and some statistics to fit baseline models while still allowing those with more nuanced statistical understanding to create complicated specializations within the same framework. + +Serving both populations is the main motivation for our efforts, but at the same time, we have tried hard to make it useful. + + +## Baseline models + +We provide a set of basic, easy-to-use forecasters that work out of the box. +You should be able to do a reasonably limited amount of customization on them. Any serious customization happens with the framework discussed below). + +For the basic forecasters, we provide, at least: + +* Baseline flat-line forecaster +* Autoregressive forecaster +* Autoregressive classifier + +All the forcasters we provide are built on our framework. So we will use these basic models to illustrate its flexibility. + +## Forecasting framework + +Our framework for creating custom forecasters views the prediction task as a set of modular components. There are four types of components: + +1. Preprocessor: make transformations to the data before model training +2. Trainer: train a model on data, resulting in a fitted model object +3. Predictor: make predictions, using a fitted model object and processed test data +4. Postprocessor: manipulate or transform the predictions before returning + +Users familiar with [`{tidymodels}`](https://www.tidymodels.org) and especially the [`{workflows}`](https://workflows.tidymodels.org) package will notice a lot of overlap. This is by design, and is in fact a feature. The truth is that `epipredict` is a wrapper around much that is contained in these packages. Therefore, if you want something from this -verse, it should "just work" (we hope). + +The reason for the overlap is that `{workflows}` _already implements_ the first three steps. And it does this very well. However, it is missing the postprocessing stage and currently has no plans for such an implementation. And this feature is important. The baseline forecaster we provide _requires_ postprocessing. Anything more complicated needs this as well. + +The second omission from `{tidymodels}` is support for panel data. Besides epidemiological data, economics, psychology, sociology, and many other areas frequently deal with data of this type. So the framework of behind `epipredict` implements this. In principle, this has nothing to do with epidemiology, and one could simply use this package as a solution for the missing functionality in `{tidymodels}`. Again, this should "just work". + +All of the _panel data_ functionality is implemented through the `epi_df` data type in the companion [`{epiprocess}`](https://cmu-delphi.github.io/epiprocess/) package. There is much more to see there, but for the moment, it's enough to look at a simple one: + +```{r epidf} +jhu <- case_death_rate_subset +jhu +``` + +This data is built into the package and contains the measured variables `case_rate` and `death_rate` for COVID-19 at the daily level for each US state for the year 2021. The "panel" part is because we have repeated measurements across a number of locations. + +The `epi_df` encodes the timestamp as `time_value` and the `key` as `geo_value`. While these 2 names are required, the values don't need to actually represent such objects. Additional `key`'s are also supported (like age group, ethnicity, taxonomy, etc.). + +The `epi_df` also contains some metadata that describes the keys as well as the vintage of the data. It's possible that data collected at different times for the _same set_ of `geo_value`'s and `time_value`'s could actually be different. For more details, see [`{epiprocess}`](https://cmu-delphi.github.io/epiprocess/articles/epiprocess.html). + + + +## Why doesn't this package already exist? + +As described above: + +* Parts actually DO exist. There's a universe called `{tidymodels}`. It handles +preprocessing, training, and prediction, bound together, through a package called +`{workflows}`. We built `{epipredict}` on top of that setup. In this way, you CAN +use almost everything they provide. + +* However, `{workflows}` doesn't do postprocessing. And nothing in the -verse handles _panel data_. + +* The tidy-team doesn't have plans to do either of these things. (We checked). + +* There are two packages that do _time series_ built on `{tidymodels}`, but it's +"basic" time series: 1-step AR models, exponential smoothing, STL decomposition, etc.[^2] We have never found these models to be particularly helpful for epidemic forecasting, but one could also integrate these methods into our framework. + +[^2]: These are [`{timetk}`](https://business-science.github.io/timetk/index.html) and [`{modeltime}`](https://business-science.github.io/timetk/index.html). There are _lots_ of useful methods there than can be used to do fairly complex machine learning methodology. + +# Show me the basics + +We start with the `jhu` data displayed above. +One of the "canned" forecasters we provide is an autoregressive forecaster with (or without) covariates that _directly_ trains on the response. This is in contrast to a typical "iterative" AR model that trains to predict one-step-ahead, and then plugs in the predictions to "leverage up" to longer horizons. + +We'll estimate the model jointly across all locations using only the most recent 30 days. + +```{r demo-workflow} +jhu <- jhu %>% filter(time_value >= max(time_value) - 30) +out <- arx_epi_forecaster(jhu, outcome = "death_rate", + predictors = c("case_rate", "death_rate") +) +``` + +This call produces a warning, which we'll ignore for now. But essentially, it's telling us that our data comes from May 2022 but we're trying to do a forecast for January 2022. The result is likely not an accurate measure of real-time forecast performance, because the data have been revised over time. + +The `out` object has two components: + + 1. The predictions which is just another `epi_df`. It contains the predictions for each location along with additional columns. By default, these are a 90% predictive interval, the `forecast_date` (the date on which the forecast was putatively made) and the `target_date` (the date for which the forecast is being made). + ```{r} + out$predictions + ``` + 2. A list object of class `epi_workflow`. This object encapsulates all the instructions necessary to create the prediction. More details on this below. + +Note that the `time_value` in the predictions is not necessarily meaningful. + +By default, the forecaster predicts the outcome (`death_rate`) 1-week ahead, using 3 lags of each predictor (`case_rate` and `death_rate`) at 0 (today), 1 week back and 2 weeks back. The predictors and outcome can be changed directly. The rest of the defaults are encapsulated into a list of arguments. This list is produced by `arx_args_list()`. + +## Simple adjustments + +Basic adjustments can be made through the `args_list`. + +```{r kill-warnings, echo=FALSE} +knitr::opts_chunk$set(warning = FALSE, message = FALSE) +``` + +```{r differential-lags} +out2week <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), + args_list = arx_args_list( + lags = list(c(0,1,2,3,7,14), c(0,7,14)), + ahead = 14) + ) +``` + +Here, we've used different lags on the `case_rate` and are now predicting 2 weeks ahead. This example also illustrates a major difficulty with the "iterative" versions of AR models. This model doesn't produce forecasts for `case_rate`, and so, would not have data to "plug in" for the necessary lags.[^1] + +[^1]: An obvious fix is to instead use a VAR and predict both, but would likely increase the variance, and may lead to less accurate forecasts for the variable of interest. + +Another property of the basic model is the predictive interval. We describe this in more detail in a different vignette, but it is easy to request multiple quantiles. + +```{r differential-levels} +out_q <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), + args_list = arx_args_list( + levels = c(.01,.025, seq(.05,.95, by=.05), .975,.99)) + ) +``` + +The column `.pred_dstn` in the `predictions` object is actually a "distribution" here parameterized by its quantiles. For this default forecaster, these are created using the quantiles of the residuals of the predictive model (possibly symmetrized). Here, we used 23 quantiles, but one can grab a particular quantile + +```{r q1} +quantile(out_q$predictions$.pred_distn, p = .4) +``` + +Or extract the entire distribution into a "long" `epi_df` with `tau` being the probability and `q` being the value associated to that quantile. + +```{r q2} +out_q$predictions %>% + mutate( + .pred_distn = nested_quantiles(.pred_distn) # "nested" list-col + ) %>% unnest(.pred_distn) +``` + +Further simple adjustments can be made using the function. + +```{r, eval = FALSE} +arx_args_list( + lags = c(0L, 7L, 14L), ahead = 7L, min_train_window = 20L, + forecast_date = NULL, target_date = NULL, levels = c(0.05, 0.95), + symmetrize = TRUE, nonneg = TRUE, quantile_by_key = TRUE +) +``` + +## Changing the engine + +So far, our forecasts have been produced using simple linear regression. But this is not the only way to estimate such a model. +The `trainer` argument determines the type of model we want. +This takes a [`{parsnip}`](https://parsnip.tidymodels.org) model. The default is linear regression, but we could instead use a random forest with the `{ranger}` package: + +```{r ranger, warning = FALSE} +out_rf <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), + rand_forest(mode = "regression")) +``` + +Or boosted regression trees with `{xgboost}`: + +```{r xgboost, warning = FALSE} +out_gb <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), + boost_tree(mode = "regression", trees = 20)) +``` + +## Inner workings + +Underneath the hood, this forecaster creates (and returns) an `epi_workflow`. +Essentially, this is a big S3 object that wraps up the 4 modular steps +(preprocessing - postprocessing) described above. + +### Preprocessing + +Preprocessing is accomplished through a `recipe` (imagine baking a cake) as +provided in the [`{recipes}`](https://recipes.tidymodels.org) package. +We've made a few modifications (to handle +panel data) as well as added some additional options. The recipe gives a +specification of how to handle training data. Think of it like a fancified +`formula` that you would pass to `lm()`: `y ~ x1 + log(x2)`. In general, +there are 2 extensions to the `formula` that `{recipes}` handles: + + 1. Doing transformations of both training and test data that can always be + applied. These are things like taking the log of a variable, leading or + lagging, filtering out rows, handling dummy variables, etc. + 2. Using statistics from the training data to eventually process test data. + This is a major benefit of `{recipes}`. It prevents what the tidy team calls + "data leakage". A simple example is centering a predictor by it's mean. We + need to store the mean of the predictor from the training data and use that + value on the test data rather than accidentally calculating the mean of + the test predictor. + +A recipe is processed in 2 steps, first it is "prepped". This calculates and +stores the result as necessary for use on the test data. Then it is "baked" +resulting in training data ready for passing into a statistical model (like `lm`). + +We have introduced an `epi_recipe`. It's just a `recipe` that knows how to handle +the `time_value`, `geo_value`, and any additional keys so that these are available +when necessary. + +The `epi_recipe` from `out_gb` can be extracted from the result: +```{r} +extract_recipe(out_gb$epi_workflow) +``` + +The "Inputs" are the original `epi_df` and the "roles" that these are assigned. +None of these are predictors or outcomes. Those will be created +by the recipe when it is prepped. The "Operations" are the sequence of +instructions to create the cake. +Here we create lagged predictors, lead the outcome, and then remove NA's. +Some models like `lm` internally handle NA's, but not everything does, so we +deal with them explicitly. The code to do this (inside the forecaster) is + +```{r} +er <- epi_recipe(jhu) %>% + step_epi_lag(case_rate, death_rate, lag = c(0,7,14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() +``` + +While `{recipes}` provides a function `step_lag()`, it assumes that the data +have no breaks in the sequence of `time_values`. This is a bit dangerous, so +we avoid that behaviour. Our `lag/ahead` functions also appropriately adjust the +amount of data to avoid accidently dropping recent predictors from the test +data. + +### The model specification + +Users with familiarity with the `{parsnip}` package will have no trouble here. +Basically, `{parsnip}` unifies the function signature across statistical models. +For example, `lm()` "likes" to work with formulas, but `glmnet::glmnet()` uses +`x` and `y` for predictors and response. `{parsnip}` is agnostic. Both of these +do "linear regression". Above we switched from `lm()` to `xgboost()` without +any issue despite the fact that these functions couldn't be more different. + +```{r, eval = FALSE} +lm(formula, data, subset, weights, na.action, method = "qr", + model = TRUE, x = FALSE, y = FALSE, qr = TRUE, singular.ok = TRUE, + contrasts = NULL, offset, ...) +xgboost(data = NULL, label = NULL, missing = NA, weight = NULL, + params = list(), nrounds, verbose = 1, print_every_n = 1L, + early_stopping_rounds = NULL, maximize = NULL, save_period = NULL, + save_name = "xgboost.model", xgb_model = NULL, callbacks = list(), + ...) +``` + +`{epipredict}` provides a few engines/modules (the flatline forecaster and +quantile regression), but you should be able to use any available models +listed [here](https://www.tidymodels.org/find/parsnip/). + +To estimate (fit) a preprocessed model, one calls `fit()` on the `epi_workflow`. + +```{r} +ewf <- epi_workflow(er, linear_reg()) %>% fit(jhu) +``` + +### Postprocessing + +To stretch the metaphor of preparing a cake to its natural limits, we have +created postprocessing functionality called "frosting". Much like the recipe, +each postprocessing operation is a "layer" and we "slather" these onto our +baked cake. To fix ideas, below is the postprocessing `frosting` for +`arx_epi_forecaster()` + +```{r} +extract_frosting(out_q$epi_workflow) +``` + +Here we have 5 layers of frosting. The first generates the forecasts from the test data. +The second uses quantiles of the residuals (by `geo_value`) to create distributional +forecasts. The next two add columns for the date the forecast was made and the +date for which it is intended to occur. Because we are predicting rates, they +should be non-negative, so the last layer thresholds both predicted values and +intervals at 0. The code to do this (inside the forecaster) is + +```{r} +f <- frosting() %>% + layer_predict() %>% + layer_residual_quantiles( + probs = c(.01,.025, seq(.05,.95, by=.05), .975,.99), + symmetrize = TRUE) %>% + layer_add_forecast_date() %>% + layer_add_target_date() %>% + layer_threshold(starts_with(".pred")) +``` + +At predict time, we add this object onto the `epi_workflow` and call `predict()` + +```{r, warning=FALSE} +test_data <- get_test_data(er, jhu) +ewf %>% add_frosting(f) %>% predict(test_data) +``` + +## Conclusion + +Internally, we provide some simple functions to create reasonable forecasts. +But ideally, a user could create their own forecasters by building up the +components we provide. In other vignettes, we try to walk through some of these +customizations. + +To illustrate everything above, here is (roughly) the code for the `flatline_epi_forecaster()` applied to the `case_rate`. + +```{r} +r <- epi_recipe(jhu) %>% + step_epi_ahead(case_rate, ahead = 7, skip = TRUE) %>% + update_role(case_rate, new_role = "predictor") %>% + add_role(all_of(epi_keys(jhu)), new_role = "predictor") + +# bit of a weird hack to get the latest values per key +latest <- get_test_data(epi_recipe(jhu), jhu) + +f <- frosting() %>% + layer_predict() %>% + layer_residual_quantiles() %>% + layer_add_forecast_date() %>% + layer_add_target_date() %>% + layer_threshold(starts_with(".pred")) + +eng <- linear_reg() %>% set_engine("flatline") +wf <- epi_workflow(r, eng, f) %>% fit(jhu) +preds <- predict(wf, latest) +``` + +All that really differs from the `arx_forecaster()` is the `recipe`, the +test data, and the engine. The `frosting` is identical, as is the fitting +and predicting procedure. + +```{r} +preds +``` + +## Using the workflow object