Skip to content

Commit 5e50a5a

Browse files
authored
Merge pull request #319 from cmu-delphi/ds/forecast
feat: add forecast method
2 parents cd3fe2e + b46ebb7 commit 5e50a5a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+325
-333
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.13
3+
Version: 0.0.14
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ S3method(extrapolate_quantiles,distribution)
4545
S3method(fit,epi_workflow)
4646
S3method(flusight_hub_formatter,canned_epipred)
4747
S3method(flusight_hub_formatter,data.frame)
48+
S3method(forecast,epi_workflow)
4849
S3method(format,dist_quantiles)
4950
S3method(is.na,dist_quantiles)
5051
S3method(is.na,distribution)
@@ -152,6 +153,7 @@ export(flatline)
152153
export(flatline_args_list)
153154
export(flatline_forecaster)
154155
export(flusight_hub_formatter)
156+
export(forecast)
155157
export(frosting)
156158
export(get_test_data)
157159
export(grab_names)
@@ -219,6 +221,7 @@ importFrom(dplyr,ungroup)
219221
importFrom(epiprocess,growth_rate)
220222
importFrom(generics,augment)
221223
importFrom(generics,fit)
224+
importFrom(generics,forecast)
222225
importFrom(ggplot2,autoplot)
223226
importFrom(hardhat,refresh_blueprint)
224227
importFrom(hardhat,run_mold)

NEWS.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
3131
- Working vignette
3232
- use `checkmate` for input validation
3333
- refactor quantile extrapolation (possibly creates different results)
34-
- force `target_date` + `forecast_date` handling to match the time_type of
35-
the epi_df. allows for annual and weekly data
34+
- force `target_date` + `forecast_date` handling to match the time_type of the
35+
epi_df. allows for annual and weekly data
3636
- add `check_enough_train_data()` that will error if training data is too small
3737
- added `check_enough_train_data()` to `arx_forecaster()`
38-
- `layer_residual_quantiles()` will now error if any of the residual quantiles are NA
38+
- `layer_residual_quantiles()` will now error if any of the residual quantiles
39+
are NA
3940
- `*_args_list()` functions now warn if `forecast_date + ahead != target_date`
40-
- the `predictor` argument in `arx_forecaster()` now defaults to the value of the `outcome` argument
41+
- the `predictor` argument in `arx_forecaster()` now defaults to the value of
42+
the `outcome` argument
4143
- `arx_fcast_epi_workflow()` and `arx_class_epi_workflow()` now default to
4244
`trainer = parsnip::logistic_reg()` to match their more canned versions.
45+
- add a `forecast()` method simplify generating forecasts

R/arx_classifier.R

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,15 @@ arx_classifier <- function(
5151
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
5252
}
5353

54-
wf <- arx_class_epi_workflow(
55-
epi_data, outcome, predictors, trainer, args_list
56-
)
57-
58-
latest <- get_test_data(
59-
hardhat::extract_preprocessor(wf), epi_data, TRUE, args_list$nafill_buffer,
60-
args_list$forecast_date %||% max(epi_data$time_value)
61-
)
62-
54+
wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
6355
wf <- generics::fit(wf, epi_data)
64-
preds <- predict(wf, new_data = latest) %>%
56+
57+
preds <- forecast(
58+
wf,
59+
fill_locf = TRUE,
60+
n_recent = args_list$nafill_buffer,
61+
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
62+
) %>%
6563
tibble::as_tibble() %>%
6664
dplyr::select(-time_value)
6765

R/arx_forecaster.R

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,25 @@
3838
#' trainer = quantile_reg(),
3939
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
4040
#' )
41-
arx_forecaster <- function(epi_data,
42-
outcome,
43-
predictors = outcome,
44-
trainer = parsnip::linear_reg(),
45-
args_list = arx_args_list()) {
41+
arx_forecaster <- function(
42+
epi_data,
43+
outcome,
44+
predictors = outcome,
45+
trainer = parsnip::linear_reg(),
46+
args_list = arx_args_list()) {
4647
if (!is_regression(trainer)) {
4748
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
4849
}
4950

50-
wf <- arx_fcast_epi_workflow(
51-
epi_data, outcome, predictors, trainer, args_list
52-
)
53-
54-
latest <- get_test_data(
55-
hardhat::extract_preprocessor(wf), epi_data, TRUE, args_list$nafill_buffer,
56-
args_list$forecast_date %||% max(epi_data$time_value)
57-
)
58-
51+
wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5952
wf <- generics::fit(wf, epi_data)
60-
preds <- predict(wf, new_data = latest) %>%
53+
54+
preds <- forecast(
55+
wf,
56+
fill_locf = TRUE,
57+
n_recent = args_list$nafill_buffer,
58+
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
59+
) %>%
6160
tibble::as_tibble() %>%
6261
dplyr::select(-time_value)
6362

R/autoplot.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ ggplot2::autoplot
6161
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
6262
#' step_epi_naomit()
6363
#' ewf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu)
64-
#' td <- get_test_data(r, jhu)
65-
#' predict(ewf, new_data = td)
64+
#' forecast(ewf)
6665
#' })
6766
#'
6867
#' p <- do.call(rbind, p)

R/epi_workflow.R

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ update_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
197197
#'
198198
#' @export
199199
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) {
200-
object$fit$meta <- list(max_time_value = max(data$time_value), as_of = attributes(data)$metadata$as_of)
200+
object$fit$meta <- list(
201+
max_time_value = max(data$time_value),
202+
as_of = attributes(data)$metadata$as_of
203+
)
204+
object$original_data <- data
201205

202206
NextMethod()
203207
}
@@ -326,3 +330,54 @@ print.epi_workflow <- function(x, ...) {
326330
print_postprocessor(x)
327331
invisible(x)
328332
}
333+
334+
335+
#' Produce a forecast from an epi workflow
336+
#'
337+
#' @param object An epi workflow.
338+
#' @param ... Not used.
339+
#' @param fill_locf Logical. Should we use locf to fill in missing data?
340+
#' @param n_recent Integer or NULL. If filling missing data with locf = TRUE,
341+
#' how far back are we willing to tolerate missing data? Larger values allow
342+
#' more filling. The default NULL will determine this from the the recipe. For
343+
#' example, suppose n_recent = 3, then if the 3 most recent observations in any
344+
#' geo_value are all NA’s, we won’t be able to fill anything, and an error
345+
#' message will be thrown. (See details.)
346+
#' @param forecast_date By default, this is set to the maximum time_value in x.
347+
#' But if there is data latency such that recent NA's should be filled, this may
348+
#' be after the last available time_value.
349+
#'
350+
#' @return A forecast tibble.
351+
#'
352+
#' @export
353+
forecast.epi_workflow <- function(object, ..., fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) {
354+
rlang::check_dots_empty()
355+
356+
if (!object$trained) {
357+
cli_abort(c(
358+
"You cannot `forecast()` a {.cls workflow} that has not been trained.",
359+
i = "Please use `fit()` before forecasting."
360+
))
361+
}
362+
363+
frosting_fd <- NULL
364+
if (has_postprocessor(object) && detect_layer(object, "layer_add_forecast_date")) {
365+
frosting_fd <- extract_argument(object, "layer_add_forecast_date", "forecast_date")
366+
if (!is.null(frosting_fd) && class(frosting_fd) != class(object$original_data$time_value)) {
367+
cli_abort(c(
368+
"Error with layer_add_forecast_date():",
369+
i = "The type of `forecast_date` must match the type of the `time_value` column in the data."
370+
))
371+
}
372+
}
373+
374+
test_data <- get_test_data(
375+
hardhat::extract_preprocessor(object),
376+
object$original_data,
377+
fill_locf = fill_locf,
378+
n_recent = n_recent %||% Inf,
379+
forecast_date = forecast_date %||% frosting_fd %||% max(object$original_data$time_value)
380+
)
381+
382+
predict(object, new_data = test_data)
383+
}

R/flatline_forecaster.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ flatline_forecaster <- function(
4949
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
5050
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
5151

52-
latest <- get_test_data(
53-
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
54-
forecast_date
55-
)
56-
5752
f <- frosting() %>%
5853
layer_predict() %>%
5954
layer_residual_quantiles(
@@ -69,7 +64,12 @@ flatline_forecaster <- function(
6964

7065
wf <- epi_workflow(r, eng, f)
7166
wf <- generics::fit(wf, epi_data)
72-
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
67+
preds <- suppressWarnings(forecast(
68+
wf,
69+
fill_locf = TRUE,
70+
n_recent = args_list$nafill_buffer,
71+
forecast_date = forecast_date
72+
)) %>%
7373
tibble::as_tibble() %>%
7474
dplyr::select(-time_value)
7575

R/frosting.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,15 +275,14 @@ new_frosting <- function() {
275275
#' step_epi_naomit()
276276
#'
277277
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
278-
#' latest <- get_test_data(recipe = r, x = jhu)
279278
#'
280279
#' f <- frosting() %>%
281280
#' layer_predict() %>%
282281
#' layer_naomit(.pred)
283282
#'
284283
#' wf1 <- wf %>% add_frosting(f)
285284
#'
286-
#' p <- predict(wf1, latest)
285+
#' p <- forecast(wf1)
287286
#' p
288287
frosting <- function(layers = NULL, requirements = NULL) {
289288
if (!is_null(layers) || !is_null(requirements)) {

R/layer_add_target_date.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,16 @@
2828
#' step_epi_naomit()
2929
#'
3030
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
31-
#' latest <- get_test_data(r, jhu)
3231
#'
3332
#' # Use ahead + forecast date
3433
#' f <- frosting() %>%
3534
#' layer_predict() %>%
36-
#' layer_add_forecast_date(forecast_date = "2022-05-31") %>%
35+
#' layer_add_forecast_date(forecast_date = as.Date("2022-05-31")) %>%
3736
#' layer_add_target_date() %>%
3837
#' layer_naomit(.pred)
3938
#' wf1 <- wf %>% add_frosting(f)
4039
#'
41-
#' p <- predict(wf1, latest)
40+
#' p <- forecast(wf1)
4241
#' p
4342
#'
4443
#' # Use ahead + max time value from pre, fit, post
@@ -49,7 +48,7 @@
4948
#' layer_naomit(.pred)
5049
#' wf2 <- wf %>% add_frosting(f2)
5150
#'
52-
#' p2 <- predict(wf2, latest)
51+
#' p2 <- forecast(wf2)
5352
#' p2
5453
#'
5554
#' # Specify own target date
@@ -59,7 +58,7 @@
5958
#' layer_naomit(.pred)
6059
#' wf3 <- wf %>% add_frosting(f3)
6160
#'
62-
#' p3 <- predict(wf3, latest)
61+
#' p3 <- forecast(wf3)
6362
#' p3
6463
layer_add_target_date <-
6564
function(frosting, target_date = NULL, id = rand_id("add_target_date")) {

R/layer_cdc_flatline_quantiles.R

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,14 @@
6464
#'
6565
#' forecast_date <- max(case_death_rate_subset$time_value)
6666
#'
67-
#' latest <- get_test_data(
68-
#' epi_recipe(case_death_rate_subset), case_death_rate_subset
69-
#' )
70-
#'
7167
#' f <- frosting() %>%
7268
#' layer_predict() %>%
7369
#' layer_cdc_flatline_quantiles(aheads = c(7, 14, 21, 28), symmetrize = TRUE)
7470
#'
7571
#' eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
7672
#'
7773
#' wf <- epi_workflow(r, eng, f) %>% fit(case_death_rate_subset)
78-
#' preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
74+
#' preds <- forecast(wf) %>%
7975
#' dplyr::select(-time_value) %>%
8076
#' dplyr::mutate(forecast_date = forecast_date)
8177
#' preds

R/layer_naomit.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
#'
2121
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
2222
#'
23-
#' latest <- get_test_data(recipe = r, x = jhu)
24-
#'
2523
#' f <- frosting() %>%
2624
#' layer_predict() %>%
2725
#' layer_naomit(.pred)
2826
#'
2927
#' wf1 <- wf %>% add_frosting(f)
3028
#'
31-
#' p <- predict(wf1, latest)
29+
#' p <- forecast(wf1)
3230
#' p
3331
layer_naomit <- function(frosting, ..., id = rand_id("naomit")) {
3432
arg_is_chr_scalar(id)

R/layer_point_from_distn.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@
2626
#'
2727
#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>% fit(jhu)
2828
#'
29-
#' latest <- get_test_data(recipe = r, x = jhu)
30-
#'
3129
#' f1 <- frosting() %>%
3230
#' layer_predict() %>%
3331
#' layer_quantile_distn() %>% # puts the other quantiles in a different col
3432
#' layer_point_from_distn() %>% # mutate `.pred` to contain only a point prediction
3533
#' layer_naomit(.pred)
3634
#' wf1 <- wf %>% add_frosting(f1)
3735
#'
38-
#' p1 <- predict(wf1, latest)
36+
#' p1 <- forecast(wf1)
3937
#' p1
4038
#'
4139
#' f2 <- frosting() %>%
@@ -44,7 +42,7 @@
4442
#' layer_naomit(.pred)
4543
#' wf2 <- wf %>% add_frosting(f2)
4644
#'
47-
#' p2 <- predict(wf2, latest)
45+
#' p2 <- forecast(wf2)
4846
#' p2
4947
layer_point_from_distn <- function(frosting,
5048
...,

R/layer_population_scaling.R

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,7 @@
7878
#' fit(jhu) %>%
7979
#' add_frosting(f)
8080
#'
81-
#' latest <- get_test_data(
82-
#' recipe = r,
83-
#' x = epiprocess::jhu_csse_daily_subset %>%
84-
#' dplyr::filter(
85-
#' time_value > "2021-11-01",
86-
#' geo_value %in% c("ca", "ny")
87-
#' ) %>%
88-
#' dplyr::select(geo_value, time_value, cases)
89-
#' )
90-
#'
91-
#' predict(wf, latest)
81+
#' forecast(wf)
9282
layer_population_scaling <- function(frosting,
9383
...,
9484
df,

R/layer_predictive_distn.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,13 @@
3030
#'
3131
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
3232
#'
33-
#' latest <- get_test_data(recipe = r, x = jhu)
34-
#'
3533
#' f <- frosting() %>%
3634
#' layer_predict() %>%
3735
#' layer_predictive_distn() %>%
3836
#' layer_naomit(.pred)
3937
#' wf1 <- wf %>% add_frosting(f)
4038
#'
41-
#' p <- predict(wf1, latest)
39+
#' p <- forecast(wf1)
4240
#' p
4341
layer_predictive_distn <- function(frosting,
4442
...,

R/layer_quantile_distn.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@
2828
#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>%
2929
#' fit(jhu)
3030
#'
31-
#' latest <- get_test_data(recipe = r, x = jhu)
32-
#'
3331
#' f <- frosting() %>%
3432
#' layer_predict() %>%
3533
#' layer_quantile_distn() %>%
3634
#' layer_naomit(.pred)
3735
#' wf1 <- wf %>% add_frosting(f)
3836
#'
39-
#' p <- predict(wf1, latest)
37+
#' p <- forecast(wf1)
4038
#' p
4139
layer_quantile_distn <- function(frosting,
4240
...,

0 commit comments

Comments
 (0)