Skip to content

Intro vignette #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ Imports:
glue,
hardhat (>= 1.2.0),
magrittr,
purrr,
recipes (>= 1.0.0),
rlang,
stats,
Expand All @@ -42,7 +41,7 @@ Imports:
tidyselect,
usethis,
vctrs,
workflows
workflows (>= 1.0.0)
Suggests:
covidcast,
data.table,
Expand All @@ -51,9 +50,11 @@ Suggests:
knitr,
lubridate,
parsnip (>= 1.0.0),
ranger,
RcppRoll,
rmarkdown,
testthat (>= 3.0.0)
testthat (>= 3.0.0),
xgboost
VignetteBuilder:
knitr
Remotes:
Expand Down
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ S3method(extract_argument,frosting)
S3method(extract_argument,layer)
S3method(extract_argument,recipe)
S3method(extract_argument,step)
S3method(extract_frosting,default)
S3method(extract_frosting,epi_workflow)
S3method(extract_layers,frosting)
S3method(extract_layers,workflow)
S3method(extrapolate_quantiles,dist_default)
Expand All @@ -29,6 +31,7 @@ S3method(extrapolate_quantiles,distribution)
S3method(format,dist_quantiles)
S3method(median,dist_quantiles)
S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,epi_recipe)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
Expand All @@ -39,6 +42,7 @@ S3method(print,step_epi_ahead)
S3method(print,step_epi_lag)
S3method(quantile,dist_quantiles)
S3method(refresh_blueprint,default_epi_recipe_blueprint)
S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
S3method(slather,layer_add_target_date)
Expand All @@ -60,6 +64,7 @@ export(add_frosting)
export(add_layer)
export(apply_frosting)
export(arx_args_list)
export(arx_epi_forecaster)
export(arx_forecaster)
export(create_lags_and_leads)
export(create_layer)
Expand All @@ -72,8 +77,12 @@ export(epi_recipe)
export(epi_recipe_blueprint)
export(epi_workflow)
export(extract_argument)
export(extract_frosting)
export(extract_layers)
export(extrapolate_quantiles)
export(flatline)
export(flatline_args_list)
export(flatline_epi_forecaster)
export(frosting)
export(get_precision)
export(get_test_data)
Expand Down
68 changes: 45 additions & 23 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ arx_forecaster <- function(x, y, key_vars, time_value,
}

dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys)
if (intercept) dat$x0 <- 1
dat$x0 <- 1

obj <- stats::lm(
y1 ~ . + 0,
Expand Down Expand Up @@ -67,14 +67,28 @@ arx_forecaster <- function(x, y, key_vars, time_value,
#'
#' Constructs a list of arguments for [arx_forecaster()].
#'
#' @template param-lags
#' @template param-ahead
#' @template param-min_train_window
#' @template param-levels
#' @template param-intercept
#' @template param-symmetrize
#' @template param-nonneg
#' @param quantile_by_key Not currently implemented
#' @param lags Vector or List. Positive integers enumerating lags to use
#' in autoregressive-type models (in days).
#' @param ahead Integer. Number of time steps ahead (in days) of the forecast
#' date for which forecasts should be produced.
#' @param min_train_window Integer. The minimal amount of training
#' data (in the time unit of the `epi_df`) needed to produce a forecast.
#' If smaller, the forecaster will return `NA` predictions.
#' @param forecast_date The date on which the forecast is created. The default
#' `NULL` will attempt to determine this automatically.
#' @param target_date The date for which the forecast is intended. The default
#' `NULL` will attempt to determine this automatically.
#' @param levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
#' @param symmetrize Logical. The default `TRUE` calculates
#' symmetric prediction intervals.
#' @param nonneg Logical. The default `TRUE` enforces nonnegative predictions
#' by hard-thresholding at 0.
#' @param quantile_by_key Character vector. Groups residuals by listed keys
#' before calculating residual quantiles. See the `by_key` argument to
#' [layer_residual_quantiles()] for more information. The default,
#' `character(0)` performs no grouping.
#'
#' @return A list containing updated parameter choices.
#' @export
Expand All @@ -83,28 +97,36 @@ arx_forecaster <- function(x, y, key_vars, time_value,
#' arx_args_list()
#' arx_args_list(symmetrize = FALSE)
#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120)
arx_args_list <- function(lags = c(0, 7, 14), ahead = 7, min_train_window = 20,
levels = c(0.05, 0.95), intercept = TRUE,
arx_args_list <- function(lags = c(0L, 7L, 14L),
Copy link
Contributor

@ChloeYou ChloeYou Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why intercept was removed? Is it because it seems like we would always have the intercept in the ARX model no matter what?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we necessarily have these default lags and ahead arguments? The defaults seem specifically defined for time_type = day but there are many other time types for which the defaults make a bit less sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do leave the defaults as is, it may be good to briefly mention the defaults somewhere else in arx_args_list() aside from the initial list of arguments to emphasize them (under the argument descriptions seems to be natural choice, where I see mention of defaults for some but not all of the arguments).

Copy link
Contributor Author

@dajmcdon dajmcdon Aug 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Intercept, yes, it's always there.

On the time_type. Good point. Our implementation always modifies the time_value by a day. I added this to the documentation.

Actually, can someone create an issue to:

  1. adjust the lead/lag function to accept a time_unit argument (days, months, weeks, years) and perform the operations appropriately.
  2. Add an arg here (and elsewhere) to allow a user to specify the units (possibly different for lags/ahead/train_window)

ahead = 7L,
min_train_window = 20L,
forecast_date = NULL,
target_date = NULL,
levels = c(0.05, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = FALSE) {
quantile_by_key = character(0L)) {

# error checking if lags is a list
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)

arg_is_scalar(ahead, min_train_window)
arg_is_scalar(ahead, min_train_window, symmetrize, nonneg)
arg_is_chr(quantile_by_key, allow_null = TRUE)
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
arg_is_nonneg_int(ahead, min_train_window, lags)
arg_is_lgl(intercept, symmetrize, nonneg)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(levels, allow_null = TRUE)

max_lags <- max(lags)

list(
lags = .lags, ahead = as.integer(ahead),
min_train_window = min_train_window,
levels = levels, intercept = intercept,
symmetrize = symmetrize, nonneg = nonneg,
max_lags = max_lags
)
}
enlist(lags = .lags,
ahead,
min_train_window,
levels,
forecast_date,
target_date,
symmetrize,
nonneg,
max_lags,
quantile_by_key)
}
84 changes: 69 additions & 15 deletions R/arx_forecaster_mod.R
Original file line number Diff line number Diff line change
@@ -1,29 +1,83 @@
arx_epi_forecaster <- function(epi_data, response,
...,
#' Direct autoregressive forecaster with covariates
#'
#' This is an autoregressive forecasting model for
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
#' that it estimates a model for a particular target horizon.
#'
#'
#' @param epi_data An `epi_df` object
#' @param outcome A character (scalar) specifying the outcome (in the
#' `epi_df`).
#' @param predictors A character vector giving column(s) of predictor
#' variables.
#' @param trainer A `{parsnip}` model describing the type of estimation.
#' For now, we enforce `mode = "regression"`.
#' @param args_list A list of customization arguments to determine
#' the type of forecasting model. See [arx_args_list()].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should users know the default of arx_args_list() creates 0, 7, 14 periods lagged predictors and 7 day ahead outcomes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but I don't want to put this in the docs at the moment. Looking at the function would tell you.

#'
#' @return A list with (1) `predictions` an `epi_df` of predicted values
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
#' workflow
#' @export
#'
#' @examples
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
#'
#' out <- arx_epi_forecaster(jhu, "death_rate",
#' c("case_rate", "death_rate"))
arx_epi_forecaster <- function(epi_data,
outcome,
predictors,
trainer = parsnip::linear_reg(),
args_list = arx_args_list()) {

r <- epi_recipe(epi_data) %>%
step_epi_lag(..., lag = args_list$lags) %>% # hmmm, same for all predictors
step_epi_ahead(response, ahead = args_list$ahead) %>%
# should use the internal function (in an open PR)
recipes::step_naomit(recipes::all_predictors()) %>%
recipes::step_naomit(recipes::all_outcomes(), skip = TRUE)
validate_forecaster_inputs(epi_data, outcome, predictors)
if (!is.list(trainer) || trainer$mode != "regression")
cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.")
lags <- arx_lags_validator(predictors, args_list$lags)

r <- epi_recipe(epi_data)
for (l in seq_along(lags)) {
p <- predictors[l]
r <- step_epi_lag(r, !!p, lag = lags[[l]])
}
r <- r %>%
step_epi_ahead(dplyr::all_of(!!outcome), ahead = args_list$ahead) %>%
step_epi_naomit()
# should limit the training window here (in an open PR)
# What to do if insufficient training data? Add issue.
# remove intercept? not sure how this is implemented in tidymodels

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
f <- frosting() %>%
layer_predict() %>%
layer_naomit(.pred) %>%
layer_residual_quantile(
# layer_naomit(.pred) %>%
layer_residual_quantiles(
probs = args_list$levels,
symmetrize = args_list$symmetrize) %>%
layer_threshold(.pred, dplyr::starts_with("q")) #, .flag = args_list$nonneg) in open PR
# need the target date processing here
layer_add_forecast_date(forecast_date = forecast_date) %>%
layer_add_target_date(target_date = target_date)
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))

latest <- get_test_data(r, epi_data)

epi_workflow(r, trainer) %>% # bug, issue 72
add_frosting(f)
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
list(
predictions = predict(wf, new_data = latest),
epi_workflow = wf
)
}


arx_lags_validator <- function(predictors, lags) {
p <- length(predictors)
if (!is.list(lags)) lags <- list(lags)
if (length(lags) == 1) lags <- rep(lags, p)
else if (length(lags) < p) {
cli_stop(
"You have requested {p} predictors but lags cannot be recycled to match."
)
}
lags
}
22 changes: 22 additions & 0 deletions R/compat-purrr.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,45 @@ map <- function(.x, .f, ...) {
.f <- rlang::as_function(.f, env = rlang::global_env())
lapply(.x, .f, ...)
}

walk <- function(.x, .f, ...) {
map(.x, .f, ...)
invisible(.x)
}

walk2 <- function(.x, .y, .f, ...) {
map2(.x, .y, .f, ...)
invisible(.x)
}

map_lgl <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, logical(1), ...)
}

map_int <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, integer(1), ...)
}

map_dbl <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, double(1), ...)
}

map_chr <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, character(1), ...)
}

map_dfr <- function(.x, .f, ..., .id = NULL) {
.f <- rlang::as_function(.f, env = global_env())
res <- map(.x, .f, ...)
dplyr::bind_rows(res, .id = .id)
}

map2_dfr <- function(.x, .y, .f, ..., .id = NULL) {
.f <- rlang::as_function(.f, env = global_env())
res <- map2(.x, .y, .f, ...)
dplyr::bind_rows(res, .id = .id)
}

.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) {
.f <- rlang::as_function(.f, env = rlang::global_env())
out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE)
Expand Down
4 changes: 2 additions & 2 deletions R/epi_keys.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ epi_keys.recipe <- function(x) {
epi_keys_mold <- function(mold) {
keys <- c("time_value", "geo_value", "key")
molded_names <- names(mold$extras$roles)
mold_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names)
unname(mold_keys)
mold_keys <- map(mold$extras$roles[molded_names %in% keys], names)
unname(unlist(mold_keys))
}

4 changes: 2 additions & 2 deletions R/epi_shift.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") {
tidyr::unchop(shift) %>% # what is chop
dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>%
# One list element for each shifted feature
purrr::pmap(function(i, shift, name) {
pmap(function(i, shift, name) {
tibble(keys,
time_value = time_value + shift, # Shift back
!!name := x[[i]])
})
if (is.data.frame(keys)) common_names <- c(names(keys), "time_value")
else common_names <- c("keys", "time_value")

purrr::reduce(out_list, dplyr::full_join, by = common_names)
reduce(out_list, dplyr::full_join, by = common_names)
}

epi_shift_single <- function(x, col, shift_val, newname, key_cols) {
Expand Down
1 change: 1 addition & 0 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ is_epi_workflow <- function(x) {
inherits(x, "epi_workflow")
}


#' Predict from an epi_workflow
#'
#' @description
Expand Down
File renamed without changes.
Loading