Skip to content

Commit 230985d

Browse files
committed
move out all the old deprecated code from the package
1 parent 172ab1e commit 230985d

12 files changed

+140
-139
lines changed

R/arx_forecaster.R

Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,87 @@
1-
#' AR forecaster with optional covariates
1+
#' Direct autoregressive forecaster with covariates
22
#'
3-
#' @param x Covariates. Allowed to be missing (resulting in AR on `y`).
4-
#' @param y Response.
5-
#' @param key_vars Factor(s). A prediction will be made for each unique
6-
#' combination.
7-
#' @param time_value the time value associated with each row of measurements.
8-
#' @param args Additional arguments specifying the forecasting task. Created
9-
#' by calling `arx_args_list()`.
3+
#' This is an autoregressive forecasting model for
4+
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
5+
#' that it estimates a model for a particular target horizon.
106
#'
11-
#' @return A data frame of point (and optionally interval) forecasts at a single
12-
#' ahead (unique horizon) for each unique combination of `key_vars`.
7+
#'
8+
#' @param epi_data An `epi_df` object
9+
#' @param outcome A character (scalar) specifying the outcome (in the
10+
#' `epi_df`).
11+
#' @param predictors A character vector giving column(s) of predictor
12+
#' variables.
13+
#' @param trainer A `{parsnip}` model describing the type of estimation.
14+
#' For now, we enforce `mode = "regression"`.
15+
#' @param args_list A list of customization arguments to determine
16+
#' the type of forecasting model. See [arx_args_list()].
17+
#'
18+
#' @return A list with (1) `predictions` an `epi_df` of predicted values
19+
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
20+
#' workflow
1321
#' @export
14-
arx_forecaster <- function(x, y, key_vars, time_value,
15-
args = arx_args_list()) {
22+
#'
23+
#' @examples
24+
#' jhu <- case_death_rate_subset %>%
25+
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
26+
#'
27+
#' out <- arx_forecaster(jhu, "death_rate",
28+
#' c("case_rate", "death_rate"))
29+
arx_forecaster <- function(epi_data,
30+
outcome,
31+
predictors,
32+
trainer = parsnip::linear_reg(),
33+
args_list = arx_args_list()) {
1634

17-
# TODO: function to verify standard forecaster signature inputs
35+
validate_forecaster_inputs(epi_data, outcome, predictors)
36+
if (!is.list(trainer) || trainer$mode != "regression")
37+
cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.")
38+
lags <- arx_lags_validator(predictors, args_list$lags)
1839

19-
assign_arg_list(args)
20-
if (is.null(key_vars)) { # this is annoying/repetitive, seemingly necessary?
21-
keys <- NULL
22-
distinct_keys <- tibble(.dump = NA)
23-
} else {
24-
keys <- tibble::tibble(key_vars)
25-
distinct_keys <- dplyr::distinct(keys)
40+
r <- epi_recipe(epi_data)
41+
for (l in seq_along(lags)) {
42+
p <- predictors[l]
43+
r <- step_epi_lag(r, !!p, lag = lags[[l]])
2644
}
45+
r <- r %>%
46+
step_epi_ahead(dplyr::all_of(!!outcome), ahead = args_list$ahead) %>%
47+
step_epi_naomit()
48+
# should limit the training window here (in an open PR)
49+
# What to do if insufficient training data? Add issue.
2750

28-
# Return NA if insufficient training data
29-
if (length(y) < min_train_window + max_lags + ahead) {
30-
qnames <- probs_to_string(levels)
31-
out <- dplyr::bind_cols(distinct_keys, point = NA) %>%
32-
dplyr::select(!dplyr::any_of(".dump"))
33-
return(enframer(out, qnames))
34-
}
51+
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
52+
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
53+
f <- frosting() %>%
54+
layer_predict() %>%
55+
# layer_naomit(.pred) %>%
56+
layer_residual_quantiles(
57+
probs = args_list$levels,
58+
symmetrize = args_list$symmetrize) %>%
59+
layer_add_forecast_date(forecast_date = forecast_date) %>%
60+
layer_add_target_date(target_date = target_date)
61+
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
3562

36-
dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys)
37-
dat$x0 <- 1
63+
latest <- get_test_data(r, epi_data)
3864

39-
obj <- stats::lm(
40-
y1 ~ . + 0,
41-
data = dat %>% dplyr::select(starts_with(c("x", "y")))
65+
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
66+
list(
67+
predictions = predict(wf, new_data = latest),
68+
epi_workflow = wf
4269
)
70+
}
4371

44-
point <- make_predictions(obj, dat, time_value, keys)
45-
46-
# Residuals, simplest case, requires
47-
# 1. same quantiles for all keys
48-
# 2. `residuals(obj)` works
49-
r <- residuals(obj)
50-
q <- residual_quantiles(r, point, levels, symmetrize)
5172

52-
# Harder case requires handling failures of 1 and or 2, neither implemented
53-
# 1. different quantiles by key, need to bind the keys, then group_modify
54-
# 2 fails. need to bind the keys, grab, y and yhat, subtract
55-
if (nonneg) {
56-
q <- dplyr::mutate(q, dplyr::across(dplyr::everything(), ~ pmax(.x, 0)))
73+
arx_lags_validator <- function(predictors, lags) {
74+
p <- length(predictors)
75+
if (!is.list(lags)) lags <- list(lags)
76+
if (length(lags) == 1) lags <- rep(lags, p)
77+
else if (length(lags) < p) {
78+
cli_stop(
79+
"You have requested {p} predictors but lags cannot be recycled to match."
80+
)
5781
}
58-
59-
return(
60-
dplyr::bind_cols(distinct_keys, q) %>%
61-
dplyr::select(!dplyr::any_of(".dump"))
62-
)
82+
lags
6383
}
6484

65-
6685
#' ARX forecaster argument constructor
6786
#'
6887
#' Constructs a list of arguments for [arx_forecaster()].

R/arx_forecaster_mod.R

Lines changed: 0 additions & 83 deletions
This file was deleted.

musings/arx_forecaster_old.R

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#' AR forecaster with optional covariates
2+
#'
3+
#' @param x Covariates. Allowed to be missing (resulting in AR on `y`).
4+
#' @param y Response.
5+
#' @param key_vars Factor(s). A prediction will be made for each unique
6+
#' combination.
7+
#' @param time_value the time value associated with each row of measurements.
8+
#' @param args Additional arguments specifying the forecasting task. Created
9+
#' by calling `arx_args_list()`.
10+
#'
11+
#' @return A data frame of point (and optionally interval) forecasts at a single
12+
#' ahead (unique horizon) for each unique combination of `key_vars`.
13+
#' @export
14+
arx_forecaster <- function(x, y, key_vars, time_value,
15+
args = arx_args_list()) {
16+
17+
# TODO: function to verify standard forecaster signature inputs
18+
19+
assign_arg_list(args)
20+
if (is.null(key_vars)) { # this is annoying/repetitive, seemingly necessary?
21+
keys <- NULL
22+
distinct_keys <- tibble(.dump = NA)
23+
} else {
24+
keys <- tibble::tibble(key_vars)
25+
distinct_keys <- dplyr::distinct(keys)
26+
}
27+
28+
# Return NA if insufficient training data
29+
if (length(y) < min_train_window + max_lags + ahead) {
30+
qnames <- probs_to_string(levels)
31+
out <- dplyr::bind_cols(distinct_keys, point = NA) %>%
32+
dplyr::select(!dplyr::any_of(".dump"))
33+
return(enframer(out, qnames))
34+
}
35+
36+
dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys)
37+
dat$x0 <- 1
38+
39+
obj <- stats::lm(
40+
y1 ~ . + 0,
41+
data = dat %>% dplyr::select(starts_with(c("x", "y")))
42+
)
43+
44+
point <- make_predictions(obj, dat, time_value, keys)
45+
46+
# Residuals, simplest case, requires
47+
# 1. same quantiles for all keys
48+
# 2. `residuals(obj)` works
49+
r <- residuals(obj)
50+
q <- residual_quantiles(r, point, levels, symmetrize)
51+
52+
# Harder case requires handling failures of 1 and or 2, neither implemented
53+
# 1. different quantiles by key, need to bind the keys, then group_modify
54+
# 2 fails. need to bind the keys, grab, y and yhat, subtract
55+
if (nonneg) {
56+
q <- dplyr::mutate(q, dplyr::across(dplyr::everything(), ~ pmax(.x, 0)))
57+
}
58+
59+
return(
60+
dplyr::bind_cols(distinct_keys, q) %>%
61+
dplyr::select(!dplyr::any_of(".dump"))
62+
)
63+
}
64+
65+
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

vignettes/epipredict.Rmd

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ We'll estimate the model jointly across all locations using only the most recent
103103

104104
```{r demo-workflow}
105105
jhu <- jhu %>% filter(time_value >= max(time_value) - 30)
106-
out <- arx_epi_forecaster(jhu, outcome = "death_rate",
106+
out <- arx_forecaster(jhu, outcome = "death_rate",
107107
predictors = c("case_rate", "death_rate")
108108
)
109109
```
@@ -131,7 +131,7 @@ knitr::opts_chunk$set(warning = FALSE, message = FALSE)
131131
```
132132

133133
```{r differential-lags}
134-
out2week <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
134+
out2week <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
135135
args_list = arx_args_list(
136136
lags = list(c(0,1,2,3,7,14), c(0,7,14)),
137137
ahead = 14)
@@ -145,7 +145,7 @@ Here, we've used different lags on the `case_rate` and are now predicting 2 week
145145
Another property of the basic model is the predictive interval. We describe this in more detail in a different vignette, but it is easy to request multiple quantiles.
146146

147147
```{r differential-levels}
148-
out_q <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
148+
out_q <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
149149
args_list = arx_args_list(
150150
levels = c(.01,.025, seq(.05,.95, by=.05), .975,.99))
151151
)
@@ -183,14 +183,14 @@ The `trainer` argument determines the type of model we want.
183183
This takes a [`{parsnip}`](https://parsnip.tidymodels.org) model. The default is linear regression, but we could instead use a random forest with the `{ranger}` package:
184184

185185
```{r ranger, warning = FALSE}
186-
out_rf <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
186+
out_rf <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
187187
rand_forest(mode = "regression"))
188188
```
189189

190190
Or boosted regression trees with `{xgboost}`:
191191

192192
```{r xgboost, warning = FALSE}
193-
out_gb <- arx_epi_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
193+
out_gb <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"),
194194
boost_tree(mode = "regression", trees = 20))
195195
```
196196

@@ -290,7 +290,7 @@ To stretch the metaphor of preparing a cake to its natural limits, we have
290290
created postprocessing functionality called "frosting". Much like the recipe,
291291
each postprocessing operation is a "layer" and we "slather" these onto our
292292
baked cake. To fix ideas, below is the postprocessing `frosting` for
293-
`arx_epi_forecaster()`
293+
`arx_forecaster()`
294294

295295
```{r}
296296
extract_frosting(out_q$epi_workflow)

0 commit comments

Comments
 (0)