19
19
# ' and (2) `epi_workflow`, a list that encapsulates the entire estimation
20
20
# ' workflow
21
21
# ' @export
22
+ # ' @seealso [arx_fcast_epi_workflow()], [arx_args_list()]
22
23
# '
23
24
# ' @examples
24
25
# ' jhu <- case_death_rate_subset %>%
@@ -36,12 +37,72 @@ arx_forecaster <- function(epi_data,
36
37
trainer = parsnip :: linear_reg(),
37
38
args_list = arx_args_list()) {
38
39
40
+ if (! is_regression(trainer ))
41
+ cli :: cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'regression'." )
42
+
43
+ wf <- arx_fcast_epi_workflow(
44
+ epi_data , outcome , predictors , trainer , args_list
45
+ )
46
+
47
+ latest <- get_test_data(
48
+ workflows :: extract_preprocessor(wf ), epi_data , TRUE
49
+ )
50
+
51
+ wf <- generics :: fit(wf , epi_data )
52
+ preds <- predict(wf , new_data = latest ) %> %
53
+ tibble :: as_tibble() %> %
54
+ dplyr :: select(- time_value )
55
+
56
+ structure(list (
57
+ predictions = preds ,
58
+ epi_workflow = wf ,
59
+ metadata = list (
60
+ training = attr(epi_data , " metadata" ),
61
+ forecast_created = Sys.time()
62
+ )),
63
+ class = c(" arx_fcast" , " canned_epipred" )
64
+ )
65
+ }
66
+
67
+ # ' Create a template `arx_forecaster` workflow
68
+ # '
69
+ # ' This function creates an unfit workflow for use with [arx_forecaster()].
70
+ # ' It is useful if you want to make small modifications to that forecaster
71
+ # ' before fitting and predicting. Supplying a trainer to the function
72
+ # ' may alter the returned `epi_workflow` object (e.g., if you intend to
73
+ # ' use [quantile_reg()]) but can be omitted.
74
+ # '
75
+ # ' @inheritParams arx_forecaster
76
+ # ' @param trainer A `{parsnip}` model describing the type of estimation.
77
+ # ' For now, we enforce `mode = "regression"`. May be `NULL` (the default).
78
+ # '
79
+ # ' @return An unfitted `epi_workflow`.
80
+ # ' @export
81
+ # ' @seealso [arx_forecaster()]
82
+ # '
83
+ # ' @examples
84
+ # ' jhu <- case_death_rate_subset %>%
85
+ # ' dplyr::filter(time_value >= as.Date("2021-12-01"))
86
+ # '
87
+ # ' arx_fcast_epi_workflow(jhu, "death_rate",
88
+ # ' c("case_rate", "death_rate"))
89
+ # '
90
+ # ' arx_fcast_epi_workflow(jhu, "death_rate",
91
+ # ' c("case_rate", "death_rate"), trainer = quantile_reg(),
92
+ # ' args_list = arx_args_list(levels = 1:9 / 10))
93
+ arx_fcast_epi_workflow <- function (
94
+ epi_data ,
95
+ outcome ,
96
+ predictors ,
97
+ trainer = NULL ,
98
+ args_list = arx_args_list()) {
99
+
39
100
# --- validation
40
101
validate_forecaster_inputs(epi_data , outcome , predictors )
41
- if (! inherits(args_list , " arx_flist " ))
42
- cli_stop (" args_list was not created using `arx_args_list()." )
43
- if (! is_regression(trainer ))
44
- cli_stop (" {trainer} must be a `parsnip` method of mode 'regression'." )
102
+ if (! inherits(args_list , c( " arx_fcast " , " alist " ) ))
103
+ cli :: cli_abort (" args_list was not created using `arx_args_list()." )
104
+ if (! (is.null( trainer ) || is_regression(trainer ) ))
105
+ cli :: cli_abort (" {trainer} must be a `{ parsnip}` model of mode 'regression'." )
45
106
lags <- arx_lags_validator(predictors , args_list $ lags )
46
107
47
108
# --- preprocessor
@@ -78,28 +139,10 @@ arx_forecaster <- function(epi_data,
78
139
layer_add_target_date(target_date = target_date )
79
140
if (args_list $ nonneg ) f <- layer_threshold(f , dplyr :: starts_with(" .pred" ))
80
141
81
- # --- create test data, fit, and return
82
- latest <- get_test_data(r , epi_data , TRUE )
83
- wf <- epi_workflow(r , trainer , f ) %> % generics :: fit(epi_data )
84
- list (
85
- predictions = predict(wf , new_data = latest ),
86
- epi_workflow = wf
87
- )
142
+ epi_workflow(r , trainer , f )
88
143
}
89
144
90
145
91
- arx_lags_validator <- function (predictors , lags ) {
92
- p <- length(predictors )
93
- if (! is.list(lags )) lags <- list (lags )
94
- if (length(lags ) == 1 ) lags <- rep(lags , p )
95
- else if (length(lags ) < p ) {
96
- cli_stop(
97
- " You have requested {p} predictors but lags cannot be recycled to match."
98
- )
99
- }
100
- lags
101
- }
102
-
103
146
# ' ARX forecaster argument constructor
104
147
# '
105
148
# ' Constructs a list of arguments for [arx_forecaster()].
@@ -138,15 +181,16 @@ arx_lags_validator <- function(predictors, lags) {
138
181
# ' arx_args_list()
139
182
# ' arx_args_list(symmetrize = FALSE)
140
183
# ' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
141
- arx_args_list <- function (lags = c(0L , 7L , 14L ),
142
- ahead = 7L ,
143
- n_training = Inf ,
144
- forecast_date = NULL ,
145
- target_date = NULL ,
146
- levels = c(0.05 , 0.95 ),
147
- symmetrize = TRUE ,
148
- nonneg = TRUE ,
149
- quantile_by_key = character (0L )) {
184
+ arx_args_list <- function (
185
+ lags = c(0L , 7L , 14L ),
186
+ ahead = 7L ,
187
+ n_training = Inf ,
188
+ forecast_date = NULL ,
189
+ target_date = NULL ,
190
+ levels = c(0.05 , 0.95 ),
191
+ symmetrize = TRUE ,
192
+ nonneg = TRUE ,
193
+ quantile_by_key = character (0L )) {
150
194
151
195
# error checking if lags is a list
152
196
.lags <- lags
@@ -163,22 +207,26 @@ arx_args_list <- function(lags = c(0L, 7L, 14L),
163
207
if (is.finite(n_training )) arg_is_pos_int(n_training )
164
208
165
209
max_lags <- max(lags )
166
- structure(enlist(lags = .lags ,
167
- ahead ,
168
- n_training ,
169
- levels ,
170
- forecast_date ,
171
- target_date ,
172
- symmetrize ,
173
- nonneg ,
174
- max_lags ,
175
- quantile_by_key ),
176
- class = " arx_flist" )
210
+ structure(
211
+ enlist(lags = .lags ,
212
+ ahead ,
213
+ n_training ,
214
+ levels ,
215
+ forecast_date ,
216
+ target_date ,
217
+ symmetrize ,
218
+ nonneg ,
219
+ max_lags ,
220
+ quantile_by_key ),
221
+ class = c(" arx_fcast" , " alist" )
222
+ )
177
223
}
178
224
225
+
179
226
# ' @export
180
- print.arx_flist <- function (x , ... ) {
181
- utils :: str(x )
227
+ print.arx_fcast <- function (x , ... ) {
228
+ name <- " ARX Forecaster"
229
+ NextMethod(name = name , ... )
182
230
}
183
231
184
232
compare_quantile_args <- function (alist , tlist ) {
0 commit comments