Skip to content

add functions to output an unfit classifier/forecaster workflow #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ S3method(prep,step_growth_rate)
S3method(prep,step_lag_difference)
S3method(prep,step_population_scaling)
S3method(prep,step_training_window)
S3method(print,arx_clist)
S3method(print,arx_flist)
S3method(print,alist)
S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,epi_workflow)
S3method(print,flatline_alist)
S3method(print,flatline)
S3method(print,frosting)
S3method(print,layer_add_forecast_date)
S3method(print,layer_add_target_date)
Expand Down Expand Up @@ -92,13 +94,14 @@ export(add_layer)
export(apply_frosting)
export(arx_args_list)
export(arx_class_args_list)
export(arx_class_epi_workflow)
export(arx_classifier)
export(arx_fcast_epi_workflow)
export(arx_forecaster)
export(bake)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
export(df_mat_mul)
export(dist_quantiles)
export(epi_keys)
export(epi_recipe)
Expand Down
116 changes: 90 additions & 26 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can
#' also be used.
#' @param args_list A list of customization arguments to determine
#' the type of forecasting model. See [arx_args_list()].
#' the type of forecasting model. See [arx_class_args_list()].
#'
#' @return A list with (1) `predictions` an `epi_df` of predicted classes
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
#' workflow
#' @export
#' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
#'
#' @examples
#' jhu <- case_death_rate_subset %>%
Expand All @@ -34,18 +35,87 @@
#' horizon = 14, method = "linear_reg"
#' )
#' )
arx_classifier <- function(epi_data,
outcome,
predictors,
trainer = parsnip::logistic_reg(),
args_list = arx_class_args_list()) {
arx_classifier <- function(
epi_data,
outcome,
predictors,
trainer = parsnip::logistic_reg(),
args_list = arx_class_args_list()) {

# --- validation
validate_forecaster_inputs(epi_data, outcome, predictors)
if (!inherits(args_list, "arx_clist"))
cli_stop("args_list was not created using `arx_class_args_list().")
if (!is_classification(trainer))
cli_stop("{trainer} must be a `parsnip` method of mode 'classification'.")
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")

wf <- arx_class_epi_workflow(
epi_data, outcome, predictors, trainer, args_list
)

latest <- get_test_data(
workflows::extract_preprocessor(wf), epi_data, TRUE
)

wf <- generics::fit(wf, epi_data)
preds <- predict(wf, new_data = latest) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value)

structure(list(
predictions = preds,
epi_workflow = wf,
metadata = list(
training = attr(epi_data, "metadata"),
forecast_created = Sys.time()
)),
class = c("arx_class", "canned_epipred")
)
}


#' Create a template `arx_classifier` workflow
#'
#' This function creates an unfit workflow for use with [arx_classifier()].
#' It is useful if you want to make small modifications to that classifier
#' before fitting and predicting. Supplying a trainer to the function
#' may alter the returned `epi_workflow` object but can be omitted.
#'
#' @inheritParams arx_classifier
#' @param trainer A `{parsnip}` model describing the type of estimation.
#' For now, we enforce `mode = "classification"`. Typical values are
#' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can
#' also be used. May be `NULL` (the default).
#'
#' @return An unfit `epi_workflow`.
#' @export
#' @seealso [arx_classifier()]
#' @examples
#'
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(time_value >= as.Date("2021-11-01"))
#'
#' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
#'
#' arx_class_epi_workflow(
#' jhu,
#' "death_rate",
#' c("case_rate", "death_rate"),
#' trainer = parsnip::multinom_reg(),
#' args_list = arx_class_args_list(
#' breaks = c(-.05, .1), ahead = 14,
#' horizon = 14, method = "linear_reg"
#' )
#' )
arx_class_epi_workflow <- function(
epi_data,
outcome,
predictors,
trainer = NULL,
args_list = arx_class_args_list()) {

validate_forecaster_inputs(epi_data, outcome, predictors)
if (!inherits(args_list, c("arx_class", "alist")))
rlang::abort("args_list was not created using `arx_class_args_list().")
if (!(is.null(trainer) || is_classification(trainer)))
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
lags <- arx_lags_validator(predictors, args_list$lags)

# --- preprocessor
Expand Down Expand Up @@ -108,18 +178,9 @@ arx_classifier <- function(epi_data,
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
layer_add_target_date(target_date = target_date)


# --- create test data, fit, and return
latest <- get_test_data(r, epi_data, TRUE)
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
list(
predictions = predict(wf, new_data = latest),
epi_workflow = wf
)
epi_workflow(r, trainer, f)
}



#' ARX classifier argument constructor
#'
#' Constructs a list of arguments for [arx_classifier()].
Expand Down Expand Up @@ -190,8 +251,9 @@ arx_class_args_list <- function(
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (!is.list(additional_gr_args)) {
rlang::abort(
c("`additional_gr_args` must be a list.",
cli::cli_abort(
c("`additional_gr_args` must be a {.cls list}.",
"!" = "This is a {.cls {class(additional_gr_args)}}.",
i = "See `?epiprocess::growth_rate` for available arguments.")
)
}
Expand All @@ -216,11 +278,13 @@ arx_class_args_list <- function(
log_scale,
additional_gr_args
),
class = "arx_clist"
class = c("arx_class", "alist")
)
}

#' @export
print.arx_clist <- function(x, ...) {
utils::str(x)
print.arx_class <- function(x, ...) {
name <- "ARX Classifier"
NextMethod(name = name, ...)
}

138 changes: 93 additions & 45 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#' and (2) `epi_workflow`, a list that encapsulates the entire estimation
#' workflow
#' @export
#' @seealso [arx_fcast_epi_workflow()], [arx_args_list()]
#'
#' @examples
#' jhu <- case_death_rate_subset %>%
Expand All @@ -36,12 +37,72 @@ arx_forecaster <- function(epi_data,
trainer = parsnip::linear_reg(),
args_list = arx_args_list()) {

if (!is_regression(trainer))
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")

wf <- arx_fcast_epi_workflow(
epi_data, outcome, predictors, trainer, args_list
)

latest <- get_test_data(
workflows::extract_preprocessor(wf), epi_data, TRUE
)

wf <- generics::fit(wf, epi_data)
preds <- predict(wf, new_data = latest) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value)

structure(list(
predictions = preds,
epi_workflow = wf,
metadata = list(
training = attr(epi_data, "metadata"),
forecast_created = Sys.time()
)),
class = c("arx_fcast", "canned_epipred")
)
}

#' Create a template `arx_forecaster` workflow
#'
#' This function creates an unfit workflow for use with [arx_forecaster()].
#' It is useful if you want to make small modifications to that forecaster
#' before fitting and predicting. Supplying a trainer to the function
#' may alter the returned `epi_workflow` object (e.g., if you intend to
#' use [quantile_reg()]) but can be omitted.
#'
#' @inheritParams arx_forecaster
#' @param trainer A `{parsnip}` model describing the type of estimation.
#' For now, we enforce `mode = "regression"`. May be `NULL` (the default).
#'
#' @return An unfitted `epi_workflow`.
#' @export
#' @seealso [arx_forecaster()]
#'
#' @examples
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
#'
#' arx_fcast_epi_workflow(jhu, "death_rate",
#' c("case_rate", "death_rate"))
#'
#' arx_fcast_epi_workflow(jhu, "death_rate",
#' c("case_rate", "death_rate"), trainer = quantile_reg(),
#' args_list = arx_args_list(levels = 1:9 / 10))
arx_fcast_epi_workflow <- function(
epi_data,
outcome,
predictors,
trainer = NULL,
args_list = arx_args_list()) {

# --- validation
validate_forecaster_inputs(epi_data, outcome, predictors)
if (!inherits(args_list, "arx_flist"))
cli_stop("args_list was not created using `arx_args_list().")
if (!is_regression(trainer))
cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.")
if (!inherits(args_list, c("arx_fcast", "alist")))
cli::cli_abort("args_list was not created using `arx_args_list().")
if (!(is.null(trainer) || is_regression(trainer)))
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
lags <- arx_lags_validator(predictors, args_list$lags)

# --- preprocessor
Expand Down Expand Up @@ -74,28 +135,10 @@ arx_forecaster <- function(epi_data,
layer_add_target_date(target_date = target_date)
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))

# --- create test data, fit, and return
latest <- get_test_data(r, epi_data, TRUE)
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
list(
predictions = predict(wf, new_data = latest),
epi_workflow = wf
)
epi_workflow(r, trainer, f)
}


arx_lags_validator <- function(predictors, lags) {
p <- length(predictors)
if (!is.list(lags)) lags <- list(lags)
if (length(lags) == 1) lags <- rep(lags, p)
else if (length(lags) < p) {
cli_stop(
"You have requested {p} predictors but lags cannot be recycled to match."
)
}
lags
}

#' ARX forecaster argument constructor
#'
#' Constructs a list of arguments for [arx_forecaster()].
Expand Down Expand Up @@ -130,15 +173,16 @@ arx_lags_validator <- function(predictors, lags) {
#' arx_args_list()
#' arx_args_list(symmetrize = FALSE)
#' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
arx_args_list <- function(lags = c(0L, 7L, 14L),
ahead = 7L,
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
levels = c(0.05, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L)) {
arx_args_list <- function(
lags = c(0L, 7L, 14L),
ahead = 7L,
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
levels = c(0.05, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L)) {

# error checking if lags is a list
.lags <- lags
Expand All @@ -155,20 +199,24 @@ arx_args_list <- function(lags = c(0L, 7L, 14L),
if (is.finite(n_training)) arg_is_pos_int(n_training)

max_lags <- max(lags)
structure(enlist(lags = .lags,
ahead,
n_training,
levels,
forecast_date,
target_date,
symmetrize,
nonneg,
max_lags,
quantile_by_key),
class = "arx_flist")
structure(
enlist(lags = .lags,
ahead,
n_training,
levels,
forecast_date,
target_date,
symmetrize,
nonneg,
max_lags,
quantile_by_key),
class = c("arx_fcast", "alist")
)
}


#' @export
print.arx_flist <- function(x, ...) {
utils::str(x)
print.arx_fcast <- function(x, ...) {
name <- "ARX Forecaster"
NextMethod(name = name, ...)
}
Loading