Skip to content

Commit 9d84f6a

Browse files
authored
Merge pull request #187 from cmu-delphi/djm/unfit-wf
add functions to output an unfit classifier/forecaster workflow
2 parents 3c90302 + 7b4c39d commit 9d84f6a

15 files changed

+417
-145
lines changed

NAMESPACE

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@ S3method(prep,step_growth_rate)
4444
S3method(prep,step_lag_difference)
4545
S3method(prep,step_population_scaling)
4646
S3method(prep,step_training_window)
47-
S3method(print,arx_clist)
48-
S3method(print,arx_flist)
47+
S3method(print,alist)
48+
S3method(print,arx_class)
49+
S3method(print,arx_fcast)
50+
S3method(print,canned_epipred)
4951
S3method(print,epi_workflow)
50-
S3method(print,flatline_alist)
52+
S3method(print,flatline)
5153
S3method(print,frosting)
5254
S3method(print,layer_add_forecast_date)
5355
S3method(print,layer_add_target_date)
@@ -92,13 +94,14 @@ export(add_layer)
9294
export(apply_frosting)
9395
export(arx_args_list)
9496
export(arx_class_args_list)
97+
export(arx_class_epi_workflow)
9598
export(arx_classifier)
99+
export(arx_fcast_epi_workflow)
96100
export(arx_forecaster)
97101
export(bake)
98102
export(create_layer)
99103
export(default_epi_recipe_blueprint)
100104
export(detect_layer)
101-
export(df_mat_mul)
102105
export(dist_quantiles)
103106
export(epi_keys)
104107
export(epi_recipe)

R/arx_classifier.R

Lines changed: 89 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
2424
#' workflow
2525
#' @export
26+
#' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
2627
#'
2728
#' @examples
2829
#' jhu <- case_death_rate_subset %>%
@@ -40,18 +41,87 @@
4041
#' horizon = 14, method = "linear_reg"
4142
#' )
4243
#' )
43-
arx_classifier <- function(epi_data,
44-
outcome,
45-
predictors,
46-
trainer = parsnip::logistic_reg(),
47-
args_list = arx_class_args_list()) {
44+
arx_classifier <- function(
45+
epi_data,
46+
outcome,
47+
predictors,
48+
trainer = parsnip::logistic_reg(),
49+
args_list = arx_class_args_list()) {
4850

49-
# --- validation
50-
validate_forecaster_inputs(epi_data, outcome, predictors)
51-
if (!inherits(args_list, "arx_clist"))
52-
cli_stop("args_list was not created using `arx_class_args_list().")
5351
if (!is_classification(trainer))
54-
cli_stop("{trainer} must be a `parsnip` method of mode 'classification'.")
52+
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
53+
54+
wf <- arx_class_epi_workflow(
55+
epi_data, outcome, predictors, trainer, args_list
56+
)
57+
58+
latest <- get_test_data(
59+
workflows::extract_preprocessor(wf), epi_data, TRUE
60+
)
61+
62+
wf <- generics::fit(wf, epi_data)
63+
preds <- predict(wf, new_data = latest) %>%
64+
tibble::as_tibble() %>%
65+
dplyr::select(-time_value)
66+
67+
structure(list(
68+
predictions = preds,
69+
epi_workflow = wf,
70+
metadata = list(
71+
training = attr(epi_data, "metadata"),
72+
forecast_created = Sys.time()
73+
)),
74+
class = c("arx_class", "canned_epipred")
75+
)
76+
}
77+
78+
79+
#' Create a template `arx_classifier` workflow
80+
#'
81+
#' This function creates an unfit workflow for use with [arx_classifier()].
82+
#' It is useful if you want to make small modifications to that classifier
83+
#' before fitting and predicting. Supplying a trainer to the function
84+
#' may alter the returned `epi_workflow` object but can be omitted.
85+
#'
86+
#' @inheritParams arx_classifier
87+
#' @param trainer A `{parsnip}` model describing the type of estimation.
88+
#' For now, we enforce `mode = "classification"`. Typical values are
89+
#' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
90+
#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can
91+
#' also be used. May be `NULL` (the default).
92+
#'
93+
#' @return An unfit `epi_workflow`.
94+
#' @export
95+
#' @seealso [arx_classifier()]
96+
#' @examples
97+
#'
98+
#' jhu <- case_death_rate_subset %>%
99+
#' dplyr::filter(time_value >= as.Date("2021-11-01"))
100+
#'
101+
#' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
102+
#'
103+
#' arx_class_epi_workflow(
104+
#' jhu,
105+
#' "death_rate",
106+
#' c("case_rate", "death_rate"),
107+
#' trainer = parsnip::multinom_reg(),
108+
#' args_list = arx_class_args_list(
109+
#' breaks = c(-.05, .1), ahead = 14,
110+
#' horizon = 14, method = "linear_reg"
111+
#' )
112+
#' )
113+
arx_class_epi_workflow <- function(
114+
epi_data,
115+
outcome,
116+
predictors,
117+
trainer = NULL,
118+
args_list = arx_class_args_list()) {
119+
120+
validate_forecaster_inputs(epi_data, outcome, predictors)
121+
if (!inherits(args_list, c("arx_class", "alist")))
122+
rlang::abort("args_list was not created using `arx_class_args_list().")
123+
if (!(is.null(trainer) || is_classification(trainer)))
124+
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
55125
lags <- arx_lags_validator(predictors, args_list$lags)
56126

57127
# --- preprocessor
@@ -114,18 +184,9 @@ arx_classifier <- function(epi_data,
114184
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
115185
layer_add_target_date(target_date = target_date)
116186

117-
118-
# --- create test data, fit, and return
119-
latest <- get_test_data(r, epi_data, TRUE)
120-
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
121-
list(
122-
predictions = predict(wf, new_data = latest),
123-
epi_workflow = wf
124-
)
187+
epi_workflow(r, trainer, f)
125188
}
126189

127-
128-
129190
#' ARX classifier argument constructor
130191
#'
131192
#' Constructs a list of arguments for [arx_classifier()].
@@ -199,8 +260,9 @@ arx_class_args_list <- function(
199260
arg_is_pos(n_training)
200261
if (is.finite(n_training)) arg_is_pos_int(n_training)
201262
if (!is.list(additional_gr_args)) {
202-
rlang::abort(
203-
c("`additional_gr_args` must be a list.",
263+
cli::cli_abort(
264+
c("`additional_gr_args` must be a {.cls list}.",
265+
"!" = "This is a {.cls {class(additional_gr_args)}}.",
204266
i = "See `?epiprocess::growth_rate` for available arguments.")
205267
)
206268
}
@@ -225,11 +287,13 @@ arx_class_args_list <- function(
225287
log_scale,
226288
additional_gr_args
227289
),
228-
class = "arx_clist"
290+
class = c("arx_class", "alist")
229291
)
230292
}
231293

232294
#' @export
233-
print.arx_clist <- function(x, ...) {
234-
utils::str(x)
295+
print.arx_class <- function(x, ...) {
296+
name <- "ARX Classifier"
297+
NextMethod(name = name, ...)
235298
}
299+

R/arx_forecaster.R

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
2020
#' workflow
2121
#' @export
22+
#' @seealso [arx_fcast_epi_workflow()], [arx_args_list()]
2223
#'
2324
#' @examples
2425
#' jhu <- case_death_rate_subset %>%
@@ -36,12 +37,72 @@ arx_forecaster <- function(epi_data,
3637
trainer = parsnip::linear_reg(),
3738
args_list = arx_args_list()) {
3839

40+
if (!is_regression(trainer))
41+
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
42+
43+
wf <- arx_fcast_epi_workflow(
44+
epi_data, outcome, predictors, trainer, args_list
45+
)
46+
47+
latest <- get_test_data(
48+
workflows::extract_preprocessor(wf), epi_data, TRUE
49+
)
50+
51+
wf <- generics::fit(wf, epi_data)
52+
preds <- predict(wf, new_data = latest) %>%
53+
tibble::as_tibble() %>%
54+
dplyr::select(-time_value)
55+
56+
structure(list(
57+
predictions = preds,
58+
epi_workflow = wf,
59+
metadata = list(
60+
training = attr(epi_data, "metadata"),
61+
forecast_created = Sys.time()
62+
)),
63+
class = c("arx_fcast", "canned_epipred")
64+
)
65+
}
66+
67+
#' Create a template `arx_forecaster` workflow
68+
#'
69+
#' This function creates an unfit workflow for use with [arx_forecaster()].
70+
#' It is useful if you want to make small modifications to that forecaster
71+
#' before fitting and predicting. Supplying a trainer to the function
72+
#' may alter the returned `epi_workflow` object (e.g., if you intend to
73+
#' use [quantile_reg()]) but can be omitted.
74+
#'
75+
#' @inheritParams arx_forecaster
76+
#' @param trainer A `{parsnip}` model describing the type of estimation.
77+
#' For now, we enforce `mode = "regression"`. May be `NULL` (the default).
78+
#'
79+
#' @return An unfitted `epi_workflow`.
80+
#' @export
81+
#' @seealso [arx_forecaster()]
82+
#'
83+
#' @examples
84+
#' jhu <- case_death_rate_subset %>%
85+
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
86+
#'
87+
#' arx_fcast_epi_workflow(jhu, "death_rate",
88+
#' c("case_rate", "death_rate"))
89+
#'
90+
#' arx_fcast_epi_workflow(jhu, "death_rate",
91+
#' c("case_rate", "death_rate"), trainer = quantile_reg(),
92+
#' args_list = arx_args_list(levels = 1:9 / 10))
93+
arx_fcast_epi_workflow <- function(
94+
epi_data,
95+
outcome,
96+
predictors,
97+
trainer = NULL,
98+
args_list = arx_args_list()) {
99+
39100
# --- validation
40101
validate_forecaster_inputs(epi_data, outcome, predictors)
41-
if (!inherits(args_list, "arx_flist"))
42-
cli_stop("args_list was not created using `arx_args_list().")
43-
if (!is_regression(trainer))
44-
cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.")
102+
if (!inherits(args_list, c("arx_fcast", "alist")))
103+
cli::cli_abort("args_list was not created using `arx_args_list().")
104+
if (!(is.null(trainer) || is_regression(trainer)))
105+
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
45106
lags <- arx_lags_validator(predictors, args_list$lags)
46107

47108
# --- preprocessor
@@ -78,28 +139,10 @@ arx_forecaster <- function(epi_data,
78139
layer_add_target_date(target_date = target_date)
79140
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
80141

81-
# --- create test data, fit, and return
82-
latest <- get_test_data(r, epi_data, TRUE)
83-
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
84-
list(
85-
predictions = predict(wf, new_data = latest),
86-
epi_workflow = wf
87-
)
142+
epi_workflow(r, trainer, f)
88143
}
89144

90145

91-
arx_lags_validator <- function(predictors, lags) {
92-
p <- length(predictors)
93-
if (!is.list(lags)) lags <- list(lags)
94-
if (length(lags) == 1) lags <- rep(lags, p)
95-
else if (length(lags) < p) {
96-
cli_stop(
97-
"You have requested {p} predictors but lags cannot be recycled to match."
98-
)
99-
}
100-
lags
101-
}
102-
103146
#' ARX forecaster argument constructor
104147
#'
105148
#' Constructs a list of arguments for [arx_forecaster()].
@@ -138,15 +181,16 @@ arx_lags_validator <- function(predictors, lags) {
138181
#' arx_args_list()
139182
#' arx_args_list(symmetrize = FALSE)
140183
#' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
141-
arx_args_list <- function(lags = c(0L, 7L, 14L),
142-
ahead = 7L,
143-
n_training = Inf,
144-
forecast_date = NULL,
145-
target_date = NULL,
146-
levels = c(0.05, 0.95),
147-
symmetrize = TRUE,
148-
nonneg = TRUE,
149-
quantile_by_key = character(0L)) {
184+
arx_args_list <- function(
185+
lags = c(0L, 7L, 14L),
186+
ahead = 7L,
187+
n_training = Inf,
188+
forecast_date = NULL,
189+
target_date = NULL,
190+
levels = c(0.05, 0.95),
191+
symmetrize = TRUE,
192+
nonneg = TRUE,
193+
quantile_by_key = character(0L)) {
150194

151195
# error checking if lags is a list
152196
.lags <- lags
@@ -163,22 +207,26 @@ arx_args_list <- function(lags = c(0L, 7L, 14L),
163207
if (is.finite(n_training)) arg_is_pos_int(n_training)
164208

165209
max_lags <- max(lags)
166-
structure(enlist(lags = .lags,
167-
ahead,
168-
n_training,
169-
levels,
170-
forecast_date,
171-
target_date,
172-
symmetrize,
173-
nonneg,
174-
max_lags,
175-
quantile_by_key),
176-
class = "arx_flist")
210+
structure(
211+
enlist(lags = .lags,
212+
ahead,
213+
n_training,
214+
levels,
215+
forecast_date,
216+
target_date,
217+
symmetrize,
218+
nonneg,
219+
max_lags,
220+
quantile_by_key),
221+
class = c("arx_fcast", "alist")
222+
)
177223
}
178224

225+
179226
#' @export
180-
print.arx_flist <- function(x, ...) {
181-
utils::str(x)
227+
print.arx_fcast <- function(x, ...) {
228+
name <- "ARX Forecaster"
229+
NextMethod(name = name, ...)
182230
}
183231

184232
compare_quantile_args <- function(alist, tlist) {

0 commit comments

Comments
 (0)