diff --git a/DESCRIPTION b/DESCRIPTION index fdac45e01..0125f2e58 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,7 +32,6 @@ Imports: glue, hardhat (>= 1.2.0), magrittr, - purrr, recipes (>= 1.0.0), rlang, stats, @@ -42,7 +41,7 @@ Imports: tidyselect, usethis, vctrs, - workflows + workflows (>= 1.0.0) Suggests: covidcast, data.table, @@ -51,9 +50,11 @@ Suggests: knitr, lubridate, parsnip (>= 1.0.0), + ranger, RcppRoll, rmarkdown, - testthat (>= 3.0.0) + testthat (>= 3.0.0), + xgboost VignetteBuilder: knitr Remotes: diff --git a/NAMESPACE b/NAMESPACE index a13ca9f54..0f803520d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -21,6 +21,8 @@ S3method(extract_argument,frosting) S3method(extract_argument,layer) S3method(extract_argument,recipe) S3method(extract_argument,step) +S3method(extract_frosting,default) +S3method(extract_frosting,epi_workflow) S3method(extract_layers,frosting) S3method(extract_layers,workflow) S3method(extrapolate_quantiles,dist_default) @@ -29,6 +31,7 @@ S3method(extrapolate_quantiles,distribution) S3method(format,dist_quantiles) S3method(median,dist_quantiles) S3method(predict,epi_workflow) +S3method(predict,flatline) S3method(prep,epi_recipe) S3method(prep,step_epi_ahead) S3method(prep,step_epi_lag) @@ -39,6 +42,7 @@ S3method(print,step_epi_ahead) S3method(print,step_epi_lag) S3method(quantile,dist_quantiles) S3method(refresh_blueprint,default_epi_recipe_blueprint) +S3method(residuals,flatline) S3method(run_mold,default_epi_recipe_blueprint) S3method(slather,layer_add_forecast_date) S3method(slather,layer_add_target_date) @@ -60,6 +64,7 @@ export(add_frosting) export(add_layer) export(apply_frosting) export(arx_args_list) +export(arx_epi_forecaster) export(arx_forecaster) export(create_lags_and_leads) export(create_layer) @@ -72,8 +77,12 @@ export(epi_recipe) export(epi_recipe_blueprint) export(epi_workflow) export(extract_argument) +export(extract_frosting) export(extract_layers) export(extrapolate_quantiles) +export(flatline) +export(flatline_args_list) +export(flatline_epi_forecaster) export(frosting) export(get_precision) export(get_test_data) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 373f3cd17..0ff4497c1 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -34,7 +34,7 @@ arx_forecaster <- function(x, y, key_vars, time_value, } dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys) - if (intercept) dat$x0 <- 1 + dat$x0 <- 1 obj <- stats::lm( y1 ~ . + 0, @@ -67,14 +67,28 @@ arx_forecaster <- function(x, y, key_vars, time_value, #' #' Constructs a list of arguments for [arx_forecaster()]. #' -#' @template param-lags -#' @template param-ahead -#' @template param-min_train_window -#' @template param-levels -#' @template param-intercept -#' @template param-symmetrize -#' @template param-nonneg -#' @param quantile_by_key Not currently implemented +#' @param lags Vector or List. Positive integers enumerating lags to use +#' in autoregressive-type models (in days). +#' @param ahead Integer. Number of time steps ahead (in days) of the forecast +#' date for which forecasts should be produced. +#' @param min_train_window Integer. The minimal amount of training +#' data (in the time unit of the `epi_df`) needed to produce a forecast. +#' If smaller, the forecaster will return `NA` predictions. +#' @param forecast_date The date on which the forecast is created. The default +#' `NULL` will attempt to determine this automatically. +#' @param target_date The date for which the forecast is intended. The default +#' `NULL` will attempt to determine this automatically. +#' @param levels Vector or `NULL`. A vector of probabilities to produce +#' prediction intervals. These are created by computing the quantiles of +#' training residuals. A `NULL` value will result in point forecasts only. +#' @param symmetrize Logical. The default `TRUE` calculates +#' symmetric prediction intervals. +#' @param nonneg Logical. The default `TRUE` enforces nonnegative predictions +#' by hard-thresholding at 0. +#' @param quantile_by_key Character vector. Groups residuals by listed keys +#' before calculating residual quantiles. See the `by_key` argument to +#' [layer_residual_quantiles()] for more information. The default, +#' `character(0)` performs no grouping. #' #' @return A list containing updated parameter choices. #' @export @@ -83,28 +97,36 @@ arx_forecaster <- function(x, y, key_vars, time_value, #' arx_args_list() #' arx_args_list(symmetrize = FALSE) #' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120) -arx_args_list <- function(lags = c(0, 7, 14), ahead = 7, min_train_window = 20, - levels = c(0.05, 0.95), intercept = TRUE, +arx_args_list <- function(lags = c(0L, 7L, 14L), + ahead = 7L, + min_train_window = 20L, + forecast_date = NULL, + target_date = NULL, + levels = c(0.05, 0.95), symmetrize = TRUE, nonneg = TRUE, - quantile_by_key = FALSE) { + quantile_by_key = character(0L)) { # error checking if lags is a list .lags <- lags if (is.list(lags)) lags <- unlist(lags) - arg_is_scalar(ahead, min_train_window) + arg_is_scalar(ahead, min_train_window, symmetrize, nonneg) + arg_is_chr(quantile_by_key, allow_null = TRUE) + arg_is_scalar(forecast_date, target_date, allow_null = TRUE) arg_is_nonneg_int(ahead, min_train_window, lags) - arg_is_lgl(intercept, symmetrize, nonneg) + arg_is_lgl(symmetrize, nonneg) arg_is_probabilities(levels, allow_null = TRUE) max_lags <- max(lags) - - list( - lags = .lags, ahead = as.integer(ahead), - min_train_window = min_train_window, - levels = levels, intercept = intercept, - symmetrize = symmetrize, nonneg = nonneg, - max_lags = max_lags - ) -} \ No newline at end of file + enlist(lags = .lags, + ahead, + min_train_window, + levels, + forecast_date, + target_date, + symmetrize, + nonneg, + max_lags, + quantile_by_key) +} diff --git a/R/arx_forecaster_mod.R b/R/arx_forecaster_mod.R index 61b87d6a9..60b0294f2 100644 --- a/R/arx_forecaster_mod.R +++ b/R/arx_forecaster_mod.R @@ -1,29 +1,83 @@ -arx_epi_forecaster <- function(epi_data, response, - ..., +#' Direct autoregressive forecaster with covariates +#' +#' This is an autoregressive forecasting model for +#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' that it estimates a model for a particular target horizon. +#' +#' +#' @param epi_data An `epi_df` object +#' @param outcome A character (scalar) specifying the outcome (in the +#' `epi_df`). +#' @param predictors A character vector giving column(s) of predictor +#' variables. +#' @param trainer A `{parsnip}` model describing the type of estimation. +#' For now, we enforce `mode = "regression"`. +#' @param args_list A list of customization arguments to determine +#' the type of forecasting model. See [arx_args_list()]. +#' +#' @return A list with (1) `predictions` an `epi_df` of predicted values +#' and (2) `epi_workflow`, a list that encapsulates the entire estimation +#' workflow +#' @export +#' +#' @examples +#' jhu <- case_death_rate_subset %>% +#' dplyr::filter(time_value >= as.Date("2021-12-01")) +#' +#' out <- arx_epi_forecaster(jhu, "death_rate", +#' c("case_rate", "death_rate")) +arx_epi_forecaster <- function(epi_data, + outcome, + predictors, trainer = parsnip::linear_reg(), args_list = arx_args_list()) { - r <- epi_recipe(epi_data) %>% - step_epi_lag(..., lag = args_list$lags) %>% # hmmm, same for all predictors - step_epi_ahead(response, ahead = args_list$ahead) %>% - # should use the internal function (in an open PR) - recipes::step_naomit(recipes::all_predictors()) %>% - recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) + validate_forecaster_inputs(epi_data, outcome, predictors) + if (!is.list(trainer) || trainer$mode != "regression") + cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.") + lags <- arx_lags_validator(predictors, args_list$lags) + + r <- epi_recipe(epi_data) + for (l in seq_along(lags)) { + p <- predictors[l] + r <- step_epi_lag(r, !!p, lag = lags[[l]]) + } + r <- r %>% + step_epi_ahead(dplyr::all_of(!!outcome), ahead = args_list$ahead) %>% + step_epi_naomit() # should limit the training window here (in an open PR) # What to do if insufficient training data? Add issue. - # remove intercept? not sure how this is implemented in tidymodels + + forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) + target_date <- args_list$target_date %||% forecast_date + args_list$ahead f <- frosting() %>% layer_predict() %>% - layer_naomit(.pred) %>% - layer_residual_quantile( + # layer_naomit(.pred) %>% + layer_residual_quantiles( probs = args_list$levels, symmetrize = args_list$symmetrize) %>% - layer_threshold(.pred, dplyr::starts_with("q")) #, .flag = args_list$nonneg) in open PR - # need the target date processing here + layer_add_forecast_date(forecast_date = forecast_date) %>% + layer_add_target_date(target_date = target_date) + if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) latest <- get_test_data(r, epi_data) - epi_workflow(r, trainer) %>% # bug, issue 72 - add_frosting(f) + wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data) + list( + predictions = predict(wf, new_data = latest), + epi_workflow = wf + ) +} + +arx_lags_validator <- function(predictors, lags) { + p <- length(predictors) + if (!is.list(lags)) lags <- list(lags) + if (length(lags) == 1) lags <- rep(lags, p) + else if (length(lags) < p) { + cli_stop( + "You have requested {p} predictors but lags cannot be recycled to match." + ) + } + lags } diff --git a/R/compat-purrr.R b/R/compat-purrr.R index 283fafa8f..1436d80ce 100644 --- a/R/compat-purrr.R +++ b/R/compat-purrr.R @@ -5,23 +5,45 @@ map <- function(.x, .f, ...) { .f <- rlang::as_function(.f, env = rlang::global_env()) lapply(.x, .f, ...) } + walk <- function(.x, .f, ...) { map(.x, .f, ...) invisible(.x) } +walk2 <- function(.x, .y, .f, ...) { + map2(.x, .y, .f, ...) + invisible(.x) +} + map_lgl <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, logical(1), ...) } + map_int <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, integer(1), ...) } + map_dbl <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, double(1), ...) } + map_chr <- function(.x, .f, ...) { .rlang_purrr_map_mold(.x, .f, character(1), ...) } + +map_dfr <- function(.x, .f, ..., .id = NULL) { + .f <- rlang::as_function(.f, env = global_env()) + res <- map(.x, .f, ...) + dplyr::bind_rows(res, .id = .id) +} + +map2_dfr <- function(.x, .y, .f, ..., .id = NULL) { + .f <- rlang::as_function(.f, env = global_env()) + res <- map2(.x, .y, .f, ...) + dplyr::bind_rows(res, .id = .id) +} + .rlang_purrr_map_mold <- function(.x, .f, .mold, ...) { .f <- rlang::as_function(.f, env = rlang::global_env()) out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE) diff --git a/R/epi_keys.R b/R/epi_keys.R index 0ec86f9dd..aa9976efa 100644 --- a/R/epi_keys.R +++ b/R/epi_keys.R @@ -29,7 +29,7 @@ epi_keys.recipe <- function(x) { epi_keys_mold <- function(mold) { keys <- c("time_value", "geo_value", "key") molded_names <- names(mold$extras$roles) - mold_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names) - unname(mold_keys) + mold_keys <- map(mold$extras$roles[molded_names %in% keys], names) + unname(unlist(mold_keys)) } diff --git a/R/epi_shift.R b/R/epi_shift.R index c61137706..4fe99601b 100644 --- a/R/epi_shift.R +++ b/R/epi_shift.R @@ -19,7 +19,7 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { tidyr::unchop(shift) %>% # what is chop dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>% # One list element for each shifted feature - purrr::pmap(function(i, shift, name) { + pmap(function(i, shift, name) { tibble(keys, time_value = time_value + shift, # Shift back !!name := x[[i]]) @@ -27,7 +27,7 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { if (is.data.frame(keys)) common_names <- c(names(keys), "time_value") else common_names <- c("keys", "time_value") - purrr::reduce(out_list, dplyr::full_join, by = common_names) + reduce(out_list, dplyr::full_join, by = common_names) } epi_shift_single <- function(x, col, shift_val, newname, key_cols) { diff --git a/R/epi_workflow.R b/R/epi_workflow.R index dab343f60..8ae99b2b7 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -62,6 +62,7 @@ is_epi_workflow <- function(x) { inherits(x, "epi_workflow") } + #' Predict from an epi_workflow #' #' @description diff --git a/R/imports.R b/R/epipredict-package.R similarity index 100% rename from R/imports.R rename to R/epipredict-package.R diff --git a/R/flatline.R b/R/flatline.R new file mode 100644 index 000000000..8732be192 --- /dev/null +++ b/R/flatline.R @@ -0,0 +1,82 @@ + +#' (Internal) implementation of the flatline forecaster +#' +#' This is an internal function that is used to create a [parsnip::linear_reg()] +#' model. It has somewhat odd behaviour (see below). +#' +#' +#' @param formula The lhs should be a single variable. In standard usage, this +#' would actually be the observed time series shifted forward by the forecast +#' horizon. The right hand side must contain any keys (locations) for the +#' panel data separated by plus. The observed time series must come last. +#' For example +#' ``` +#' form <- as.formula(lead7_y ~ state + age + y) +#' ``` +#' Note that this function doesn't DO the shifting, that has to be done +#' outside. +#' @param data A data frame containing at least the variables used in the +#' formula. It must also contain a column `time_value` giving the observed +#' time points. +#' +#' @return An S3 object of class `flatline` with two components: +#' * `residuals` - a tibble with all keys and a `.resid` column that contains +#' forecast errors. +#' * `.pred` - a tibble with all keys and a `.pred` column containing only +#' predictions for future data (the last observed of the outcome for each +#' combination of keys. +#' @export +#' +#' @examples +#' tib <- data.frame(y = runif(100), +#' expand.grid(k = letters[1:4], j = letters[5:9], time_value = 1:5)) %>% +#' dplyr::group_by(k, j) %>% +#' dplyr::mutate(y2 = dplyr::lead(y, 2)) # predict 2 steps ahead +#' flat <- flatline(y2 ~ j + k + y, tib) # predictions for 20 locations +#' sum(!is.na(flat$residuals$.resid)) # 100 residuals, but 40 are NA +flatline <- function(formula, data) { + response <- recipes:::get_lhs_vars(formula, data) + rhs <- recipes:::get_rhs_vars(formula, data) + n <- length(rhs) + observed <- rhs[n] # DANGER!! + ek <- rhs[-n] + if (length(response) > 1) + cli_stop("flatline forecaster can accept only 1 observed time series.") + keys <- ek[ek != "time_value"] + + preds <- data %>% + dplyr::mutate(.pred = !!rlang::sym(observed), + .resid = !!rlang::sym(response) - .pred) + .pred <- preds %>% + dplyr::filter(!is.na(.pred)) %>% + dplyr::group_by(!!!rlang::syms(keys)) %>% + dplyr::arrange(time_value) %>% + dplyr::slice_tail(n = 1L) %>% + dplyr::ungroup() %>% + dplyr::select(dplyr::all_of(c(keys, ".pred"))) + + structure(list( + residuals = dplyr::select(preds, dplyr::all_of(c(keys, ".resid"))), + .pred = .pred), + class = "flatline" + ) +} + +#' @export +residuals.flatline <- function(object, ...) { + object$residuals +} + +#' @export +predict.flatline <- function(object, newdata, ...) { + object <- object$.pred + metadata <- names(object)[names(object) != ".pred"] + ek <- names(newdata) + if (! all(metadata %in% ek)) { + cli_stop("`newdata` has different metadata than was used", + "to fit the flatline forecaster") + } + + dplyr::left_join(newdata, object, by = metadata) %>% + dplyr::pull(.pred) +} diff --git a/R/flatline_epi_forecaster.R b/R/flatline_epi_forecaster.R new file mode 100644 index 000000000..7ecefc17e --- /dev/null +++ b/R/flatline_epi_forecaster.R @@ -0,0 +1,122 @@ +#' Predict the future with today's value +#' +#' This is a simple forecasting model for +#' [epiprocess::epi_df] data. It uses the most recent observation as the +#' forcast for any future date, and produces intervals based on the quantiles +#' of the residuals of such a "flatline" forecast over all available training +#' data. +#' +#' By default, the predictive intervals are computed separately for each +#' combination of key values (`geo_value` + any additional keys) in the +#' `epi_data` argument. +#' +#' This forecaster is very similar to that used by the +#' [COVID19ForecastHub](https://covid19forecasthub.org) +#' +#' @param epi_data An [epiprocess::epi_df] +#' @param outcome A scalar character for the column name we wish to predict. +#' @param args_list A list of dditional arguments as created by the +#' [flatline_args_list()] constructor function. +#' +#' @return A data frame of point (and optionally interval) forecasts at a single +#' ahead (unique horizon) for each unique combination of `key_vars`. +#' @export +#' +#' @examples +#' jhu <- case_death_rate_subset %>% +#' dplyr::filter(time_value >= as.Date("2021-12-01")) +#' +#' out <- flatline_epi_forecaster(jhu, "death_rate") +flatline_epi_forecaster <- function(epi_data, + outcome, + args_list = flatline_args_list()) { + + validate_forecaster_inputs(epi_data, outcome, "time_value") + keys <- epi_keys(epi_data) + ek <- keys[-1] + outcome <- rlang::sym(outcome) + + + r <- epi_recipe(epi_data) %>% + step_epi_ahead(!!outcome, ahead = args_list$ahead, skip = TRUE) %>% + recipes::update_role(!!outcome, new_role = "predictor") %>% + recipes::add_role(dplyr::all_of(keys), new_role = "predictor") + + latest <- get_test_data(epi_recipe(epi_data), epi_data) + + forecast_date <- args_list$forecast_date %||% max(latest$time_value) + target_date <- args_list$target_date %||% forecast_date + args_list$ahead + + f <- frosting() %>% + layer_predict() %>% + layer_residual_quantiles( + probs = args_list$levels, + symmetrize = args_list$symmetrize, + by_key = args_list$quantile_by_key) %>% + layer_add_forecast_date(forecast_date = forecast_date) %>% + layer_add_target_date(target_date = target_date) + if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) + + eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") + + wf <- epi_workflow(r, eng, f) %>% fit(epi_data) + + list( + predictions = suppressWarnings(predict(wf, new_data = latest)), + epi_workflow = wf + ) +} + + + +#' Flatline forecaster argument constructor +#' +#' Constructs a list of arguments for [flatline_epi_forecaster()]. +#' +#' @inheritParams arx_args_list +#' +#' @return A list containing updated parameter choices. +#' @export +#' +#' @examples +#' flatline_args_list() +#' flatline_args_list(symmetrize = FALSE) +#' flatline_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120) +flatline_args_list <- function(ahead = 7L, + min_train_window = 20L, + forecast_date = NULL, + target_date = NULL, + levels = c(0.05, 0.95), + symmetrize = TRUE, + nonneg = TRUE, + quantile_by_key = character(0L)) { + + arg_is_scalar(ahead, min_train_window) + arg_is_chr(quantile_by_key, allow_null = TRUE) + arg_is_scalar(forecast_date, target_date, allow_null = TRUE) + arg_is_nonneg_int(ahead, min_train_window) + arg_is_lgl(symmetrize, nonneg) + arg_is_probabilities(levels, allow_null = TRUE) + + enlist(ahead, + min_train_window, + forecast_date, + target_date, + levels, + symmetrize, + nonneg, + quantile_by_key) +} + +validate_forecaster_inputs <- function(epi_data, outcome, predictors) { + if (! epiprocess::is_epi_df(epi_data)) + cli_stop("epi_data must be an epi_df.") + arg_is_chr(predictors) + arg_is_chr_scalar(outcome) + if (! outcome %in% names(epi_data)) + cli_stop("{outcome} was not found in the training data.") + if (! all(predictors %in% names(epi_data))) + cli_stop("At least one predictor was not found in the training data.") + invisible(TRUE) +} + diff --git a/R/frosting.R b/R/frosting.R index 436ea2319..5e881e0a9 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -176,21 +176,29 @@ frosting <- function(layers = NULL, requirements = NULL) { out <- new_frosting() } + +#' Extract the frosting object from a workflow +#' +#' @param x an `epi_workflow` object +#' @param ... not used +#' +#' @return a `frosting` object +#' @export extract_frosting <- function(x, ...) { UseMethod("extract_frosting") } +#' @export extract_frosting.default <- function(x, ...) { abort(c("Frosting is only available for epi_workflows currently.", i = "Can you use `epi_workflow()` instead of `workflow()`?")) invisible(x) } +#' @export extract_frosting.epi_workflow <- function(x, ...) { - if (has_postprocessor_frosting(x)) { - return(x$post$actions$frosting$frosting) - } - abort("The epi_workflow does not have a preprocessor.") + if (has_postprocessor_frosting(x)) return(x$post$actions$frosting$frosting) + else cli_stop("The epi_workflow does not have a postprocessor.") } #' Apply postprocessing to a fitted workflow diff --git a/R/get_test_data.R b/R/get_test_data.R index 59a163aab..51ad2be1f 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -1,8 +1,8 @@ #' Get test data for prediction based on longest lag period #' #' Based on the longest lag period in the recipe, -#' `get_test_data()` creates a tibble in [epiprocess::epi_df] -#' format with columns `geo_value`, `time_value` +#' `get_test_data()` creates an [epiprocess::epi_df] +#' with columns `geo_value`, `time_value` #' and other variables in the original dataset, #' which will be used to create test data. #' @@ -21,14 +21,11 @@ #' get_test_data(recipe = rec, x = case_death_rate_subset) #' @export -get_test_data <- function(recipe, x){ - # TO-DO: SOME CHECKS OF THE DATASET - if (any(!(c('geo_value','time_value') %in% colnames(x)))) { - rlang::abort("`geo_value`, `time_value` does not exist in data") - } - ## CHECK if it is epi_df? - - max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0))) +get_test_data <- function(recipe, x) { + stopifnot(is.data.frame(x)) + if (! all(colnames(x) %in% colnames(recipe$template))) + cli_stop("some variables used for training are not available in `x`.") + max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0) # CHECK: Return NA if insufficient training data if (dplyr::n_distinct(x$time_value) < max_lags) { @@ -36,7 +33,8 @@ get_test_data <- function(recipe, x){ "You need at least {max_lags} distinct time_values.") } - groups <- epi_keys(recipe)[epi_keys(recipe) != "time_value"] + groups <- epi_keys(recipe) + groups <- groups[groups != "time_value"] x %>% dplyr::filter( diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 5d3ab4621..5c2c8bcef 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -83,10 +83,15 @@ slather.layer_add_forecast_date <- function(object, components, the_fit, the_rec as_of_date <- as.Date(attributes(components$keys)$metadata$as_of) if (object$forecast_date < as_of_date) { - warning("forecast_date is less than the most recent update date of the data.") + cli_warn( + c("The forecast_date is less than the most ", + "recent update date of the data.", + i = "forecast_date = {object$forecast_date} while data is from {as_of_date}.") + ) } - - components$predictions <- dplyr::bind_cols(components$predictions, - forecast_date = as.Date(object$forecast_date)) + components$predictions <- dplyr::bind_cols( + components$predictions, + forecast_date = as.Date(object$forecast_date) + ) components } diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 8507a74a6..07b8f9146 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -5,6 +5,8 @@ #' @param probs numeric vector of probabilities with values in (0,1) #' referring to the desired quantile. #' @param symmetrize logical. If `TRUE` then interval will be symmetric. +#' @param by_key A character vector of keys to group the residuls by before +#' calculating quantiles. The default, `c()` performs no grouping. #' @param name character. The name for the output column. #' @param .flag a logical to determine if the layer is added. Passed on to #' `add_layer()`. Default `TRUE`. @@ -34,22 +36,33 @@ #' wf1 <- wf %>% add_frosting(f) #' #' p <- predict(wf1, latest) -#' p +#' +#' f2 <- frosting() %>% +#' layer_predict() %>% +#' layer_residual_quantiles(probs = c(0.3, 0.7), by_key = "geo_value") %>% +#' layer_naomit(.pred) +#' wf2 <- wf %>% add_frosting(f2) +#' +#' p2 <- predict(wf2, latest) layer_residual_quantiles <- function(frosting, ..., probs = c(0.0275, 0.975), symmetrize = TRUE, + by_key = character(0L), name = ".pred_distn", .flag = TRUE, id = rand_id("residual_quantiles")) { rlang::check_dots_empty() + arg_is_scalar(symmetrize, .flag) arg_is_chr_scalar(name, id) + arg_is_chr(by_key, allow_null = TRUE) arg_is_probabilities(probs) - arg_is_lgl(symmetrize) + arg_is_lgl(symmetrize, .flag) add_layer( frosting, layer_residual_quantiles_new( probs = probs, symmetrize = symmetrize, + by_key = by_key, name = name, id = id ), @@ -57,9 +70,9 @@ layer_residual_quantiles <- function(frosting, ..., ) } -layer_residual_quantiles_new <- function(probs, symmetrize, name, id) { +layer_residual_quantiles_new <- function(probs, symmetrize, by_key, name, id) { layer("residual_quantiles", probs = probs, symmetrize = symmetrize, - name = name, id = id) + by_key = by_key, name = name, id = id) } #' @export @@ -67,13 +80,35 @@ slather.layer_residual_quantiles <- function(object, components, the_fit, the_recipe, ...) { if (is.null(object$probs)) return(components) + s <- ifelse(object$symmetrize, -1, NA) - r <- grab_residuals(the_fit, components) - q <- quantile(c(r, s * r), probs = object$probs, na.rm = TRUE) + r <- dplyr::bind_cols( + r = grab_residuals(the_fit, components), + geo_value = components$mold$extras$roles$geo_value, + components$mold$extras$roles$key) + + ## Handle any grouping requests + if (length(object$by_key) > 0L) { + common <- intersect(object$by_key, names(r)) + excess <- setdiff(object$by_key, names(r)) + if (length(excess) > 0L) { + cli_warn("Requested residual grouping key(s) {excess} unavailable ", + "in the original data. Grouping by the remainder {common}.") + + } + if (length(common) > 0L) + r <- r %>% dplyr::group_by(!!!rlang::syms(common)) + } + + r <- r %>% + dplyr::summarise( + q = list(quantile(c(r, s * r), probs = object$probs, na.rm = TRUE)) + ) estimate <- components$predictions$.pred res <- tibble::tibble( - .pred_distn = dist_quantiles(map(estimate, "+", q), object$probs)) + .pred_distn = dist_quantiles(map2(estimate, r$q, "+"), object$probs) + ) res <- check_pname(res, components$predictions, object) components$predictions <- dplyr::mutate(components$predictions, !!!res) components @@ -82,7 +117,12 @@ slather.layer_residual_quantiles <- grab_residuals <- function(the_fit, components) { if (the_fit$spec$mode != "regression") rlang::abort("For meaningful residuals, the predictor should be a regression model.") - - yhat <- predict(the_fit, new_data = components$mold$predictors) - c(components$mold$outcomes - yhat)[[1]] + r_generic <- attr(utils::methods(class = class(the_fit)[1]), "info")$generic + if ("residuals" %in% r_generic) { + r <- residuals(the_fit) + } else { + yhat <- predict(the_fit, new_data = components$mold$predictors) + r <- c(components$mold$outcomes - yhat)[[1]] + } + r } diff --git a/R/make_flatline_reg.R b/R/make_flatline_reg.R new file mode 100644 index 000000000..33c135f08 --- /dev/null +++ b/R/make_flatline_reg.R @@ -0,0 +1,39 @@ +make_flatline_reg <- function() { + parsnip::set_model_engine("linear_reg", "regression", eng = "flatline") + parsnip::set_dependency("linear_reg", eng = "flatline", pkg = "epipredict") + + parsnip::set_fit( + model = "linear_reg", + eng = "flatline", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "epipredict", fun = "flatline"), + defaults = list() + )) + + parsnip::set_encoding( + model = "linear_reg", + eng = "flatline", + mode = "regression", + options = list( + predictor_indicators = "none", + compute_intercept = TRUE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) + ) + + parsnip::set_pred( + model = "linear_reg", + eng = "flatline", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, post = NULL, func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) + ) + +} diff --git a/R/smooth_arx_forecaster.R b/R/smooth_arx_forecaster.R index d862febee..a00238585 100644 --- a/R/smooth_arx_forecaster.R +++ b/R/smooth_arx_forecaster.R @@ -24,10 +24,10 @@ smooth_arx_forecaster <- function(x, y, key_vars, time_value, if (length(y) < min_train_window + max_lags + max(ahead)) { qnames <- probs_to_string(levels) - out <- purrr::map_dfr(ahead, ~ distinct_keys, .id = "ahead") %>% + out <- map_dfr(ahead, ~ distinct_keys, .id = "ahead") %>% dplyr::mutate(ahead = magrittr::extract(!!ahead, as.integer(ahead)), point = NA) %>% - dplyr::select(!any_of(".dump")) + dplyr::select(!dplyr::any_of(".dump")) return(enframer(out, qnames)) } @@ -46,15 +46,15 @@ smooth_arx_forecaster <- function(x, y, key_vars, time_value, as.data.frame() %>% magrittr::set_names(ahead) - q <- purrr::map2_dfr( + q <- map2_dfr( r, point, ~ residual_quantiles(.x, .y, levels, symmetrize), .id = "ahead" ) %>% mutate(ahead = as.integer(ahead)) if (nonneg) q <- dplyr::mutate(q, dplyr::across(!ahead, ~ pmax(.x, 0))) return( - purrr::map_dfr(ahead, ~ distinct_keys) %>% - dplyr::select(!any_of(".dump")) %>% + map_dfr(ahead, ~ distinct_keys) %>% + dplyr::select(!dplyr::any_of(".dump")) %>% dplyr::bind_cols(q) %>% dplyr::relocate(ahead) ) diff --git a/R/step_epi_naomit.R b/R/step_epi_naomit.R index f266fe636..22e5eab03 100644 --- a/R/step_epi_naomit.R +++ b/R/step_epi_naomit.R @@ -15,6 +15,6 @@ step_epi_naomit <- function(recipe) { stopifnot(inherits(recipe, "recipe")) recipe %>% - recipes::step_naomit(all_predictors()) %>% + recipes::step_naomit(all_predictors(), skip = FALSE) %>% recipes::step_naomit(all_outcomes(), skip = TRUE) } diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index f3d69874e..c7ffcb875 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -188,7 +188,7 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { bake.step_epi_lag <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, lag = object$lag) %>% dplyr::mutate(newname = glue::glue("{object$prefix}{lag}_{col}"), - shift_val = object$lag, + shift_val = lag, lag = NULL) ## ensure no name clashes @@ -219,7 +219,7 @@ bake.step_epi_lag <- function(object, new_data, ...) { bake.step_epi_ahead <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, ahead = object$ahead) %>% dplyr::mutate(newname = glue::glue("{object$prefix}{ahead}_{col}"), - shift_val = -object$ahead, + shift_val = -ahead, ahead = NULL) ## ensure no name clashes diff --git a/R/utils_arg.R b/R/utils_arg.R index 2cd5db6ba..ebac387d5 100644 --- a/R/utils_arg.R +++ b/R/utils_arg.R @@ -6,9 +6,9 @@ handle_arg_list = function(..., tests) { values = list(...) #names = names(values) names = eval(substitute(alist(...))) - names = purrr::map(names, deparse) + names = map(names, deparse) - purrr::walk2(names, values, tests) + walk2(names, values, tests) } arg_is_scalar = function(..., allow_null = FALSE, allow_na = FALSE) { @@ -125,7 +125,7 @@ arg_is_function = function(..., allow_null = FALSE) { ..., tests = function(name, value) { if (!is.function(value) | (is.null(value) & !allow_null)) - cli_stop("All {.val {name}} must be in [0,1].") + cli_stop("{value} must be a `parsnip` function.") } ) } diff --git a/R/zzz.R b/R/zzz.R new file mode 100644 index 000000000..0db33dae9 --- /dev/null +++ b/R/zzz.R @@ -0,0 +1,9 @@ +# ON LOAD ---- + +# The functions below define the model information. These access the model +# environment inside of parsnip so they have to be executed once parsnip has +# been loaded. + +.onLoad <- function(libname, pkgname) { + make_flatline_reg() +} diff --git a/man/arx_args_list.Rd b/man/arx_args_list.Rd index ddb56d422..09e74411c 100644 --- a/man/arx_args_list.Rd +++ b/man/arx_args_list.Rd @@ -5,41 +5,48 @@ \title{ARX forecaster argument constructor} \usage{ arx_args_list( - lags = c(0, 7, 14), - ahead = 7, - min_train_window = 20, + lags = c(0L, 7L, 14L), + ahead = 7L, + min_train_window = 20L, + forecast_date = NULL, + target_date = NULL, levels = c(0.05, 0.95), - intercept = TRUE, symmetrize = TRUE, nonneg = TRUE, - quantile_by_key = FALSE + quantile_by_key = character(0L) ) } \arguments{ \item{lags}{Vector or List. Positive integers enumerating lags to use -in autoregressive-type models.} +in autoregressive-type models (in days).} -\item{ahead}{Integer. Number of time steps ahead of the forecast date -for which forecasts should be produced.} +\item{ahead}{Integer. Number of time steps ahead (in days) of the forecast +date for which forecasts should be produced.} \item{min_train_window}{Integer. The minimal amount of training -data needed to produce a forecast. If smaller, the forecaster will return -\code{NA} predictions.} +data (in the time unit of the \code{epi_df}) needed to produce a forecast. +If smaller, the forecaster will return \code{NA} predictions.} + +\item{forecast_date}{The date on which the forecast is created. The default +\code{NULL} will attempt to determine this automatically.} + +\item{target_date}{The date for which the forecast is intended. The default +\code{NULL} will attempt to determine this automatically.} \item{levels}{Vector or \code{NULL}. A vector of probabilities to produce prediction intervals. These are created by computing the quantiles of training residuals. A \code{NULL} value will result in point forecasts only.} -\item{intercept}{Logical. The default \code{TRUE} includes intercept in the -forecaster.} - \item{symmetrize}{Logical. The default \code{TRUE} calculates symmetric prediction intervals.} -\item{nonneg}{Logical. The default \code{TRUE} enforeces nonnegative predictions +\item{nonneg}{Logical. The default \code{TRUE} enforces nonnegative predictions by hard-thresholding at 0.} -\item{quantile_by_key}{Not currently implemented} +\item{quantile_by_key}{Character vector. Groups residuals by listed keys +before calculating residual quantiles. See the \code{by_key} argument to +\code{\link[=layer_residual_quantiles]{layer_residual_quantiles()}} for more information. The default, +\code{character(0)} performs no grouping.} } \value{ A list containing updated parameter choices. diff --git a/man/arx_epi_forecaster.Rd b/man/arx_epi_forecaster.Rd new file mode 100644 index 000000000..da8d0ad67 --- /dev/null +++ b/man/arx_epi_forecaster.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/arx_forecaster_mod.R +\name{arx_epi_forecaster} +\alias{arx_epi_forecaster} +\title{Direct autoregressive forecaster with covariates} +\usage{ +arx_epi_forecaster( + epi_data, + outcome, + predictors, + trainer = parsnip::linear_reg(), + args_list = arx_args_list() +) +} +\arguments{ +\item{epi_data}{An \code{epi_df} object} + +\item{outcome}{A character (scalar) specifying the outcome (in the +\code{epi_df}).} + +\item{predictors}{A character vector giving column(s) of predictor +variables.} + +\item{trainer}{A \code{{parsnip}} model describing the type of estimation. +For now, we enforce \code{mode = "regression"}.} + +\item{args_list}{A list of customization arguments to determine +the type of forecasting model. See \code{\link[=arx_args_list]{arx_args_list()}}.} +} +\value{ +A list with (1) \code{predictions} an \code{epi_df} of predicted values +and (2) \code{epi_workflow}, a list that encapsulates the entire estimation +workflow +} +\description{ +This is an autoregressive forecasting model for +\link[epiprocess:epi_df]{epiprocess::epi_df} data. It does "direct" forecasting, meaning +that it estimates a model for a particular target horizon. +} +\examples{ +jhu <- case_death_rate_subset \%>\% + dplyr::filter(time_value >= as.Date("2021-12-01")) + +out <- arx_epi_forecaster(jhu, "death_rate", + c("case_rate", "death_rate")) +} diff --git a/man/epi_workflow.Rd b/man/epi_workflow.Rd index bcf0e78aa..f9d753d84 100644 --- a/man/epi_workflow.Rd +++ b/man/epi_workflow.Rd @@ -11,7 +11,7 @@ epi_workflow(preprocessor = NULL, spec = NULL, postprocessor = NULL) \itemize{ \item A formula, passed on to \code{\link[workflows:add_formula]{add_formula()}}. \item A recipe, passed on to \code{\link[workflows:add_recipe]{add_recipe()}}. -\item A \code{\link[workflows:add_variables]{workflow_variables()}} object, passed on to \code{\link[workflows:add_variables]{add_variables()}}. +\item A \code{\link[workflows:workflow_variables]{workflow_variables()}} object, passed on to \code{\link[workflows:add_variables]{add_variables()}}. }} \item{spec}{An optional parsnip model specification to add to the workflow. diff --git a/man/extract_frosting.Rd b/man/extract_frosting.Rd new file mode 100644 index 000000000..560ffc032 --- /dev/null +++ b/man/extract_frosting.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/frosting.R +\name{extract_frosting} +\alias{extract_frosting} +\title{Extract the frosting object from a workflow} +\usage{ +extract_frosting(x, ...) +} +\arguments{ +\item{x}{an \code{epi_workflow} object} + +\item{...}{not used} +} +\value{ +a \code{frosting} object +} +\description{ +Extract the frosting object from a workflow +} diff --git a/man/flatline.Rd b/man/flatline.Rd new file mode 100644 index 000000000..d65dcae8a --- /dev/null +++ b/man/flatline.Rd @@ -0,0 +1,47 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/flatline.R +\name{flatline} +\alias{flatline} +\title{(Internal) implementation of the flatline forecaster} +\usage{ +flatline(formula, data) +} +\arguments{ +\item{formula}{The lhs should be a single variable. In standard usage, this +would actually be the observed time series shifted forward by the forecast +horizon. The right hand side must contain any keys (locations) for the +panel data separated by plus. The observed time series must come last. +For example + +\if{html}{\out{