5
5
# ' that it estimates a class at a particular target horizon.
6
6
# '
7
7
# ' @inheritParams arx_forecaster
8
+ # ' @param outcome A character (scalar) specifying the outcome (in the
9
+ # ' `epi_df`). Note that as with [arx_forecaster()], this is expected to
10
+ # ' be real-valued. Conversion of this data to unordered classes is handled
11
+ # ' internally based on the `breaks` argument to [arx_class_args_list()].
12
+ # ' If discrete classes are already in the `epi_df`, it is recommended to
13
+ # ' code up a classifier from scratch using [epi_recipe()].
8
14
# ' @param trainer A `{parsnip}` model describing the type of estimation.
9
15
# ' For now, we enforce `mode = "classification"`. Typical values are
10
16
# ' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
11
17
# ' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can
12
18
# ' also be used.
13
19
# ' @param args_list A list of customization arguments to determine
14
- # ' the type of forecasting model. See [arx_args_list ()].
20
+ # ' the type of forecasting model. See [arx_class_args_list ()].
15
21
# '
16
22
# ' @return A list with (1) `predictions` an `epi_df` of predicted classes
17
23
# ' and (2) `epi_workflow`, a list that encapsulates the entire estimation
18
24
# ' workflow
19
25
# ' @export
26
+ # ' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
20
27
# '
21
28
# ' @examples
22
29
# ' jhu <- case_death_rate_subset %>%
34
41
# ' horizon = 14, method = "linear_reg"
35
42
# ' )
36
43
# ' )
37
- arx_classifier <- function (epi_data ,
38
- outcome ,
39
- predictors ,
40
- trainer = parsnip :: logistic_reg(),
41
- args_list = arx_class_args_list()) {
44
+ arx_classifier <- function (
45
+ epi_data ,
46
+ outcome ,
47
+ predictors ,
48
+ trainer = parsnip :: logistic_reg(),
49
+ args_list = arx_class_args_list()) {
42
50
43
- # --- validation
44
- validate_forecaster_inputs(epi_data , outcome , predictors )
45
- if (! inherits(args_list , " arx_clist" ))
46
- cli_stop(" args_list was not created using `arx_class_args_list()." )
47
51
if (! is_classification(trainer ))
48
- cli_stop(" {trainer} must be a `parsnip` method of mode 'classification'." )
52
+ cli :: cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
53
+
54
+ wf <- arx_class_epi_workflow(
55
+ epi_data , outcome , predictors , trainer , args_list
56
+ )
57
+
58
+ latest <- get_test_data(
59
+ workflows :: extract_preprocessor(wf ), epi_data , TRUE
60
+ )
61
+
62
+ wf <- generics :: fit(wf , epi_data )
63
+ preds <- predict(wf , new_data = latest ) %> %
64
+ tibble :: as_tibble() %> %
65
+ dplyr :: select(- time_value )
66
+
67
+ structure(list (
68
+ predictions = preds ,
69
+ epi_workflow = wf ,
70
+ metadata = list (
71
+ training = attr(epi_data , " metadata" ),
72
+ forecast_created = Sys.time()
73
+ )),
74
+ class = c(" arx_class" , " canned_epipred" )
75
+ )
76
+ }
77
+
78
+
79
+ # ' Create a template `arx_classifier` workflow
80
+ # '
81
+ # ' This function creates an unfit workflow for use with [arx_classifier()].
82
+ # ' It is useful if you want to make small modifications to that classifier
83
+ # ' before fitting and predicting. Supplying a trainer to the function
84
+ # ' may alter the returned `epi_workflow` object but can be omitted.
85
+ # '
86
+ # ' @inheritParams arx_classifier
87
+ # ' @param trainer A `{parsnip}` model describing the type of estimation.
88
+ # ' For now, we enforce `mode = "classification"`. Typical values are
89
+ # ' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
90
+ # ' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can
91
+ # ' also be used. May be `NULL` (the default).
92
+ # '
93
+ # ' @return An unfit `epi_workflow`.
94
+ # ' @export
95
+ # ' @seealso [arx_classifier()]
96
+ # ' @examples
97
+ # '
98
+ # ' jhu <- case_death_rate_subset %>%
99
+ # ' dplyr::filter(time_value >= as.Date("2021-11-01"))
100
+ # '
101
+ # ' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
102
+ # '
103
+ # ' arx_class_epi_workflow(
104
+ # ' jhu,
105
+ # ' "death_rate",
106
+ # ' c("case_rate", "death_rate"),
107
+ # ' trainer = parsnip::multinom_reg(),
108
+ # ' args_list = arx_class_args_list(
109
+ # ' breaks = c(-.05, .1), ahead = 14,
110
+ # ' horizon = 14, method = "linear_reg"
111
+ # ' )
112
+ # ' )
113
+ arx_class_epi_workflow <- function (
114
+ epi_data ,
115
+ outcome ,
116
+ predictors ,
117
+ trainer = NULL ,
118
+ args_list = arx_class_args_list()) {
119
+
120
+ validate_forecaster_inputs(epi_data , outcome , predictors )
121
+ if (! inherits(args_list , c(" arx_class" , " alist" )))
122
+ rlang :: abort(" args_list was not created using `arx_class_args_list()." )
123
+ if (! (is.null(trainer ) || is_classification(trainer )))
124
+ rlang :: abort(" `trainer` must be a `{parsnip}` model of mode 'classification'." )
49
125
lags <- arx_lags_validator(predictors , args_list $ lags )
50
126
51
127
# --- preprocessor
@@ -108,29 +184,23 @@ arx_classifier <- function(epi_data,
108
184
f <- layer_add_forecast_date(f , forecast_date = forecast_date ) %> %
109
185
layer_add_target_date(target_date = target_date )
110
186
111
-
112
- # --- create test data, fit, and return
113
- latest <- get_test_data(r , epi_data , TRUE )
114
- wf <- epi_workflow(r , trainer , f ) %> % generics :: fit(epi_data )
115
- list (
116
- predictions = predict(wf , new_data = latest ),
117
- epi_workflow = wf
118
- )
187
+ epi_workflow(r , trainer , f )
119
188
}
120
189
121
-
122
-
123
190
# ' ARX classifier argument constructor
124
191
# '
125
192
# ' Constructs a list of arguments for [arx_classifier()].
126
193
# '
127
194
# ' @inheritParams arx_args_list
128
195
# ' @param outcome_transform Scalar character. Whether the outcome should
129
- # ' be created using growth rates (as the predictors are) or lagged differences
130
- # ' or growth rates . The second case is closer to the requirements for the
196
+ # ' be created using growth rates (as the predictors are) or lagged
197
+ # ' differences . The second case is closer to the requirements for the
131
198
# ' [2022-23 CDC Flusight Hospitalization Experimental Target](https://github.com/cdcepi/Flusight-forecast-data/blob/745511c436923e1dc201dea0f4181f21a8217b52/data-experimental/README.md).
132
199
# ' See the Classification Vignette for details of how to create a reasonable
133
- # ' baseline for this case.
200
+ # ' baseline for this case. Selecting `"growth_rate"` (the default) uses
201
+ # ' [epiprocess::growth_rate()] to create the outcome using some of the
202
+ # ' additional arguments below. Choosing `"lag_difference"` instead simply
203
+ # ' uses the change from the value at the selected `horizon`.
134
204
# ' @param breaks Vector. A vector of breaks to turn real-valued growth rates
135
205
# ' into discrete classes. The default gives binary upswing classification
136
206
# ' as in [McDonald, Bien, Green, Hu, et al.](https://doi.org/10.1073/pnas.2111453118).
@@ -190,8 +260,9 @@ arx_class_args_list <- function(
190
260
arg_is_pos(n_training )
191
261
if (is.finite(n_training )) arg_is_pos_int(n_training )
192
262
if (! is.list(additional_gr_args )) {
193
- rlang :: abort(
194
- c(" `additional_gr_args` must be a list." ,
263
+ cli :: cli_abort(
264
+ c(" `additional_gr_args` must be a {.cls list}." ,
265
+ " !" = " This is a {.cls {class(additional_gr_args)}}." ,
195
266
i = " See `?epiprocess::growth_rate` for available arguments." )
196
267
)
197
268
}
@@ -216,11 +287,13 @@ arx_class_args_list <- function(
216
287
log_scale ,
217
288
additional_gr_args
218
289
),
219
- class = " arx_clist "
290
+ class = c( " arx_class " , " alist " )
220
291
)
221
292
}
222
293
223
294
# ' @export
224
- print.arx_clist <- function (x , ... ) {
225
- utils :: str(x )
295
+ print.arx_class <- function (x , ... ) {
296
+ name <- " ARX Classifier"
297
+ NextMethod(name = name , ... )
226
298
}
299
+
0 commit comments