Skip to content

Commit 172ab1e

Browse files
authored
Merge pull request #116 from cmu-delphi/draft-intro
Intro vignette
2 parents d3f17ec + 38f3b1d commit 172ab1e

35 files changed

+1131
-106
lines changed

DESCRIPTION

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Imports:
3232
glue,
3333
hardhat (>= 1.2.0),
3434
magrittr,
35-
purrr,
3635
recipes (>= 1.0.0),
3736
rlang,
3837
stats,
@@ -42,7 +41,7 @@ Imports:
4241
tidyselect,
4342
usethis,
4443
vctrs,
45-
workflows
44+
workflows (>= 1.0.0)
4645
Suggests:
4746
covidcast,
4847
data.table,
@@ -51,9 +50,11 @@ Suggests:
5150
knitr,
5251
lubridate,
5352
parsnip (>= 1.0.0),
53+
ranger,
5454
RcppRoll,
5555
rmarkdown,
56-
testthat (>= 3.0.0)
56+
testthat (>= 3.0.0),
57+
xgboost
5758
VignetteBuilder:
5859
knitr
5960
Remotes:

NAMESPACE

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ S3method(extract_argument,frosting)
2121
S3method(extract_argument,layer)
2222
S3method(extract_argument,recipe)
2323
S3method(extract_argument,step)
24+
S3method(extract_frosting,default)
25+
S3method(extract_frosting,epi_workflow)
2426
S3method(extract_layers,frosting)
2527
S3method(extract_layers,workflow)
2628
S3method(extrapolate_quantiles,dist_default)
@@ -29,6 +31,7 @@ S3method(extrapolate_quantiles,distribution)
2931
S3method(format,dist_quantiles)
3032
S3method(median,dist_quantiles)
3133
S3method(predict,epi_workflow)
34+
S3method(predict,flatline)
3235
S3method(prep,epi_recipe)
3336
S3method(prep,step_epi_ahead)
3437
S3method(prep,step_epi_lag)
@@ -39,6 +42,7 @@ S3method(print,step_epi_ahead)
3942
S3method(print,step_epi_lag)
4043
S3method(quantile,dist_quantiles)
4144
S3method(refresh_blueprint,default_epi_recipe_blueprint)
45+
S3method(residuals,flatline)
4246
S3method(run_mold,default_epi_recipe_blueprint)
4347
S3method(slather,layer_add_forecast_date)
4448
S3method(slather,layer_add_target_date)
@@ -60,6 +64,7 @@ export(add_frosting)
6064
export(add_layer)
6165
export(apply_frosting)
6266
export(arx_args_list)
67+
export(arx_epi_forecaster)
6368
export(arx_forecaster)
6469
export(create_lags_and_leads)
6570
export(create_layer)
@@ -72,8 +77,12 @@ export(epi_recipe)
7277
export(epi_recipe_blueprint)
7378
export(epi_workflow)
7479
export(extract_argument)
80+
export(extract_frosting)
7581
export(extract_layers)
7682
export(extrapolate_quantiles)
83+
export(flatline)
84+
export(flatline_args_list)
85+
export(flatline_epi_forecaster)
7786
export(frosting)
7887
export(get_precision)
7988
export(get_test_data)

R/arx_forecaster.R

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ arx_forecaster <- function(x, y, key_vars, time_value,
3434
}
3535

3636
dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys)
37-
if (intercept) dat$x0 <- 1
37+
dat$x0 <- 1
3838

3939
obj <- stats::lm(
4040
y1 ~ . + 0,
@@ -67,14 +67,28 @@ arx_forecaster <- function(x, y, key_vars, time_value,
6767
#'
6868
#' Constructs a list of arguments for [arx_forecaster()].
6969
#'
70-
#' @template param-lags
71-
#' @template param-ahead
72-
#' @template param-min_train_window
73-
#' @template param-levels
74-
#' @template param-intercept
75-
#' @template param-symmetrize
76-
#' @template param-nonneg
77-
#' @param quantile_by_key Not currently implemented
70+
#' @param lags Vector or List. Positive integers enumerating lags to use
71+
#' in autoregressive-type models (in days).
72+
#' @param ahead Integer. Number of time steps ahead (in days) of the forecast
73+
#' date for which forecasts should be produced.
74+
#' @param min_train_window Integer. The minimal amount of training
75+
#' data (in the time unit of the `epi_df`) needed to produce a forecast.
76+
#' If smaller, the forecaster will return `NA` predictions.
77+
#' @param forecast_date The date on which the forecast is created. The default
78+
#' `NULL` will attempt to determine this automatically.
79+
#' @param target_date The date for which the forecast is intended. The default
80+
#' `NULL` will attempt to determine this automatically.
81+
#' @param levels Vector or `NULL`. A vector of probabilities to produce
82+
#' prediction intervals. These are created by computing the quantiles of
83+
#' training residuals. A `NULL` value will result in point forecasts only.
84+
#' @param symmetrize Logical. The default `TRUE` calculates
85+
#' symmetric prediction intervals.
86+
#' @param nonneg Logical. The default `TRUE` enforces nonnegative predictions
87+
#' by hard-thresholding at 0.
88+
#' @param quantile_by_key Character vector. Groups residuals by listed keys
89+
#' before calculating residual quantiles. See the `by_key` argument to
90+
#' [layer_residual_quantiles()] for more information. The default,
91+
#' `character(0)` performs no grouping.
7892
#'
7993
#' @return A list containing updated parameter choices.
8094
#' @export
@@ -83,28 +97,36 @@ arx_forecaster <- function(x, y, key_vars, time_value,
8397
#' arx_args_list()
8498
#' arx_args_list(symmetrize = FALSE)
8599
#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120)
86-
arx_args_list <- function(lags = c(0, 7, 14), ahead = 7, min_train_window = 20,
87-
levels = c(0.05, 0.95), intercept = TRUE,
100+
arx_args_list <- function(lags = c(0L, 7L, 14L),
101+
ahead = 7L,
102+
min_train_window = 20L,
103+
forecast_date = NULL,
104+
target_date = NULL,
105+
levels = c(0.05, 0.95),
88106
symmetrize = TRUE,
89107
nonneg = TRUE,
90-
quantile_by_key = FALSE) {
108+
quantile_by_key = character(0L)) {
91109

92110
# error checking if lags is a list
93111
.lags <- lags
94112
if (is.list(lags)) lags <- unlist(lags)
95113

96-
arg_is_scalar(ahead, min_train_window)
114+
arg_is_scalar(ahead, min_train_window, symmetrize, nonneg)
115+
arg_is_chr(quantile_by_key, allow_null = TRUE)
116+
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
97117
arg_is_nonneg_int(ahead, min_train_window, lags)
98-
arg_is_lgl(intercept, symmetrize, nonneg)
118+
arg_is_lgl(symmetrize, nonneg)
99119
arg_is_probabilities(levels, allow_null = TRUE)
100120

101121
max_lags <- max(lags)
102-
103-
list(
104-
lags = .lags, ahead = as.integer(ahead),
105-
min_train_window = min_train_window,
106-
levels = levels, intercept = intercept,
107-
symmetrize = symmetrize, nonneg = nonneg,
108-
max_lags = max_lags
109-
)
110-
}
122+
enlist(lags = .lags,
123+
ahead,
124+
min_train_window,
125+
levels,
126+
forecast_date,
127+
target_date,
128+
symmetrize,
129+
nonneg,
130+
max_lags,
131+
quantile_by_key)
132+
}

R/arx_forecaster_mod.R

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,83 @@
1-
arx_epi_forecaster <- function(epi_data, response,
2-
...,
1+
#' Direct autoregressive forecaster with covariates
2+
#'
3+
#' This is an autoregressive forecasting model for
4+
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
5+
#' that it estimates a model for a particular target horizon.
6+
#'
7+
#'
8+
#' @param epi_data An `epi_df` object
9+
#' @param outcome A character (scalar) specifying the outcome (in the
10+
#' `epi_df`).
11+
#' @param predictors A character vector giving column(s) of predictor
12+
#' variables.
13+
#' @param trainer A `{parsnip}` model describing the type of estimation.
14+
#' For now, we enforce `mode = "regression"`.
15+
#' @param args_list A list of customization arguments to determine
16+
#' the type of forecasting model. See [arx_args_list()].
17+
#'
18+
#' @return A list with (1) `predictions` an `epi_df` of predicted values
19+
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
20+
#' workflow
21+
#' @export
22+
#'
23+
#' @examples
24+
#' jhu <- case_death_rate_subset %>%
25+
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
26+
#'
27+
#' out <- arx_epi_forecaster(jhu, "death_rate",
28+
#' c("case_rate", "death_rate"))
29+
arx_epi_forecaster <- function(epi_data,
30+
outcome,
31+
predictors,
332
trainer = parsnip::linear_reg(),
433
args_list = arx_args_list()) {
534

6-
r <- epi_recipe(epi_data) %>%
7-
step_epi_lag(..., lag = args_list$lags) %>% # hmmm, same for all predictors
8-
step_epi_ahead(response, ahead = args_list$ahead) %>%
9-
# should use the internal function (in an open PR)
10-
recipes::step_naomit(recipes::all_predictors()) %>%
11-
recipes::step_naomit(recipes::all_outcomes(), skip = TRUE)
35+
validate_forecaster_inputs(epi_data, outcome, predictors)
36+
if (!is.list(trainer) || trainer$mode != "regression")
37+
cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.")
38+
lags <- arx_lags_validator(predictors, args_list$lags)
39+
40+
r <- epi_recipe(epi_data)
41+
for (l in seq_along(lags)) {
42+
p <- predictors[l]
43+
r <- step_epi_lag(r, !!p, lag = lags[[l]])
44+
}
45+
r <- r %>%
46+
step_epi_ahead(dplyr::all_of(!!outcome), ahead = args_list$ahead) %>%
47+
step_epi_naomit()
1248
# should limit the training window here (in an open PR)
1349
# What to do if insufficient training data? Add issue.
14-
# remove intercept? not sure how this is implemented in tidymodels
50+
51+
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
52+
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
1553
f <- frosting() %>%
1654
layer_predict() %>%
17-
layer_naomit(.pred) %>%
18-
layer_residual_quantile(
55+
# layer_naomit(.pred) %>%
56+
layer_residual_quantiles(
1957
probs = args_list$levels,
2058
symmetrize = args_list$symmetrize) %>%
21-
layer_threshold(.pred, dplyr::starts_with("q")) #, .flag = args_list$nonneg) in open PR
22-
# need the target date processing here
59+
layer_add_forecast_date(forecast_date = forecast_date) %>%
60+
layer_add_target_date(target_date = target_date)
61+
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
2362

2463
latest <- get_test_data(r, epi_data)
2564

26-
epi_workflow(r, trainer) %>% # bug, issue 72
27-
add_frosting(f)
65+
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
66+
list(
67+
predictions = predict(wf, new_data = latest),
68+
epi_workflow = wf
69+
)
70+
}
71+
2872

73+
arx_lags_validator <- function(predictors, lags) {
74+
p <- length(predictors)
75+
if (!is.list(lags)) lags <- list(lags)
76+
if (length(lags) == 1) lags <- rep(lags, p)
77+
else if (length(lags) < p) {
78+
cli_stop(
79+
"You have requested {p} predictors but lags cannot be recycled to match."
80+
)
81+
}
82+
lags
2983
}

R/compat-purrr.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,45 @@ map <- function(.x, .f, ...) {
55
.f <- rlang::as_function(.f, env = rlang::global_env())
66
lapply(.x, .f, ...)
77
}
8+
89
walk <- function(.x, .f, ...) {
910
map(.x, .f, ...)
1011
invisible(.x)
1112
}
1213

14+
walk2 <- function(.x, .y, .f, ...) {
15+
map2(.x, .y, .f, ...)
16+
invisible(.x)
17+
}
18+
1319
map_lgl <- function(.x, .f, ...) {
1420
.rlang_purrr_map_mold(.x, .f, logical(1), ...)
1521
}
22+
1623
map_int <- function(.x, .f, ...) {
1724
.rlang_purrr_map_mold(.x, .f, integer(1), ...)
1825
}
26+
1927
map_dbl <- function(.x, .f, ...) {
2028
.rlang_purrr_map_mold(.x, .f, double(1), ...)
2129
}
30+
2231
map_chr <- function(.x, .f, ...) {
2332
.rlang_purrr_map_mold(.x, .f, character(1), ...)
2433
}
34+
35+
map_dfr <- function(.x, .f, ..., .id = NULL) {
36+
.f <- rlang::as_function(.f, env = global_env())
37+
res <- map(.x, .f, ...)
38+
dplyr::bind_rows(res, .id = .id)
39+
}
40+
41+
map2_dfr <- function(.x, .y, .f, ..., .id = NULL) {
42+
.f <- rlang::as_function(.f, env = global_env())
43+
res <- map2(.x, .y, .f, ...)
44+
dplyr::bind_rows(res, .id = .id)
45+
}
46+
2547
.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) {
2648
.f <- rlang::as_function(.f, env = rlang::global_env())
2749
out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE)

R/epi_keys.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ epi_keys.recipe <- function(x) {
2929
epi_keys_mold <- function(mold) {
3030
keys <- c("time_value", "geo_value", "key")
3131
molded_names <- names(mold$extras$roles)
32-
mold_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names)
33-
unname(mold_keys)
32+
mold_keys <- map(mold$extras$roles[molded_names %in% keys], names)
33+
unname(unlist(mold_keys))
3434
}
3535

R/epi_shift.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") {
1919
tidyr::unchop(shift) %>% # what is chop
2020
dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>%
2121
# One list element for each shifted feature
22-
purrr::pmap(function(i, shift, name) {
22+
pmap(function(i, shift, name) {
2323
tibble(keys,
2424
time_value = time_value + shift, # Shift back
2525
!!name := x[[i]])
2626
})
2727
if (is.data.frame(keys)) common_names <- c(names(keys), "time_value")
2828
else common_names <- c("keys", "time_value")
2929

30-
purrr::reduce(out_list, dplyr::full_join, by = common_names)
30+
reduce(out_list, dplyr::full_join, by = common_names)
3131
}
3232

3333
epi_shift_single <- function(x, col, shift_val, newname, key_cols) {

R/epi_workflow.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ is_epi_workflow <- function(x) {
6262
inherits(x, "epi_workflow")
6363
}
6464

65+
6566
#' Predict from an epi_workflow
6667
#'
6768
#' @description
File renamed without changes.

0 commit comments

Comments
 (0)