Skip to content

Commit 8c48889

Browse files
committed
Merge branch 'v0.0.5' into smooth-quant-reg
2 parents ff8febf + 9d84f6a commit 8c48889

22 files changed

+532
-163
lines changed

.github/workflows/R-CMD-check.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
22
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
3+
#
4+
# Created with usethis + edited to use API key.
35
on:
46
push:
57
branches: [main, master]
@@ -27,3 +29,5 @@ jobs:
2729
needs: check
2830

2931
- uses: r-lib/actions/check-r-package@v2
32+
env:
33+
DELPHI_EPIDATA_KEY: ${{ secrets.SECRET_EPIPREDICT_GHACTIONS_DELPHI_EPIDATA_KEY }}

.github/workflows/pkgdown.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
22
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
3+
#
4+
# Created with usethis + edited to use API key.
35
on:
46
push:
57
branches: [main, master]
@@ -32,6 +34,8 @@ jobs:
3234
needs: website
3335

3436
- name: Build site
37+
env:
38+
DELPHI_EPIDATA_KEY: ${{ secrets.SECRET_EPIPREDICT_GHACTIONS_DELPHI_EPIDATA_KEY }}
3539
run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE)
3640
shell: Rscript {0}
3741

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.4
3+
Version: 0.0.5
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),

NAMESPACE

+7-4
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ S3method(prep,step_growth_rate)
4646
S3method(prep,step_lag_difference)
4747
S3method(prep,step_population_scaling)
4848
S3method(prep,step_training_window)
49-
S3method(print,arx_clist)
50-
S3method(print,arx_flist)
49+
S3method(print,alist)
50+
S3method(print,arx_class)
51+
S3method(print,arx_fcast)
52+
S3method(print,canned_epipred)
5153
S3method(print,epi_workflow)
52-
S3method(print,flatline_alist)
54+
S3method(print,flatline)
5355
S3method(print,frosting)
5456
S3method(print,layer_add_forecast_date)
5557
S3method(print,layer_add_target_date)
@@ -96,13 +98,14 @@ export(add_layer)
9698
export(apply_frosting)
9799
export(arx_args_list)
98100
export(arx_class_args_list)
101+
export(arx_class_epi_workflow)
99102
export(arx_classifier)
103+
export(arx_fcast_epi_workflow)
100104
export(arx_forecaster)
101105
export(bake)
102106
export(create_layer)
103107
export(default_epi_recipe_blueprint)
104108
export(detect_layer)
105-
export(df_mat_mul)
106109
export(dist_quantiles)
107110
export(epi_keys)
108111
export(epi_recipe)

NEWS.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# epipredict (development)
22

3+
# epipredict 0.0.5
4+
5+
* add `smooth_quantile_reg()`
6+
* improved printing of various methods / internals
7+
* canned forecasters get a class
8+
* fixed quantile bug in `flatline_forecaster()`
9+
* add functionality to output the unfit workflow from the canned forecasters
10+
* add `pivot_quantiles()` for easier plotting
11+
312

413
# epipredict 0.0.4
514

R/arx_classifier.R

+102-29
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,25 @@
55
#' that it estimates a class at a particular target horizon.
66
#'
77
#' @inheritParams arx_forecaster
8+
#' @param outcome A character (scalar) specifying the outcome (in the
9+
#' `epi_df`). Note that as with [arx_forecaster()], this is expected to
10+
#' be real-valued. Conversion of this data to unordered classes is handled
11+
#' internally based on the `breaks` argument to [arx_class_args_list()].
12+
#' If discrete classes are already in the `epi_df`, it is recommended to
13+
#' code up a classifier from scratch using [epi_recipe()].
814
#' @param trainer A `{parsnip}` model describing the type of estimation.
915
#' For now, we enforce `mode = "classification"`. Typical values are
1016
#' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
1117
#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can
1218
#' also be used.
1319
#' @param args_list A list of customization arguments to determine
14-
#' the type of forecasting model. See [arx_args_list()].
20+
#' the type of forecasting model. See [arx_class_args_list()].
1521
#'
1622
#' @return A list with (1) `predictions` an `epi_df` of predicted classes
1723
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
1824
#' workflow
1925
#' @export
26+
#' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
2027
#'
2128
#' @examples
2229
#' jhu <- case_death_rate_subset %>%
@@ -34,18 +41,87 @@
3441
#' horizon = 14, method = "linear_reg"
3542
#' )
3643
#' )
37-
arx_classifier <- function(epi_data,
38-
outcome,
39-
predictors,
40-
trainer = parsnip::logistic_reg(),
41-
args_list = arx_class_args_list()) {
44+
arx_classifier <- function(
45+
epi_data,
46+
outcome,
47+
predictors,
48+
trainer = parsnip::logistic_reg(),
49+
args_list = arx_class_args_list()) {
4250

43-
# --- validation
44-
validate_forecaster_inputs(epi_data, outcome, predictors)
45-
if (!inherits(args_list, "arx_clist"))
46-
cli_stop("args_list was not created using `arx_class_args_list().")
4751
if (!is_classification(trainer))
48-
cli_stop("{trainer} must be a `parsnip` method of mode 'classification'.")
52+
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
53+
54+
wf <- arx_class_epi_workflow(
55+
epi_data, outcome, predictors, trainer, args_list
56+
)
57+
58+
latest <- get_test_data(
59+
workflows::extract_preprocessor(wf), epi_data, TRUE
60+
)
61+
62+
wf <- generics::fit(wf, epi_data)
63+
preds <- predict(wf, new_data = latest) %>%
64+
tibble::as_tibble() %>%
65+
dplyr::select(-time_value)
66+
67+
structure(list(
68+
predictions = preds,
69+
epi_workflow = wf,
70+
metadata = list(
71+
training = attr(epi_data, "metadata"),
72+
forecast_created = Sys.time()
73+
)),
74+
class = c("arx_class", "canned_epipred")
75+
)
76+
}
77+
78+
79+
#' Create a template `arx_classifier` workflow
80+
#'
81+
#' This function creates an unfit workflow for use with [arx_classifier()].
82+
#' It is useful if you want to make small modifications to that classifier
83+
#' before fitting and predicting. Supplying a trainer to the function
84+
#' may alter the returned `epi_workflow` object but can be omitted.
85+
#'
86+
#' @inheritParams arx_classifier
87+
#' @param trainer A `{parsnip}` model describing the type of estimation.
88+
#' For now, we enforce `mode = "classification"`. Typical values are
89+
#' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
90+
#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can
91+
#' also be used. May be `NULL` (the default).
92+
#'
93+
#' @return An unfit `epi_workflow`.
94+
#' @export
95+
#' @seealso [arx_classifier()]
96+
#' @examples
97+
#'
98+
#' jhu <- case_death_rate_subset %>%
99+
#' dplyr::filter(time_value >= as.Date("2021-11-01"))
100+
#'
101+
#' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
102+
#'
103+
#' arx_class_epi_workflow(
104+
#' jhu,
105+
#' "death_rate",
106+
#' c("case_rate", "death_rate"),
107+
#' trainer = parsnip::multinom_reg(),
108+
#' args_list = arx_class_args_list(
109+
#' breaks = c(-.05, .1), ahead = 14,
110+
#' horizon = 14, method = "linear_reg"
111+
#' )
112+
#' )
113+
arx_class_epi_workflow <- function(
114+
epi_data,
115+
outcome,
116+
predictors,
117+
trainer = NULL,
118+
args_list = arx_class_args_list()) {
119+
120+
validate_forecaster_inputs(epi_data, outcome, predictors)
121+
if (!inherits(args_list, c("arx_class", "alist")))
122+
rlang::abort("args_list was not created using `arx_class_args_list().")
123+
if (!(is.null(trainer) || is_classification(trainer)))
124+
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
49125
lags <- arx_lags_validator(predictors, args_list$lags)
50126

51127
# --- preprocessor
@@ -108,29 +184,23 @@ arx_classifier <- function(epi_data,
108184
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
109185
layer_add_target_date(target_date = target_date)
110186

111-
112-
# --- create test data, fit, and return
113-
latest <- get_test_data(r, epi_data, TRUE)
114-
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
115-
list(
116-
predictions = predict(wf, new_data = latest),
117-
epi_workflow = wf
118-
)
187+
epi_workflow(r, trainer, f)
119188
}
120189

121-
122-
123190
#' ARX classifier argument constructor
124191
#'
125192
#' Constructs a list of arguments for [arx_classifier()].
126193
#'
127194
#' @inheritParams arx_args_list
128195
#' @param outcome_transform Scalar character. Whether the outcome should
129-
#' be created using growth rates (as the predictors are) or lagged differences
130-
#' or growth rates. The second case is closer to the requirements for the
196+
#' be created using growth rates (as the predictors are) or lagged
197+
#' differences. The second case is closer to the requirements for the
131198
#' [2022-23 CDC Flusight Hospitalization Experimental Target](https://github.com/cdcepi/Flusight-forecast-data/blob/745511c436923e1dc201dea0f4181f21a8217b52/data-experimental/README.md).
132199
#' See the Classification Vignette for details of how to create a reasonable
133-
#' baseline for this case.
200+
#' baseline for this case. Selecting `"growth_rate"` (the default) uses
201+
#' [epiprocess::growth_rate()] to create the outcome using some of the
202+
#' additional arguments below. Choosing `"lag_difference"` instead simply
203+
#' uses the change from the value at the selected `horizon`.
134204
#' @param breaks Vector. A vector of breaks to turn real-valued growth rates
135205
#' into discrete classes. The default gives binary upswing classification
136206
#' as in [McDonald, Bien, Green, Hu, et al.](https://doi.org/10.1073/pnas.2111453118).
@@ -190,8 +260,9 @@ arx_class_args_list <- function(
190260
arg_is_pos(n_training)
191261
if (is.finite(n_training)) arg_is_pos_int(n_training)
192262
if (!is.list(additional_gr_args)) {
193-
rlang::abort(
194-
c("`additional_gr_args` must be a list.",
263+
cli::cli_abort(
264+
c("`additional_gr_args` must be a {.cls list}.",
265+
"!" = "This is a {.cls {class(additional_gr_args)}}.",
195266
i = "See `?epiprocess::growth_rate` for available arguments.")
196267
)
197268
}
@@ -216,11 +287,13 @@ arx_class_args_list <- function(
216287
log_scale,
217288
additional_gr_args
218289
),
219-
class = "arx_clist"
290+
class = c("arx_class", "alist")
220291
)
221292
}
222293

223294
#' @export
224-
print.arx_clist <- function(x, ...) {
225-
utils::str(x)
295+
print.arx_class <- function(x, ...) {
296+
name <- "ARX Classifier"
297+
NextMethod(name = name, ...)
226298
}
299+

0 commit comments

Comments
 (0)