Skip to content

Commit 1de2d92

Browse files
authored
Merge pull request #180 from cmu-delphi/training-window
Training window
2 parents 9a80430 + 18033c9 commit 1de2d92

29 files changed

+350
-122
lines changed

DESCRIPTION

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.3.9999
3+
Version: 0.0.4
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -21,7 +21,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
2121
https://cmu-delphi.github.io/epipredict
2222
BugReports: https://github.com/cmu-delphi/epipredict/issues/
2323
Depends:
24-
epiprocess,
24+
epiprocess (>= 0.6.0),
2525
parsnip (>= 1.0.0),
2626
R (>= 3.5.0)
2727
Imports:
@@ -61,7 +61,7 @@ VignetteBuilder:
6161
knitr
6262
Remotes:
6363
cmu-delphi/epidatr,
64-
cmu-delphi/epiprocess
64+
cmu-delphi/epiprocess@dev
6565
Config/testthat/edition: 3
6666
Encoding: UTF-8
6767
LazyData: true

NAMESPACE

+7
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ S3method(prep,step_growth_rate)
4444
S3method(prep,step_lag_difference)
4545
S3method(prep,step_population_scaling)
4646
S3method(prep,step_training_window)
47+
S3method(print,arx_clist)
48+
S3method(print,arx_flist)
4749
S3method(print,epi_workflow)
50+
S3method(print,flatline_alist)
4851
S3method(print,frosting)
4952
S3method(print,layer_add_forecast_date)
5053
S3method(print,layer_add_target_date)
@@ -91,6 +94,7 @@ export(arx_args_list)
9194
export(arx_class_args_list)
9295
export(arx_classifier)
9396
export(arx_forecaster)
97+
export(bake)
9498
export(create_layer)
9599
export(default_epi_recipe_blueprint)
96100
export(detect_layer)
@@ -128,6 +132,7 @@ export(layer_threshold)
128132
export(nested_quantiles)
129133
export(new_default_epi_recipe_blueprint)
130134
export(new_epi_recipe_blueprint)
135+
export(prep)
131136
export(quantile_reg)
132137
export(remove_frosting)
133138
export(slather)
@@ -152,6 +157,8 @@ importFrom(hardhat,run_mold)
152157
importFrom(magrittr,"%>%")
153158
importFrom(methods,is)
154159
importFrom(quantreg,rq)
160+
importFrom(recipes,bake)
161+
importFrom(recipes,prep)
155162
importFrom(rlang,"!!")
156163
importFrom(rlang,":=")
157164
importFrom(rlang,`%||%`)

NEWS.md

+7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
# epipredict (development)
22

3+
4+
# epipredict 0.0.4
5+
36
* add quantile_reg()
47
* clean up documentation bugs
8+
* add smooth_quantile_reg()
9+
* add classifier
10+
* training window step debugged
11+
* `min_train_window` argument removed from canned forecasters
512

613
# epipredict 0.0.3
714

R/arx_classifier.R

+14-6
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ arx_classifier <- function(epi_data,
9797
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>%
9898
step_mutate(outcome_class = cut(!!o2, breaks = args_list$breaks),
9999
role = "outcome") %>%
100-
step_epi_naomit()
100+
step_epi_naomit() %>%
101+
step_training_window(n_recent = args_list$n_training)
101102

102103
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
103104
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
@@ -152,7 +153,7 @@ arx_classifier <- function(epi_data,
152153
#' calculation. See [epiprocess::growth_rate()] and the related Vignette for
153154
#' more details.
154155
#'
155-
#' @return A list containing updated parameter choices with class `arx_alist`.
156+
#' @return A list containing updated parameter choices with class `arx_clist`.
156157
#' @export
157158
#'
158159
#' @examples
@@ -164,7 +165,7 @@ arx_classifier <- function(epi_data,
164165
arx_class_args_list <- function(
165166
lags = c(0L, 7L, 14L),
166167
ahead = 7L,
167-
min_train_window = 20L,
168+
n_training = Inf,
168169
forecast_date = NULL,
169170
target_date = NULL,
170171
outcome_transform = c("growth_rate", "lag_difference"),
@@ -180,12 +181,14 @@ arx_class_args_list <- function(
180181
method <- match.arg(method)
181182
outcome_transform <- match.arg(outcome_transform)
182183

183-
arg_is_scalar(ahead, min_train_window, horizon, log_scale)
184+
arg_is_scalar(ahead, n_training, horizon, log_scale)
184185
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
185186
arg_is_date(forecast_date, target_date, allow_null = TRUE)
186-
arg_is_nonneg_int(ahead, min_train_window, lags, horizon)
187+
arg_is_nonneg_int(ahead, lags, horizon)
187188
arg_is_numeric(breaks)
188189
arg_is_lgl(log_scale)
190+
arg_is_pos(n_training)
191+
if (is.finite(n_training)) arg_is_pos_int(n_training)
189192
if (!is.list(additional_gr_args)) {
190193
rlang::abort(
191194
c("`additional_gr_args` must be a list.",
@@ -202,7 +205,7 @@ arx_class_args_list <- function(
202205
structure(
203206
enlist(lags = .lags,
204207
ahead,
205-
min_train_window,
208+
n_training,
206209
breaks,
207210
forecast_date,
208211
target_date,
@@ -216,3 +219,8 @@ arx_class_args_list <- function(
216219
class = "arx_clist"
217220
)
218221
}
222+
223+
#' @export
224+
print.arx_clist <- function(x, ...) {
225+
utils::str(x)
226+
}

R/arx_forecaster.R

+20-14
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ arx_forecaster <- function(epi_data,
3838

3939
# --- validation
4040
validate_forecaster_inputs(epi_data, outcome, predictors)
41-
if (!inherits(args_list, "arx_alist"))
41+
if (!inherits(args_list, "arx_flist"))
4242
cli_stop("args_list was not created using `arx_args_list().")
4343
if (!is_regression(trainer))
4444
cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.")
@@ -52,9 +52,8 @@ arx_forecaster <- function(epi_data,
5252
}
5353
r <- r %>%
5454
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
55-
step_epi_naomit()
56-
# should limit the training window here (in an open PR)
57-
# What to do if insufficient training data? Add issue.
55+
step_epi_naomit() %>%
56+
step_training_window(n_recent = args_list$n_training)
5857

5958
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
6059
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
@@ -105,9 +104,9 @@ arx_lags_validator <- function(predictors, lags) {
105104
#' in autoregressive-type models (in days).
106105
#' @param ahead Integer. Number of time steps ahead (in days) of the forecast
107106
#' date for which forecasts should be produced.
108-
#' @param min_train_window Integer. The minimal amount of training
109-
#' data (in the time unit of the `epi_df`) needed to produce a forecast.
110-
#' If smaller, the forecaster will return `NA` predictions.
107+
#' @param n_training Integer. An upper limit for the number of rows per
108+
#' key that are used for training
109+
#' (in the time unit of the `epi_df`).
111110
#' @param forecast_date Date. The date on which the forecast is created.
112111
#' The default `NULL` will attempt to determine this automatically.
113112
#' @param target_date Date. The date for which the forecast is intended.
@@ -124,16 +123,16 @@ arx_lags_validator <- function(predictors, lags) {
124123
#' [layer_residual_quantiles()] for more information. The default,
125124
#' `character(0)` performs no grouping.
126125
#'
127-
#' @return A list containing updated parameter choices with class `arx_alist`.
126+
#' @return A list containing updated parameter choices with class `arx_flist`.
128127
#' @export
129128
#'
130129
#' @examples
131130
#' arx_args_list()
132131
#' arx_args_list(symmetrize = FALSE)
133-
#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120)
132+
#' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
134133
arx_args_list <- function(lags = c(0L, 7L, 14L),
135134
ahead = 7L,
136-
min_train_window = 20L,
135+
n_training = Inf,
137136
forecast_date = NULL,
138137
target_date = NULL,
139138
levels = c(0.05, 0.95),
@@ -145,24 +144,31 @@ arx_args_list <- function(lags = c(0L, 7L, 14L),
145144
.lags <- lags
146145
if (is.list(lags)) lags <- unlist(lags)
147146

148-
arg_is_scalar(ahead, min_train_window, symmetrize, nonneg)
147+
arg_is_scalar(ahead, n_training, symmetrize, nonneg)
149148
arg_is_chr(quantile_by_key, allow_null = TRUE)
150149
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
151150
arg_is_date(forecast_date, target_date, allow_null = TRUE)
152-
arg_is_nonneg_int(ahead, min_train_window, lags)
151+
arg_is_nonneg_int(ahead, lags)
153152
arg_is_lgl(symmetrize, nonneg)
154153
arg_is_probabilities(levels, allow_null = TRUE)
154+
arg_is_pos(n_training)
155+
if (is.finite(n_training)) arg_is_pos_int(n_training)
155156

156157
max_lags <- max(lags)
157158
structure(enlist(lags = .lags,
158159
ahead,
159-
min_train_window,
160+
n_training,
160161
levels,
161162
forecast_date,
162163
target_date,
163164
symmetrize,
164165
nonneg,
165166
max_lags,
166167
quantile_by_key),
167-
class = "arx_alist")
168+
class = "arx_flist")
169+
}
170+
171+
#' @export
172+
print.arx_flist <- function(x, ...) {
173+
utils::str(x)
168174
}

R/epi_check_training_set.R

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
epi_check_training_set <- function(x, rec) {
2+
# Philosophy, allow the model to be fit with warnings, whenever possible.
3+
# If geo_type / time_type of the recipe and training data don't match
4+
# we proceed and warn.
5+
# If other_keys is missing from the training set, there are other issues.
6+
validate_meta_match(x, rec$template, "geo_type", "warn")
7+
validate_meta_match(x, rec$template, "time_type", "warn")
8+
9+
# There are 3 possibilities.
10+
# 1. template has ok that are in x, but not labelled
11+
# 2. template has ok that are not in x
12+
# 3. x has ok that are not in template. Not a problem.
13+
old_ok <- attr(rec$template, "metadata")$other_keys
14+
new_ok <- attr(x, "metadata")$other_keys
15+
16+
if (!is.null(old_ok)) {
17+
if (all(old_ok %in% colnames(x))) { # case 1
18+
if (!all(old_ok %in% new_ok)) {
19+
cli::cli_warn(c(
20+
"The recipe specifies additional keys. Because these are available,",
21+
"they are being added to the metadata of the training data."
22+
))
23+
attr(x, "metadata")$other_keys <- union(new_ok, old_ok)
24+
}
25+
}
26+
missing_ok <- setdiff(old_ok, colnames(x))
27+
if (length(missing_ok) > 0) { # case 2
28+
cli::cli_abort(c(
29+
"The recipe specifies keys which are not in the training data.",
30+
i = "The training set is missing columns for {missing_ok}."
31+
))
32+
}
33+
}
34+
x
35+
}
36+
37+
validate_meta_match <- function(x, template, meta, warn_or_abort = "warn") {
38+
new_meta <- attr(x, "metadata")[[meta]]
39+
old_meta <- attr(template, "metadata")[[meta]]
40+
msg <- c(
41+
"The `{meta}` of the training data appears to be different from that",
42+
"used to construct the recipe. This may result in unexpected consequences.",
43+
i = "Training `geo_type` is '{new_meta}'.",
44+
i = "Originally, it was '{old_meta}'."
45+
)
46+
if (new_meta != old_meta) {
47+
switch(warn_or_abort,
48+
warn = cli::cli_warn(msg),
49+
abort = cli::cli_abort(msg)
50+
)
51+
}
52+
}

R/epi_keys.R

+4
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ epi_keys_mold <- function(mold) {
3333
unname(unlist(mold_keys))
3434
}
3535

36+
kill_time_value <- function(v) {
37+
arg_is_chr(v)
38+
v[v != "time_value"]
39+
}

R/epi_recipe.R

+14-3
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ epi_recipe.formula <- function(formula, data, ...) {
146146
# we ensure that there's only 1 row in the template
147147
data <- data[1,]
148148
# check for minus:
149-
if (! epiprocess::is_epi_df(data)) {
149+
if (!epiprocess::is_epi_df(data)) {
150150
return(recipes::recipe(formula, data, ...))
151151
}
152152

@@ -280,14 +280,17 @@ add_epi_recipe <- function(
280280

281281

282282

283-
# unfortunately, everything the same as in prep.recipe except string/fctr handling
283+
# unfortunately, almost everything the same as in prep.recipe except string/fctr handling
284284
#' @export
285285
prep.epi_recipe <- function(
286286
x, training = NULL, fresh = FALSE, verbose = FALSE,
287287
retain = TRUE, log_changes = FALSE, strings_as_factors = TRUE, ...) {
288288
training <- recipes:::check_training_set(training, x, fresh)
289+
training <- epi_check_training_set(training, x)
290+
training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training)))
289291
tr_data <- recipes:::train_info(training)
290-
keys <- epi_keys(training)
292+
keys <- epi_keys(x)
293+
291294
orig_lvls <- lapply(training, recipes:::get_levels)
292295
orig_lvls <- kill_levels(orig_lvls, keys)
293296
if (strings_as_factors) {
@@ -322,12 +325,20 @@ prep.epi_recipe <- function(
322325
cat(note, "[training]", "\n")
323326
}
324327
before_nms <- names(training)
328+
before_template <- training[1, ]
325329
x$steps[[i]] <- prep(x$steps[[i]], training = training,
326330
info = x$term_info)
327331
training <- bake(x$steps[[i]], new_data = training)
328332
if (!tibble::is_tibble(training)) {
329333
abort("bake() methods should always return tibbles")
330334
}
335+
if (!is_epi_df(training)) {
336+
# tidymodels killed our class
337+
# for now, we only allow step_epi_* to alter the metadata
338+
training <- dplyr::dplyr_reconstruct(
339+
epiprocess::as_epi_df(training), before_template)
340+
}
341+
training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training)))
331342
x$term_info <- recipes:::merge_term_info(get_types(training), x$term_info)
332343
if (!is.na(x$steps[[i]]$role)) {
333344
new_vars <- setdiff(x$term_info$variable, running_info$variable)

R/epi_shift.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@ epi_shift_single <- function(x, col, shift_val, newname, key_cols) {
3636
x %>%
3737
dplyr::select(tidyselect::all_of(c(key_cols, col))) %>%
3838
dplyr::mutate(time_value = time_value + shift_val) %>%
39-
dplyr::rename(!!newname := col)
39+
dplyr::rename(!!newname := {{ col }})
4040
}

R/flatline.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ flatline <- function(formula, data) {
4343
ek <- rhs[-n]
4444
if (length(response) > 1)
4545
cli_stop("flatline forecaster can accept only 1 observed time series.")
46-
keys <- ek[ek != "time_value"]
46+
keys <- kill_time_value(ek)
4747

4848
preds <- data %>%
4949
dplyr::mutate(.pred = !!rlang::sym(observed),
@@ -54,7 +54,7 @@ flatline <- function(formula, data) {
5454
dplyr::arrange(time_value) %>%
5555
dplyr::slice_tail(n = 1L) %>%
5656
dplyr::ungroup() %>%
57-
dplyr::select(dplyr::all_of(c(keys, ".pred")))
57+
dplyr::select(tidyselect::all_of(c(keys, ".pred")))
5858

5959
structure(list(
6060
residuals = dplyr::select(preds, dplyr::all_of(c(keys, ".resid"))),

0 commit comments

Comments
 (0)