Skip to content

Commit 6792363

Browse files
authored
Merge pull request #123 from cmu-delphi/move-old-forecasters
Move old forecasters
2 parents 172ab1e + 8d1609f commit 6792363

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+267
-631
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.1
3+
Version: 0.0.2
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut","cre")),
66
person("Jacob", "Bien", role = "aut"),
@@ -19,6 +19,7 @@ Description: A forecasting "framework" for creating epidemiological forecasts
1919
License: MIT + file LICENSE
2020
URL: https://github.com/cmu-delphi/epipredict/,
2121
https://cmu-delphi.github.io/epipredict
22+
BugReports: https://github.com/cmu-delphi/epipredict/issues/
2223
Depends:
2324
R (>= 3.5.0)
2425
Imports:

NAMESPACE

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ export(add_frosting)
6464
export(add_layer)
6565
export(apply_frosting)
6666
export(arx_args_list)
67-
export(arx_epi_forecaster)
6867
export(arx_forecaster)
69-
export(create_lags_and_leads)
7068
export(create_layer)
7169
export(default_epi_recipe_blueprint)
7270
export(detect_layer)
@@ -82,18 +80,13 @@ export(extract_layers)
8280
export(extrapolate_quantiles)
8381
export(flatline)
8482
export(flatline_args_list)
85-
export(flatline_epi_forecaster)
83+
export(flatline_forecaster)
8684
export(frosting)
87-
export(get_precision)
8885
export(get_test_data)
8986
export(grab_names)
9087
export(is_epi_recipe)
9188
export(is_epi_workflow)
9289
export(is_layer)
93-
export(knn_iteraive_ar_args_list)
94-
export(knn_iteraive_ar_forecaster)
95-
export(knnarx_args_list)
96-
export(knnarx_forecaster)
9790
export(layer)
9891
export(layer_add_forecast_date)
9992
export(layer_add_target_date)
@@ -108,8 +101,6 @@ export(new_default_epi_recipe_blueprint)
108101
export(new_epi_recipe_blueprint)
109102
export(remove_frosting)
110103
export(slather)
111-
export(smooth_arx_args_list)
112-
export(smooth_arx_forecaster)
113104
export(step_epi_ahead)
114105
export(step_epi_lag)
115106
export(step_epi_naomit)
@@ -139,5 +130,4 @@ importFrom(stats,predict)
139130
importFrom(stats,qnorm)
140131
importFrom(stats,quantile)
141132
importFrom(stats,residuals)
142-
importFrom(stats,setNames)
143133
importFrom(tibble,tibble)

R/arx_forecaster.R

Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,87 @@
1-
#' AR forecaster with optional covariates
1+
#' Direct autoregressive forecaster with covariates
22
#'
3-
#' @param x Covariates. Allowed to be missing (resulting in AR on `y`).
4-
#' @param y Response.
5-
#' @param key_vars Factor(s). A prediction will be made for each unique
6-
#' combination.
7-
#' @param time_value the time value associated with each row of measurements.
8-
#' @param args Additional arguments specifying the forecasting task. Created
9-
#' by calling `arx_args_list()`.
3+
#' This is an autoregressive forecasting model for
4+
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
5+
#' that it estimates a model for a particular target horizon.
106
#'
11-
#' @return A data frame of point (and optionally interval) forecasts at a single
12-
#' ahead (unique horizon) for each unique combination of `key_vars`.
7+
#'
8+
#' @param epi_data An `epi_df` object
9+
#' @param outcome A character (scalar) specifying the outcome (in the
10+
#' `epi_df`).
11+
#' @param predictors A character vector giving column(s) of predictor
12+
#' variables.
13+
#' @param trainer A `{parsnip}` model describing the type of estimation.
14+
#' For now, we enforce `mode = "regression"`.
15+
#' @param args_list A list of customization arguments to determine
16+
#' the type of forecasting model. See [arx_args_list()].
17+
#'
18+
#' @return A list with (1) `predictions` an `epi_df` of predicted values
19+
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
20+
#' workflow
1321
#' @export
14-
arx_forecaster <- function(x, y, key_vars, time_value,
15-
args = arx_args_list()) {
22+
#'
23+
#' @examples
24+
#' jhu <- case_death_rate_subset %>%
25+
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
26+
#'
27+
#' out <- arx_forecaster(jhu, "death_rate",
28+
#' c("case_rate", "death_rate"))
29+
arx_forecaster <- function(epi_data,
30+
outcome,
31+
predictors,
32+
trainer = parsnip::linear_reg(),
33+
args_list = arx_args_list()) {
1634

17-
# TODO: function to verify standard forecaster signature inputs
35+
validate_forecaster_inputs(epi_data, outcome, predictors)
36+
if (!is.list(trainer) || trainer$mode != "regression")
37+
cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.")
38+
lags <- arx_lags_validator(predictors, args_list$lags)
1839

19-
assign_arg_list(args)
20-
if (is.null(key_vars)) { # this is annoying/repetitive, seemingly necessary?
21-
keys <- NULL
22-
distinct_keys <- tibble(.dump = NA)
23-
} else {
24-
keys <- tibble::tibble(key_vars)
25-
distinct_keys <- dplyr::distinct(keys)
40+
r <- epi_recipe(epi_data)
41+
for (l in seq_along(lags)) {
42+
p <- predictors[l]
43+
r <- step_epi_lag(r, !!p, lag = lags[[l]])
2644
}
45+
r <- r %>%
46+
step_epi_ahead(dplyr::all_of(!!outcome), ahead = args_list$ahead) %>%
47+
step_epi_naomit()
48+
# should limit the training window here (in an open PR)
49+
# What to do if insufficient training data? Add issue.
2750

28-
# Return NA if insufficient training data
29-
if (length(y) < min_train_window + max_lags + ahead) {
30-
qnames <- probs_to_string(levels)
31-
out <- dplyr::bind_cols(distinct_keys, point = NA) %>%
32-
dplyr::select(!dplyr::any_of(".dump"))
33-
return(enframer(out, qnames))
34-
}
51+
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
52+
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
53+
f <- frosting() %>%
54+
layer_predict() %>%
55+
# layer_naomit(.pred) %>%
56+
layer_residual_quantiles(
57+
probs = args_list$levels,
58+
symmetrize = args_list$symmetrize) %>%
59+
layer_add_forecast_date(forecast_date = forecast_date) %>%
60+
layer_add_target_date(target_date = target_date)
61+
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
3562

36-
dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys)
37-
dat$x0 <- 1
63+
latest <- get_test_data(r, epi_data)
3864

39-
obj <- stats::lm(
40-
y1 ~ . + 0,
41-
data = dat %>% dplyr::select(starts_with(c("x", "y")))
65+
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
66+
list(
67+
predictions = predict(wf, new_data = latest),
68+
epi_workflow = wf
4269
)
70+
}
4371

44-
point <- make_predictions(obj, dat, time_value, keys)
45-
46-
# Residuals, simplest case, requires
47-
# 1. same quantiles for all keys
48-
# 2. `residuals(obj)` works
49-
r <- residuals(obj)
50-
q <- residual_quantiles(r, point, levels, symmetrize)
5172

52-
# Harder case requires handling failures of 1 and or 2, neither implemented
53-
# 1. different quantiles by key, need to bind the keys, then group_modify
54-
# 2 fails. need to bind the keys, grab, y and yhat, subtract
55-
if (nonneg) {
56-
q <- dplyr::mutate(q, dplyr::across(dplyr::everything(), ~ pmax(.x, 0)))
73+
arx_lags_validator <- function(predictors, lags) {
74+
p <- length(predictors)
75+
if (!is.list(lags)) lags <- list(lags)
76+
if (length(lags) == 1) lags <- rep(lags, p)
77+
else if (length(lags) < p) {
78+
cli_stop(
79+
"You have requested {p} predictors but lags cannot be recycled to match."
80+
)
5781
}
58-
59-
return(
60-
dplyr::bind_cols(distinct_keys, q) %>%
61-
dplyr::select(!dplyr::any_of(".dump"))
62-
)
82+
lags
6383
}
6484

65-
6685
#' ARX forecaster argument constructor
6786
#'
6887
#' Constructs a list of arguments for [arx_forecaster()].

R/arx_forecaster_mod.R

Lines changed: 0 additions & 83 deletions
This file was deleted.

R/blueprint-epi_recipe-default.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#' @details The `bake_dependent_roles` are automatically set to `epi_df` defaults.
1010
#' @return A recipe blueprint.
1111
#'
12+
#' @keywords internal
1213
#' @export
1314
new_epi_recipe_blueprint <-
1415
function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE,

R/df_mat_mul.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#' @return A data.frame with the new columns at the right. Original
1414
#' columns are removed.
1515
#' @export
16+
#' @keywords internal
1617
#'
1718
#' @examples
1819
#' df <- data.frame(matrix(1:200, ncol = 10))

R/epi_keys.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#' @param x a data.frame, tibble, or epi_df
44
#'
55
#' @return If an `epi_df`, this returns all "keys". Otherwise `NULL`
6+
#' @keywords internal
67
#' @export
7-
#'
88
epi_keys <- function(x) {
99
UseMethod("epi_keys")
1010
}

R/epi_recipe.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ epi_form2args <- function(formula, data, ...) {
215215
#' @param x An object.
216216
#' @return `TRUE` if the object inherits from `epi_recipe`.
217217
#'
218+
#' @keywords internal
218219
#' @export
219220
is_epi_recipe <- function(x) {
220221
inherits(x, "epi_recipe")

R/epi_shift.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#' @param keys Data frame, vector, or `NULL`. Additional grouping vars.
1111
#' @param out_name Chr. The output list will use this as a prefix.
1212
#'
13+
#' @keywords internal
14+
#'
1315
#' @return a list of tibbles
1416
epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") {
1517
if (!is.data.frame(x)) x <- data.frame(x)

R/epi_workflow.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL)
5757
#' @param x An object.
5858
#' @return `TRUE` if the object inherits from `epi_workflow`.
5959
#'
60+
#' @keywords internal
6061
#' @export
6162
is_epi_workflow <- function(x) {
6263
inherits(x, "epi_workflow")

R/extract.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#' @return An object originally passed as an argument to a layer or step
99
#' @export
1010
#'
11+
#' @keywords internal
12+
#'
1113
#' @examples
1214
#' f <- frosting() %>%
1315
#' layer_predict() %>%

R/flatline.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#' horizon. The right hand side must contain any keys (locations) for the
1111
#' panel data separated by plus. The observed time series must come last.
1212
#' For example
13-
#' ```
13+
#' ```r
1414
#' form <- as.formula(lead7_y ~ state + age + y)
1515
#' ```
1616
#' Note that this function doesn't DO the shifting, that has to be done
@@ -26,6 +26,7 @@
2626
#' predictions for future data (the last observed of the outcome for each
2727
#' combination of keys.
2828
#' @export
29+
#' @keywords internal
2930
#'
3031
#' @examples
3132
#' tib <- data.frame(y = runif(100),

R/flatline_epi_forecaster.R renamed to R/flatline_forecaster.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
#' jhu <- case_death_rate_subset %>%
2727
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
2828
#'
29-
#' out <- flatline_epi_forecaster(jhu, "death_rate")
30-
flatline_epi_forecaster <- function(epi_data,
31-
outcome,
32-
args_list = flatline_args_list()) {
29+
#' out <- flatline_forecaster(jhu, "death_rate")
30+
flatline_forecaster <- function(epi_data,
31+
outcome,
32+
args_list = flatline_args_list()) {
3333

3434
validate_forecaster_inputs(epi_data, outcome, "time_value")
3535
keys <- epi_keys(epi_data)
@@ -71,7 +71,7 @@ flatline_epi_forecaster <- function(epi_data,
7171

7272
#' Flatline forecaster argument constructor
7373
#'
74-
#' Constructs a list of arguments for [flatline_epi_forecaster()].
74+
#' Constructs a list of arguments for [flatline_forecaster()].
7575
#'
7676
#' @inheritParams arx_args_list
7777
#'

R/grab_names.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#' be used to select a range of variables.
1313
#'
1414
#' @export
15+
#' @keywords internal
1516
#' @return a character vector
1617
#' @examples
1718
#' df <- data.frame(a = 1, b = 2, cc = rep(NA, 3))

0 commit comments

Comments
 (0)