From 7557e7b92330e5e8840cbaf749cfe9ea7c666d6b Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sat, 27 May 2023 10:10:30 -0700 Subject: [PATCH 1/5] add functions to output an unfit classifier/forecaster workflow --- NAMESPACE | 2 + R/arx_classifier.R | 90 ++++++++++++++++++++++++------- R/arx_forecaster.R | 78 +++++++++++++++++++++------ R/utils-misc.R | 4 +- man/arx_classifier.Rd | 3 ++ man/arx_forecaster.Rd | 3 ++ man/arxc_epi_workflow_template.Rd | 62 +++++++++++++++++++++ man/arxf_epi_workflow_template.Rd | 53 ++++++++++++++++++ man/layer_predict.Rd | 2 +- man/predict-epi_workflow.Rd | 2 +- 10 files changed, 260 insertions(+), 39 deletions(-) create mode 100644 man/arxc_epi_workflow_template.Rd create mode 100644 man/arxf_epi_workflow_template.Rd diff --git a/NAMESPACE b/NAMESPACE index b65a7e6b2..dc927c1ce 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -94,6 +94,8 @@ export(arx_args_list) export(arx_class_args_list) export(arx_classifier) export(arx_forecaster) +export(arxc_epi_workflow_template) +export(arxf_epi_workflow_template) export(bake) export(create_layer) export(default_epi_recipe_blueprint) diff --git a/R/arx_classifier.R b/R/arx_classifier.R index 78a80534b..8a4060263 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -17,6 +17,7 @@ #' and (2) `epi_workflow`, a list that encapsulates the entire estimation #' workflow #' @export +#' @seealso [arxc_epi_workflow_template()] #' #' @examples #' jhu <- case_death_rate_subset %>% @@ -34,18 +35,78 @@ #' horizon = 14, method = "linear_reg" #' ) #' ) -arx_classifier <- function(epi_data, - outcome, - predictors, - trainer = parsnip::logistic_reg(), - args_list = arx_class_args_list()) { +arx_classifier <- function( + epi_data, + outcome, + predictors, + trainer = parsnip::logistic_reg(), + args_list = arx_class_args_list()) { + + if (!is_classification(trainer)) + rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.") + + wf <- arxc_epi_workflow_template( + epi_data, outcome, predictors, trainer, args_list + ) + + latest <- get_test_data( + workflows::extract_preprocessor(wf), epi_data, TRUE + ) + + wf <- generics::fit(wf, epi_data) + list( + predictions = predict(wf, new_data = latest), + epi_workflow = wf + ) +} + + +#' Create a template `arx_classifier` workflow +#' +#' This function creates an unfit workflow for use with [arx_classifier()]. +#' It is useful if you want to make small modifications to that classifier +#' before fitting and predicting. Supplying a trainer to the function +#' may alter the returned `epi_workflow` object but can be omitted. +#' +#' @inheritParams arx_classifier +#' @param trainer A `{parsnip}` model describing the type of estimation. +#' For now, we enforce `mode = "classification"`. Typical values are +#' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated +#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can +#' also be used. May be `NULL` (the default). +#' +#' @return An unfit `epi_workflow`. +#' @export +#' @seealso [arx_classifier()] +#' @examples +#' +#' jhu <- case_death_rate_subset %>% +#' dplyr::filter(time_value >= as.Date("2021-11-01")) +#' +#' arxc_epi_workflow_template(jhu, "death_rate", c("case_rate", "death_rate")) +#' +#' arxc_epi_workflow_template( +#' jhu, +#' "death_rate", +#' c("case_rate", "death_rate"), +#' trainer = parsnip::multinom_reg(), +#' args_list = arx_class_args_list( +#' breaks = c(-.05, .1), ahead = 14, +#' horizon = 14, method = "linear_reg" +#' ) +#' ) +arxc_epi_workflow_template <- function( + epi_data, + outcome, + predictors, + trainer = NULL, + args_list = arx_class_args_list()) { - # --- validation validate_forecaster_inputs(epi_data, outcome, predictors) if (!inherits(args_list, "arx_clist")) - cli_stop("args_list was not created using `arx_class_args_list().") - if (!is_classification(trainer)) - cli_stop("{trainer} must be a `parsnip` method of mode 'classification'.") + rlang::abort("args_list was not created using `arx_class_args_list().") + if (!(is.null(trainer) || is_classification(trainer))) + rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.") lags <- arx_lags_validator(predictors, args_list$lags) # --- preprocessor @@ -108,18 +169,9 @@ arx_classifier <- function(epi_data, f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>% layer_add_target_date(target_date = target_date) - - # --- create test data, fit, and return - latest <- get_test_data(r, epi_data, TRUE) - wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data) - list( - predictions = predict(wf, new_data = latest), - epi_workflow = wf - ) + epi_workflow(r, trainer, f) } - - #' ARX classifier argument constructor #' #' Constructs a list of arguments for [arx_classifier()]. diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 34612202a..53cfb256e 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -19,6 +19,7 @@ #' and (2) `epi_workflow`, a list that encapsulates the entire estimation #' workflow #' @export +#' @seealso [arxf_epi_workflow_template()] #' #' @examples #' jhu <- case_death_rate_subset %>% @@ -36,12 +37,63 @@ arx_forecaster <- function(epi_data, trainer = parsnip::linear_reg(), args_list = arx_args_list()) { + if (!is_regression(trainer)) + rlang::abort("`trainer` must be a `{parsnip}` model of mode 'regression'.") + + wf <- arxf_epi_workflow_template( + epi_data, outcome, predictors, trainer, args_list + ) + + latest <- get_test_data( + workflows::extract_preprocessor(wf), epi_data, TRUE + ) + + wf <- generics::fit(wf, epi_data) + list( + predictions = predict(wf, new_data = latest), + epi_workflow = wf + ) +} + +#' Create a template `arx_forecaster` workflow +#' +#' This function creates an unfit workflow for use with [arx_forecaster()]. +#' It is useful if you want to make small modifications to that forecaster +#' before fitting and predicting. Supplying a trainer to the function +#' may alter the returned `epi_workflow` object (e.g., if you intend to +#' use [quantile_reg()]) but can be omitted. +#' +#' @inheritParams arx_forecaster +#' @param trainer A `{parsnip}` model describing the type of estimation. +#' For now, we enforce `mode = "regression"`. May be `NULL` (the default). +#' +#' @return An unfitted `epi_workflow`. +#' @export +#' @seealso [arx_forecaster()] +#' +#' @examples +#' jhu <- case_death_rate_subset %>% +#' dplyr::filter(time_value >= as.Date("2021-12-01")) +#' +#' arxf_epi_workflow_template(jhu, "death_rate", +#' c("case_rate", "death_rate")) +#' +#' arxf_epi_workflow_template(jhu, "death_rate", +#' c("case_rate", "death_rate"), trainer = quantile_reg(), +#' args_list = arx_args_list(levels = 1:9 / 10)) +arxf_epi_workflow_template <- function( + epi_data, + outcome, + predictors, + trainer = NULL, + args_list = arx_args_list()) { + # --- validation validate_forecaster_inputs(epi_data, outcome, predictors) if (!inherits(args_list, "arx_flist")) - cli_stop("args_list was not created using `arx_args_list().") - if (!is_regression(trainer)) - cli_stop("{trainer} must be a `parsnip` method of mode 'regression'.") + cli::cli_abort("args_list was not created using `arx_args_list().") + if (!(is.null(trainer) || is_regression(trainer))) + cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.") lags <- arx_lags_validator(predictors, args_list$lags) # --- preprocessor @@ -74,24 +126,20 @@ arx_forecaster <- function(epi_data, layer_add_target_date(target_date = target_date) if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) - # --- create test data, fit, and return - latest <- get_test_data(r, epi_data, TRUE) - wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data) - list( - predictions = predict(wf, new_data = latest), - epi_workflow = wf - ) + epi_workflow(r, trainer, f) } arx_lags_validator <- function(predictors, lags) { p <- length(predictors) if (!is.list(lags)) lags <- list(lags) - if (length(lags) == 1) lags <- rep(lags, p) - else if (length(lags) < p) { - cli_stop( - "You have requested {p} predictors but lags cannot be recycled to match." - ) + l <- length(lags) + if (l == 1) lags <- rep(lags, p) + else if (length(lags) != p) { + cli::cli_abort(c( + "You have requested {p} predictor(s) but {l} different lags.", + i = "Lags a vector or a list with length == number of predictors." + )) } lags } diff --git a/R/utils-misc.R b/R/utils-misc.R index 0fca30064..93b2baa9f 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -60,9 +60,7 @@ grab_forged_keys <- function(forged, mold, new_data) { } get_parsnip_mode <- function(trainer) { - if (inherits(trainer, "model_spec")) { - return(trainer$mode) - } + if (inherits(trainer, "model_spec")) return(trainer$mode) cc <- class(trainer) cli::cli_abort( c("`trainer` must be a `parsnip` model.", diff --git a/man/arx_classifier.Rd b/man/arx_classifier.Rd index 942a64787..69ab2d527 100644 --- a/man/arx_classifier.Rd +++ b/man/arx_classifier.Rd @@ -57,3 +57,6 @@ out <- arx_classifier( ) ) } +\seealso{ +\code{\link[=arxc_epi_workflow_template]{arxc_epi_workflow_template()}} +} diff --git a/man/arx_forecaster.Rd b/man/arx_forecaster.Rd index 46367d844..7b409737f 100644 --- a/man/arx_forecaster.Rd +++ b/man/arx_forecaster.Rd @@ -48,3 +48,6 @@ out <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), trainer = quantile_reg(), args_list = arx_args_list(levels = 1:9 / 10)) } +\seealso{ +\code{\link[=arxf_epi_workflow_template]{arxf_epi_workflow_template()}} +} diff --git a/man/arxc_epi_workflow_template.Rd b/man/arxc_epi_workflow_template.Rd new file mode 100644 index 000000000..68be2af79 --- /dev/null +++ b/man/arxc_epi_workflow_template.Rd @@ -0,0 +1,62 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/arx_classifier.R +\name{arxc_epi_workflow_template} +\alias{arxc_epi_workflow_template} +\title{Create a template \code{arx_classifier} workflow} +\usage{ +arxc_epi_workflow_template( + epi_data, + outcome, + predictors, + trainer = NULL, + args_list = arx_class_args_list() +) +} +\arguments{ +\item{epi_data}{An \code{epi_df} object} + +\item{outcome}{A character (scalar) specifying the outcome (in the +\code{epi_df}).} + +\item{predictors}{A character vector giving column(s) of predictor +variables.} + +\item{trainer}{A \code{{parsnip}} model describing the type of estimation. +For now, we enforce \code{mode = "classification"}. Typical values are +\code{\link[parsnip:logistic_reg]{parsnip::logistic_reg()}} or \code{\link[parsnip:multinom_reg]{parsnip::multinom_reg()}}. More complicated +trainers like \code{\link[parsnip:naive_Bayes]{parsnip::naive_Bayes()}} or \code{\link[parsnip:rand_forest]{parsnip::rand_forest()}} can +also be used. May be \code{NULL} (the default).} + +\item{args_list}{A list of customization arguments to determine +the type of forecasting model. See \code{\link[=arx_args_list]{arx_args_list()}}.} +} +\value{ +An unfit \code{epi_workflow}. +} +\description{ +This function creates an unfit workflow for use with \code{\link[=arx_classifier]{arx_classifier()}}. +It is useful if you want to make small modifications to that classifier +before fitting and predicting. Supplying a trainer to the function +may alter the returned \code{epi_workflow} object but can be omitted. +} +\examples{ + +jhu <- case_death_rate_subset \%>\% + dplyr::filter(time_value >= as.Date("2021-11-01")) + +arxc_epi_workflow_template(jhu, "death_rate", c("case_rate", "death_rate")) + +arxc_epi_workflow_template( + jhu, + "death_rate", + c("case_rate", "death_rate"), + trainer = parsnip::multinom_reg(), + args_list = arx_class_args_list( + breaks = c(-.05, .1), ahead = 14, + horizon = 14, method = "linear_reg" + ) +) +} +\seealso{ +\code{\link[=arx_classifier]{arx_classifier()}} +} diff --git a/man/arxf_epi_workflow_template.Rd b/man/arxf_epi_workflow_template.Rd new file mode 100644 index 000000000..6c8d87878 --- /dev/null +++ b/man/arxf_epi_workflow_template.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/arx_forecaster.R +\name{arxf_epi_workflow_template} +\alias{arxf_epi_workflow_template} +\title{Create a template \code{arx_forecaster} workflow} +\usage{ +arxf_epi_workflow_template( + epi_data, + outcome, + predictors, + trainer = NULL, + args_list = arx_args_list() +) +} +\arguments{ +\item{epi_data}{An \code{epi_df} object} + +\item{outcome}{A character (scalar) specifying the outcome (in the +\code{epi_df}).} + +\item{predictors}{A character vector giving column(s) of predictor +variables.} + +\item{trainer}{A \code{{parsnip}} model describing the type of estimation. +For now, we enforce \code{mode = "regression"}. May be \code{NULL} (the default).} + +\item{args_list}{A list of customization arguments to determine +the type of forecasting model. See \code{\link[=arx_args_list]{arx_args_list()}}.} +} +\value{ +An unfitted \code{epi_workflow}. +} +\description{ +This function creates an unfit workflow for use with \code{\link[=arx_forecaster]{arx_forecaster()}}. +It is useful if you want to make small modifications to that forecaster +before fitting and predicting. Supplying a trainer to the function +may alter the returned \code{epi_workflow} object (e.g., if you intend to +use \code{\link[=quantile_reg]{quantile_reg()}}) but can be omitted. +} +\examples{ +jhu <- case_death_rate_subset \%>\% + dplyr::filter(time_value >= as.Date("2021-12-01")) + +arxf_epi_workflow_template(jhu, "death_rate", + c("case_rate", "death_rate")) + +arxf_epi_workflow_template(jhu, "death_rate", + c("case_rate", "death_rate"), trainer = quantile_reg(), + args_list = arx_args_list(levels = 1:9 / 10)) +} +\seealso{ +\code{\link[=arx_forecaster]{arx_forecaster()}} +} diff --git a/man/layer_predict.Rd b/man/layer_predict.Rd index e0986bdf5..1326dfe75 100644 --- a/man/layer_predict.Rd +++ b/man/layer_predict.Rd @@ -42,7 +42,7 @@ the standard error of fit or prediction (on the scale of the linear predictors). Default value is \code{FALSE}. \item \code{quantile}: for \code{type} equal to \code{quantile}, the quantiles of the distribution. Default is \code{(1:9)/10}. -\item \code{time}: for \code{type} equal to \code{"survival"} or \code{"hazard"}, the +\item \code{eval_time}: for \code{type} equal to \code{"survival"} or \code{"hazard"}, the time points at which the survival probability or hazard is estimated. }} diff --git a/man/predict-epi_workflow.Rd b/man/predict-epi_workflow.Rd index 5058e4bb4..d92fd8ca9 100644 --- a/man/predict-epi_workflow.Rd +++ b/man/predict-epi_workflow.Rd @@ -31,7 +31,7 @@ the standard error of fit or prediction (on the scale of the linear predictors). Default value is \code{FALSE}. \item \code{quantile}: for \code{type} equal to \code{quantile}, the quantiles of the distribution. Default is \code{(1:9)/10}. -\item \code{time}: for \code{type} equal to \code{"survival"} or \code{"hazard"}, the +\item \code{eval_time}: for \code{type} equal to \code{"survival"} or \code{"hazard"}, the time points at which the survival probability or hazard is estimated. }} } From 3ba6d5777a38db99ed360d3eb4e8241d0a0eaf37 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 15 Jun 2023 12:53:27 -0700 Subject: [PATCH 2/5] add classes to args_lists, move shared canned forecaster code to a separate file --- NAMESPACE | 4 +-- R/arx_classifier.R | 14 +++----- R/arx_forecaster.R | 64 ++++++++++++++---------------------- R/canned-forecaster-common.R | 41 +++++++++++++++++++++++ R/flatline_forecaster.R | 37 +++++++-------------- 5 files changed, 82 insertions(+), 78 deletions(-) create mode 100644 R/canned-forecaster-common.R diff --git a/NAMESPACE b/NAMESPACE index dc927c1ce..8225c3f38 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -44,10 +44,8 @@ S3method(prep,step_growth_rate) S3method(prep,step_lag_difference) S3method(prep,step_population_scaling) S3method(prep,step_training_window) -S3method(print,arx_clist) -S3method(print,arx_flist) +S3method(print,alist) S3method(print,epi_workflow) -S3method(print,flatline_alist) S3method(print,frosting) S3method(print,layer_add_forecast_date) S3method(print,layer_add_target_date) diff --git a/R/arx_classifier.R b/R/arx_classifier.R index 8a4060263..3ab51d52a 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -43,7 +43,7 @@ arx_classifier <- function( args_list = arx_class_args_list()) { if (!is_classification(trainer)) - rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.") + cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") wf <- arxc_epi_workflow_template( epi_data, outcome, predictors, trainer, args_list @@ -242,8 +242,9 @@ arx_class_args_list <- function( arg_is_pos(n_training) if (is.finite(n_training)) arg_is_pos_int(n_training) if (!is.list(additional_gr_args)) { - rlang::abort( - c("`additional_gr_args` must be a list.", + cli::cli_abort( + c("`additional_gr_args` must be a {.cls list}.", + "!" = "This is a {.cls {class(additional_gr_args)}}.", i = "See `?epiprocess::growth_rate` for available arguments.") ) } @@ -268,11 +269,6 @@ arx_class_args_list <- function( log_scale, additional_gr_args ), - class = "arx_clist" + class = c("arx_class", "alist") ) } - -#' @export -print.arx_clist <- function(x, ...) { - utils::str(x) -} diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 53cfb256e..25ee1d162 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -38,7 +38,7 @@ arx_forecaster <- function(epi_data, args_list = arx_args_list()) { if (!is_regression(trainer)) - rlang::abort("`trainer` must be a `{parsnip}` model of mode 'regression'.") + cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") wf <- arxf_epi_workflow_template( epi_data, outcome, predictors, trainer, args_list @@ -130,20 +130,6 @@ arxf_epi_workflow_template <- function( } -arx_lags_validator <- function(predictors, lags) { - p <- length(predictors) - if (!is.list(lags)) lags <- list(lags) - l <- length(lags) - if (l == 1) lags <- rep(lags, p) - else if (length(lags) != p) { - cli::cli_abort(c( - "You have requested {p} predictor(s) but {l} different lags.", - i = "Lags a vector or a list with length == number of predictors." - )) - } - lags -} - #' ARX forecaster argument constructor #' #' Constructs a list of arguments for [arx_forecaster()]. @@ -178,15 +164,16 @@ arx_lags_validator <- function(predictors, lags) { #' arx_args_list() #' arx_args_list(symmetrize = FALSE) #' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120) -arx_args_list <- function(lags = c(0L, 7L, 14L), - ahead = 7L, - n_training = Inf, - forecast_date = NULL, - target_date = NULL, - levels = c(0.05, 0.95), - symmetrize = TRUE, - nonneg = TRUE, - quantile_by_key = character(0L)) { +arx_args_list <- function( + lags = c(0L, 7L, 14L), + ahead = 7L, + n_training = Inf, + forecast_date = NULL, + target_date = NULL, + levels = c(0.05, 0.95), + symmetrize = TRUE, + nonneg = TRUE, + quantile_by_key = character(0L)) { # error checking if lags is a list .lags <- lags @@ -203,20 +190,17 @@ arx_args_list <- function(lags = c(0L, 7L, 14L), if (is.finite(n_training)) arg_is_pos_int(n_training) max_lags <- max(lags) - structure(enlist(lags = .lags, - ahead, - n_training, - levels, - forecast_date, - target_date, - symmetrize, - nonneg, - max_lags, - quantile_by_key), - class = "arx_flist") -} - -#' @export -print.arx_flist <- function(x, ...) { - utils::str(x) + structure( + enlist(lags = .lags, + ahead, + n_training, + levels, + forecast_date, + target_date, + symmetrize, + nonneg, + max_lags, + quantile_by_key), + class = c("arx_fcast", "alist") + ) } diff --git a/R/canned-forecaster-common.R b/R/canned-forecaster-common.R new file mode 100644 index 000000000..eae1d38ca --- /dev/null +++ b/R/canned-forecaster-common.R @@ -0,0 +1,41 @@ +validate_forecaster_inputs <- function(epi_data, outcome, predictors) { + if (!epiprocess::is_epi_df(epi_data)) { + cli::cli_abort(c( + "`epi_data` must be an {.cls epi_df}.", + "!" = "This one is a {.cls {class(epi_data)}}." + )) + } + arg_is_chr(predictors) + arg_is_chr_scalar(outcome) + if (!outcome %in% names(epi_data)) + cli::cli_abort("{outcome} was not found in the training data.") + check <- hardhat::check_column_names(epi_data, predictors) + if (!check$ok) { + cli::cli_abort(c( + "At least one predictor was not found in the training data.", + "!" = "The following required columns are missing: {check$missing_names}." + )) + } + invisible(TRUE) +} + + +arx_lags_validator <- function(predictors, lags) { + p <- length(predictors) + if (!is.list(lags)) lags <- list(lags) + l <- length(lags) + if (l == 1) lags <- rep(lags, p) + else if (length(lags) != p) { + cli::cli_abort(c( + "You have requested {p} predictor(s) but {l} different lags.", + i = "Lags must be a vector or a list with length == number of predictors." + )) + } + lags +} + + +#' @export +print.alist <- function(x, ...) { + utils::str(x) +} diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 0eb718b52..47a3bc481 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -107,32 +107,17 @@ flatline_args_list <- function( arg_is_pos(n_training) if (is.finite(n_training)) arg_is_pos_int(n_training) - structure(enlist(ahead, - n_training, - forecast_date, - target_date, - levels, - symmetrize, - nonneg, - quantile_by_key), - class = "flatline_alist") -} - -validate_forecaster_inputs <- function(epi_data, outcome, predictors) { - if (!epiprocess::is_epi_df(epi_data)) - cli_stop("epi_data must be an epi_df.") - arg_is_chr(predictors) - arg_is_chr_scalar(outcome) - if (!outcome %in% names(epi_data)) - cli_stop("{outcome} was not found in the training data.") - if (!all(predictors %in% names(epi_data))) - cli_stop("At least one predictor was not found in the training data.") - invisible(TRUE) -} - -#' @export -print.flatline_alist <- function(x, ...) { - utils::str(x) + structure( + enlist(ahead, + n_training, + forecast_date, + target_date, + levels, + symmetrize, + nonneg, + quantile_by_key), + class = c("flatline", "alist") + ) } From 2daed632f6998a932432d6f628dd5eb5904856cb Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 15 Jun 2023 14:50:26 -0700 Subject: [PATCH 3/5] improved printing for all canned forecasters --- NAMESPACE | 4 +++ R/arx_classifier.R | 24 +++++++++++--- R/arx_forecaster.R | 24 +++++++++++--- ...d-forecaster-common.R => canned-epipred.R} | 33 +++++++++++++++++++ R/flatline_forecaster.R | 28 ++++++++++++---- 5 files changed, 98 insertions(+), 15 deletions(-) rename R/{canned-forecaster-common.R => canned-epipred.R} (54%) diff --git a/NAMESPACE b/NAMESPACE index 8225c3f38..f6f21fdd1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -45,7 +45,11 @@ S3method(prep,step_lag_difference) S3method(prep,step_population_scaling) S3method(prep,step_training_window) S3method(print,alist) +S3method(print,arx_class) +S3method(print,arx_fcast) +S3method(print,canned_epipred) S3method(print,epi_workflow) +S3method(print,flatline) S3method(print,frosting) S3method(print,layer_add_forecast_date) S3method(print,layer_add_target_date) diff --git a/R/arx_classifier.R b/R/arx_classifier.R index 3ab51d52a..9e9570718 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -54,9 +54,18 @@ arx_classifier <- function( ) wf <- generics::fit(wf, epi_data) - list( - predictions = predict(wf, new_data = latest), - epi_workflow = wf + preds <- predict(wf, new_data = latest) %>% + tibble::as_tibble() %>% + dplyr::select(-time_value) + + structure(list( + predictions = preds, + epi_workflow = wf, + metadata = list( + training = attr(epi_data, "metadata"), + forecast_created = Sys.time() + )), + class = c("arx_class", "canned_epipred") ) } @@ -103,7 +112,7 @@ arxc_epi_workflow_template <- function( args_list = arx_class_args_list()) { validate_forecaster_inputs(epi_data, outcome, predictors) - if (!inherits(args_list, "arx_clist")) + if (!inherits(args_list, c("arx_class", "alist"))) rlang::abort("args_list was not created using `arx_class_args_list().") if (!(is.null(trainer) || is_classification(trainer))) rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.") @@ -272,3 +281,10 @@ arx_class_args_list <- function( class = c("arx_class", "alist") ) } + +#' @export +print.arx_class <- function(x, ...) { + name <- "ARX Classifier" + NextMethod(name = name, ...) +} + diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 25ee1d162..128ecdb6e 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -49,9 +49,18 @@ arx_forecaster <- function(epi_data, ) wf <- generics::fit(wf, epi_data) - list( - predictions = predict(wf, new_data = latest), - epi_workflow = wf + preds <- predict(wf, new_data = latest) %>% + tibble::as_tibble() %>% + dplyr::select(-time_value) + + structure(list( + predictions = preds, + epi_workflow = wf, + metadata = list( + training = attr(epi_data, "metadata"), + forecast_created = Sys.time() + )), + class = c("arx_fcast", "canned_epipred") ) } @@ -90,7 +99,7 @@ arxf_epi_workflow_template <- function( # --- validation validate_forecaster_inputs(epi_data, outcome, predictors) - if (!inherits(args_list, "arx_flist")) + if (!inherits(args_list, c("arx_fcast", "alist"))) cli::cli_abort("args_list was not created using `arx_args_list().") if (!(is.null(trainer) || is_regression(trainer))) cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.") @@ -204,3 +213,10 @@ arx_args_list <- function( class = c("arx_fcast", "alist") ) } + + +#' @export +print.arx_fcast <- function(x, ...) { + name <- "ARX Forecaster" + NextMethod(name = name, ...) +} diff --git a/R/canned-forecaster-common.R b/R/canned-epipred.R similarity index 54% rename from R/canned-forecaster-common.R rename to R/canned-epipred.R index eae1d38ca..d2b4af22f 100644 --- a/R/canned-forecaster-common.R +++ b/R/canned-epipred.R @@ -39,3 +39,36 @@ arx_lags_validator <- function(predictors, lags) { print.alist <- function(x, ...) { utils::str(x) } + +#' @export +print.canned_epipred <- function(x, name, ...) { + cat("\n") + cli::cli_rule("A basic forecaster of type {.pkg {name}}") + + cat("\n") + cli::cli_text( + "This forecaster was fit on {.val {format(x$metadata$forecast_created)}}." + ) + cat("\n") + cli::cli_text("Training data was an {.cls epi_df} with ") + cli::cli_ul(c( + "Geography: {.val {x$metadata$training$geo_type}},", + "Time type: {.val {x$metadata$training$time_type}},", + "Using data up-to-date as of: {.val {format(x$metadata$training$as_of)}}." + )) + + cat("\n") + cli::cli_rule("Predictions") + n_geos <- dplyr::n_distinct(x$predictions$geo_value) + fds <- unique(x$predictions$forecast_date) + tds <- unique(x$predictions$target_date) + + cat("\n") + cli::cli_text("A total of {nrow(x$predictions)} predictions are available for") + cli::cli_ul(c( + "{n_geos} unique geographic regions,", + "At forecast dates: {.val {fds}},", + "For target dates: {.val {tds}}." + )) + cat("\n") +} diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 47a3bc481..8843f10fe 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -33,7 +33,7 @@ flatline_forecaster <- function( args_list = flatline_args_list()) { validate_forecaster_inputs(epi_data, outcome, "time_value") - if (!inherits(args_list, "flatline_alist")) { + if (!inherits(args_list, c("flatline", "alist"))) { cli_stop("args_list was not created using `flatline_args_list().") } keys <- epi_keys(epi_data) @@ -64,11 +64,21 @@ flatline_forecaster <- function( eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") - wf <- epi_workflow(r, eng, f) %>% fit(epi_data) - - list( - predictions = suppressWarnings(predict(wf, new_data = latest)), - epi_workflow = wf + wf <- epi_workflow(r, eng, f) + + wf <- generics::fit(wf, epi_data) + preds <- suppressWarnings(predict(wf, new_data = latest)) %>% + tibble::as_tibble() %>% + dplyr::select(-time_value) + + structure(list( + predictions = preds, + epi_workflow = wf, + metadata = list( + training = attr(epi_data, "metadata"), + forecast_created = Sys.time() + )), + class = c("flatline", "canned_epipred") ) } @@ -120,4 +130,8 @@ flatline_args_list <- function( ) } - +#' @export +print.flatline <- function(x, ...) { + name <- "flatline" + NextMethod(name = name, ...) +} From f639dce06c3751eeb1d87c5bbe0baff26acd0f70 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 15 Jun 2023 14:56:11 -0700 Subject: [PATCH 4/5] better naming for the template constructors --- NAMESPACE | 4 ++-- R/arx_classifier.R | 12 ++++++------ R/arx_forecaster.R | 10 +++++----- ...orkflow_template.Rd => arx_class_epi_workflow.Rd} | 12 ++++++------ man/arx_classifier.Rd | 4 ++-- ...orkflow_template.Rd => arx_fcast_epi_workflow.Rd} | 10 +++++----- man/arx_forecaster.Rd | 2 +- 7 files changed, 27 insertions(+), 27 deletions(-) rename man/{arxc_epi_workflow_template.Rd => arx_class_epi_workflow.Rd} (86%) rename man/{arxf_epi_workflow_template.Rd => arx_fcast_epi_workflow.Rd} (88%) diff --git a/NAMESPACE b/NAMESPACE index f6f21fdd1..d08a09807 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -94,10 +94,10 @@ export(add_layer) export(apply_frosting) export(arx_args_list) export(arx_class_args_list) +export(arx_class_epi_workflow) export(arx_classifier) +export(arx_fcast_epi_workflow) export(arx_forecaster) -export(arxc_epi_workflow_template) -export(arxf_epi_workflow_template) export(bake) export(create_layer) export(default_epi_recipe_blueprint) diff --git a/R/arx_classifier.R b/R/arx_classifier.R index 9e9570718..17e8a9c9d 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -11,13 +11,13 @@ #' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can #' also be used. #' @param args_list A list of customization arguments to determine -#' the type of forecasting model. See [arx_args_list()]. +#' the type of forecasting model. See [arx_class_args_list()]. #' #' @return A list with (1) `predictions` an `epi_df` of predicted classes #' and (2) `epi_workflow`, a list that encapsulates the entire estimation #' workflow #' @export -#' @seealso [arxc_epi_workflow_template()] +#' @seealso [arx_class_epi_workflow()], [arx_class_args_list()] #' #' @examples #' jhu <- case_death_rate_subset %>% @@ -45,7 +45,7 @@ arx_classifier <- function( if (!is_classification(trainer)) cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") - wf <- arxc_epi_workflow_template( + wf <- arx_class_epi_workflow( epi_data, outcome, predictors, trainer, args_list ) @@ -92,9 +92,9 @@ arx_classifier <- function( #' jhu <- case_death_rate_subset %>% #' dplyr::filter(time_value >= as.Date("2021-11-01")) #' -#' arxc_epi_workflow_template(jhu, "death_rate", c("case_rate", "death_rate")) +#' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate")) #' -#' arxc_epi_workflow_template( +#' arx_class_epi_workflow( #' jhu, #' "death_rate", #' c("case_rate", "death_rate"), @@ -104,7 +104,7 @@ arx_classifier <- function( #' horizon = 14, method = "linear_reg" #' ) #' ) -arxc_epi_workflow_template <- function( +arx_class_epi_workflow <- function( epi_data, outcome, predictors, diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 128ecdb6e..d883f3e2a 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -19,7 +19,7 @@ #' and (2) `epi_workflow`, a list that encapsulates the entire estimation #' workflow #' @export -#' @seealso [arxf_epi_workflow_template()] +#' @seealso [arx_fcast_epi_workflow()], [arx_args_list()] #' #' @examples #' jhu <- case_death_rate_subset %>% @@ -40,7 +40,7 @@ arx_forecaster <- function(epi_data, if (!is_regression(trainer)) cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") - wf <- arxf_epi_workflow_template( + wf <- arx_fcast_epi_workflow( epi_data, outcome, predictors, trainer, args_list ) @@ -84,13 +84,13 @@ arx_forecaster <- function(epi_data, #' jhu <- case_death_rate_subset %>% #' dplyr::filter(time_value >= as.Date("2021-12-01")) #' -#' arxf_epi_workflow_template(jhu, "death_rate", +#' arx_fcast_epi_workflow(jhu, "death_rate", #' c("case_rate", "death_rate")) #' -#' arxf_epi_workflow_template(jhu, "death_rate", +#' arx_fcast_epi_workflow(jhu, "death_rate", #' c("case_rate", "death_rate"), trainer = quantile_reg(), #' args_list = arx_args_list(levels = 1:9 / 10)) -arxf_epi_workflow_template <- function( +arx_fcast_epi_workflow <- function( epi_data, outcome, predictors, diff --git a/man/arxc_epi_workflow_template.Rd b/man/arx_class_epi_workflow.Rd similarity index 86% rename from man/arxc_epi_workflow_template.Rd rename to man/arx_class_epi_workflow.Rd index 68be2af79..aaaf20cc9 100644 --- a/man/arxc_epi_workflow_template.Rd +++ b/man/arx_class_epi_workflow.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/arx_classifier.R -\name{arxc_epi_workflow_template} -\alias{arxc_epi_workflow_template} +\name{arx_class_epi_workflow} +\alias{arx_class_epi_workflow} \title{Create a template \code{arx_classifier} workflow} \usage{ -arxc_epi_workflow_template( +arx_class_epi_workflow( epi_data, outcome, predictors, @@ -28,7 +28,7 @@ trainers like \code{\link[parsnip:naive_Bayes]{parsnip::naive_Bayes()}} or \code also be used. May be \code{NULL} (the default).} \item{args_list}{A list of customization arguments to determine -the type of forecasting model. See \code{\link[=arx_args_list]{arx_args_list()}}.} +the type of forecasting model. See \code{\link[=arx_class_args_list]{arx_class_args_list()}}.} } \value{ An unfit \code{epi_workflow}. @@ -44,9 +44,9 @@ may alter the returned \code{epi_workflow} object but can be omitted. jhu <- case_death_rate_subset \%>\% dplyr::filter(time_value >= as.Date("2021-11-01")) -arxc_epi_workflow_template(jhu, "death_rate", c("case_rate", "death_rate")) +arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate")) -arxc_epi_workflow_template( +arx_class_epi_workflow( jhu, "death_rate", c("case_rate", "death_rate"), diff --git a/man/arx_classifier.Rd b/man/arx_classifier.Rd index 69ab2d527..b6227a5a6 100644 --- a/man/arx_classifier.Rd +++ b/man/arx_classifier.Rd @@ -28,7 +28,7 @@ trainers like \code{\link[parsnip:naive_Bayes]{parsnip::naive_Bayes()}} or \code also be used.} \item{args_list}{A list of customization arguments to determine -the type of forecasting model. See \code{\link[=arx_args_list]{arx_args_list()}}.} +the type of forecasting model. See \code{\link[=arx_class_args_list]{arx_class_args_list()}}.} } \value{ A list with (1) \code{predictions} an \code{epi_df} of predicted classes @@ -58,5 +58,5 @@ out <- arx_classifier( ) } \seealso{ -\code{\link[=arxc_epi_workflow_template]{arxc_epi_workflow_template()}} +\code{\link[=arx_class_epi_workflow]{arx_class_epi_workflow()}}, \code{\link[=arx_class_args_list]{arx_class_args_list()}} } diff --git a/man/arxf_epi_workflow_template.Rd b/man/arx_fcast_epi_workflow.Rd similarity index 88% rename from man/arxf_epi_workflow_template.Rd rename to man/arx_fcast_epi_workflow.Rd index 6c8d87878..fdd309959 100644 --- a/man/arxf_epi_workflow_template.Rd +++ b/man/arx_fcast_epi_workflow.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/arx_forecaster.R -\name{arxf_epi_workflow_template} -\alias{arxf_epi_workflow_template} +\name{arx_fcast_epi_workflow} +\alias{arx_fcast_epi_workflow} \title{Create a template \code{arx_forecaster} workflow} \usage{ -arxf_epi_workflow_template( +arx_fcast_epi_workflow( epi_data, outcome, predictors, @@ -41,10 +41,10 @@ use \code{\link[=quantile_reg]{quantile_reg()}}) but can be omitted. jhu <- case_death_rate_subset \%>\% dplyr::filter(time_value >= as.Date("2021-12-01")) -arxf_epi_workflow_template(jhu, "death_rate", +arx_fcast_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate")) -arxf_epi_workflow_template(jhu, "death_rate", +arx_fcast_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"), trainer = quantile_reg(), args_list = arx_args_list(levels = 1:9 / 10)) } diff --git a/man/arx_forecaster.Rd b/man/arx_forecaster.Rd index 7b409737f..d4866aa0e 100644 --- a/man/arx_forecaster.Rd +++ b/man/arx_forecaster.Rd @@ -49,5 +49,5 @@ out <- arx_forecaster(jhu, "death_rate", args_list = arx_args_list(levels = 1:9 / 10)) } \seealso{ -\code{\link[=arxf_epi_workflow_template]{arxf_epi_workflow_template()}} +\code{\link[=arx_fcast_epi_workflow]{arx_fcast_epi_workflow()}}, \code{\link[=arx_args_list]{arx_args_list()}} } From 7b4c39d15248843bf1bced325469fb7fe8144d30 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 15 Jun 2023 15:05:06 -0700 Subject: [PATCH 5/5] tests pass, remove df_mat_mul() (not sure how it got here) --- NAMESPACE | 1 - man/df_mat_mul.Rd | 35 ------------------- {R => musings}/df_mat_mul.R | 0 {tests/testthat => musings}/test-df_mat_mul.R | 0 tests/testthat/test-arx_args_list.R | 2 +- tests/testthat/test-arx_cargs_list.R | 2 +- 6 files changed, 2 insertions(+), 38 deletions(-) delete mode 100644 man/df_mat_mul.Rd rename {R => musings}/df_mat_mul.R (100%) rename {tests/testthat => musings}/test-df_mat_mul.R (100%) diff --git a/NAMESPACE b/NAMESPACE index d08a09807..2833d8aca 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -102,7 +102,6 @@ export(bake) export(create_layer) export(default_epi_recipe_blueprint) export(detect_layer) -export(df_mat_mul) export(dist_quantiles) export(epi_keys) export(epi_recipe) diff --git a/man/df_mat_mul.Rd b/man/df_mat_mul.Rd deleted file mode 100644 index 57596dd2d..000000000 --- a/man/df_mat_mul.Rd +++ /dev/null @@ -1,35 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/df_mat_mul.R -\name{df_mat_mul} -\alias{df_mat_mul} -\title{Multiply columns of a \code{data.frame} by a matrix} -\usage{ -df_mat_mul(dat, mat, out_names = "out", ...) -} -\arguments{ -\item{dat}{A data.frame} - -\item{mat}{A matrix} - -\item{out_names}{Character vector. Creates the names of the resulting -columns after multiplication. If a scalar, this is treated as a -prefix and the remaining columns will be numbered sequentially.} - -\item{...}{<\code{\link[dplyr:dplyr_tidy_select]{tidy-select}}> One or more unquoted -expressions separated by commas. Variable names can be used as if they -were positions in the data frame, so expressions like \code{x:y} can -be used to select a range of variables.} -} -\value{ -A data.frame with the new columns at the right. Original -columns are removed. -} -\description{ -Multiply columns of a \code{data.frame} by a matrix -} -\examples{ -df <- data.frame(matrix(1:200, ncol = 10)) -mat <- matrix(1:10, ncol = 2) -df_mat_mul(df, mat, "z", dplyr::num_range("X", 2:6)) -} -\keyword{internal} diff --git a/R/df_mat_mul.R b/musings/df_mat_mul.R similarity index 100% rename from R/df_mat_mul.R rename to musings/df_mat_mul.R diff --git a/tests/testthat/test-df_mat_mul.R b/musings/test-df_mat_mul.R similarity index 100% rename from tests/testthat/test-df_mat_mul.R rename to musings/test-df_mat_mul.R diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R index 834dd8996..25e3194de 100644 --- a/tests/testthat/test-arx_args_list.R +++ b/tests/testthat/test-arx_args_list.R @@ -1,5 +1,5 @@ test_that("arx_args checks inputs", { - expect_s3_class(arx_args_list(), "arx_flist") + expect_s3_class(arx_args_list(), c("arx_fcast", "alist")) expect_error(arx_args_list(ahead = c(0, 4))) expect_error(arx_args_list(n_training = c(28, 65))) diff --git a/tests/testthat/test-arx_cargs_list.R b/tests/testthat/test-arx_cargs_list.R index 699bba94b..40035890d 100644 --- a/tests/testthat/test-arx_cargs_list.R +++ b/tests/testthat/test-arx_cargs_list.R @@ -1,5 +1,5 @@ test_that("arx_class_args checks inputs", { - expect_s3_class(arx_class_args_list(), "arx_clist") + expect_s3_class(arx_class_args_list(), c("arx_class", "alist")) expect_error(arx_class_args_list(ahead = c(0, 4))) expect_error(arx_class_args_list(n_training = c(28, 65)))