Skip to content

Make layer_predict forward stored dots_list to predict() #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.16
Version: 0.0.17
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ import(parsnip)
import(recipes)
importFrom(checkmate,assert)
importFrom(checkmate,assert_character)
importFrom(checkmate,assert_class)
importFrom(checkmate,assert_date)
importFrom(checkmate,assert_function)
importFrom(checkmate,assert_int)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Revise `compat-purrr` to use the r-lang `standalone-*` version (via
`{usethis}`)
- `epi_recipe()` will now warn when given non-`epi_df` data
- `layer_predict()` and `predict.epi_workflow()` will now appropriately forward
`...` args intended for `predict.model_fit()`
- `bake.epi_recipe()` will now re-infer the geo and time type in case baking the
steps has changed the appropriate values
2 changes: 1 addition & 1 deletion R/arx_classifier.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Direct autoregressive classifier with covariates
#'
#' This is an autoregressive classification model for
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning
#' that it estimates a class at a particular target horizon.
#'
#' @inheritParams arx_forecaster
Expand Down
2 changes: 1 addition & 1 deletion R/arx_forecaster.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Direct autoregressive forecaster with covariates
#'
#' This is an autoregressive forecasting model for
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning
#' that it estimates a model for a particular target horizon.
#'
#'
Expand Down
4 changes: 2 additions & 2 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Predict the future with the most recent value
#'
#' This is a simple forecasting model for
#' [epiprocess::epi_df] data. It uses the most recent observation as the
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the
#' forecast for any future date, and produces intervals by shuffling the quantiles
#' of the residuals of such a "flatline" forecast and incrementing these
#' forward over all available training data.
Expand All @@ -12,7 +12,7 @@
#' This forecaster is meant to produce exactly the CDC Baseline used for
#' [COVID19ForecastHub](https://covid19forecasthub.org)
#'
#' @param epi_data An [`epiprocess::epi_df`]
#' @param epi_data An [`epiprocess::epi_df`][epiprocess::as_epi_df]
#' @param outcome A scalar character for the column name we wish to predict.
#' @param args_list A list of additional arguments as created by the
#' [cdc_baseline_args_list()] constructor function.
Expand Down
2 changes: 1 addition & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

#' Subset of Statistics Canada median employment income for postsecondary graduates
#'
#' @format An [epiprocess::epi_df] with 10193 rows and 8 variables:
#' @format An [epiprocess::epi_df][epiprocess::as_epi_df] with 10193 rows and 8 variables:
#' \describe{
#' \item{geo_value}{The province in Canada associated with each
#' row of measurements.}
Expand Down
10 changes: 7 additions & 3 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ is_epi_recipe <- function(x) {
#' @details
#' `add_epi_recipe` has the same behaviour as
#' [workflows::add_recipe()] but sets a different
#' default blueprint to automatically handle [epiprocess::epi_df] data.
#' default blueprint to automatically handle [epiprocess::epi_df][epiprocess::as_epi_df] data.
#'
#' @param x A `workflow` or `epi_workflow`
#'
Expand Down Expand Up @@ -572,9 +572,13 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
}
new_data <- NextMethod("bake")
if (!is.null(meta)) {
# Baking should have dropped epi_df-ness and metadata. Re-infer some
# metadata and assume others remain the same as the object/template:
new_data <- as_epi_df(
new_data, meta$geo_type, meta$time_type, meta$as_of,
meta$additional_metadata %||% list()
new_data,
as_of = meta$as_of,
# avoid NULL if meta is from saved older epi_df:
additional_metadata = meta$additional_metadata %||% list()
)
}
new_data
Expand Down
10 changes: 5 additions & 5 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
#' - Call [parsnip::predict.model_fit()] for you using the underlying fit
#' parsnip model.
#'
#' - Ensure that the returned object is an [epiprocess::epi_df] where
#' - Ensure that the returned object is an [epiprocess::epi_df][epiprocess::as_epi_df] where
#' possible. Specifically, the output will have `time_value` and
#' `geo_value` columns as well as the prediction.
#'
#' @inheritParams parsnip::predict.model_fit
#'
#' @param object An epi_workflow that has been fit by
#' [workflows::fit.workflow()]
#'
#' @param new_data A data frame containing the new predictors to preprocess
#' and predict on
#'
#' @inheritParams parsnip::predict.model_fit
#'
#' @return
#' A data frame of model predictions, with as many rows as `new_data` has.
#' If `new_data` is an `epi_df` or a data frame with `time_value` or
Expand All @@ -152,7 +152,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
#'
#' preds <- predict(wf, latest)
#' preds
predict.epi_workflow <- function(object, new_data, ...) {
predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) {
if (!workflows::is_trained_workflow(object)) {
cli::cli_abort(c(
"Can't predict on an untrained epi_workflow.",
Expand All @@ -168,7 +168,7 @@ predict.epi_workflow <- function(object, new_data, ...) {
components$forged,
components$mold, new_data
)
components <- apply_frosting(object, components, new_data, ...)
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
components$predictions
}

Expand Down
2 changes: 1 addition & 1 deletion R/epipredict-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @importFrom cli cli_abort
#' @importFrom checkmate assert assert_character assert_int assert_scalar
#' assert_logical assert_numeric assert_number assert_integer
#' assert_integerish assert_date assert_function
#' assert_integerish assert_date assert_function assert_class
#' @import epiprocess parsnip
## usethis namespace: end
NULL
4 changes: 2 additions & 2 deletions R/flatline_forecaster.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Predict the future with today's value
#'
#' This is a simple forecasting model for
#' [epiprocess::epi_df] data. It uses the most recent observation as the
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the
#' forcast for any future date, and produces intervals based on the quantiles
#' of the residuals of such a "flatline" forecast over all available training
#' data.
Expand All @@ -13,7 +13,7 @@
#' This forecaster is very similar to that used by the
#' [COVID19ForecastHub](https://covid19forecasthub.org)
#'
#' @param epi_data An [epiprocess::epi_df]
#' @param epi_data An [epiprocess::epi_df][epiprocess::as_epi_df]
#' @param outcome A scalar character for the column name we wish to predict.
#' @param args_list A list of dditional arguments as created by the
#' [flatline_args_list()] constructor function.
Expand Down
26 changes: 23 additions & 3 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,11 @@ apply_frosting.default <- function(workflow, components, ...) {
#' @rdname apply_frosting
#' @importFrom rlang is_null
#' @importFrom rlang abort
#' @param type,opts forwarded (along with `...`) to [`predict.model_fit()`] and
#' [`slather()`] for supported layers
#' @export
apply_frosting.epi_workflow <-
function(workflow, components, new_data, ...) {
function(workflow, components, new_data, type = NULL, opts = list(), ...) {
the_fit <- workflows::extract_fit_parsnip(workflow)

if (!has_postprocessor(workflow)) {
Expand All @@ -376,7 +378,7 @@ apply_frosting.epi_workflow <-
"Returning unpostprocessed predictions."
))
components$predictions <- predict(
the_fit, components$forged$predictors, ...
the_fit, components$forged$predictors, type, opts, ...
)
components$predictions <- dplyr::bind_cols(
components$keys, components$predictions
Expand All @@ -397,10 +399,28 @@ apply_frosting.epi_workflow <-
layers
)
}
if (length(layers) > 1L &&
(!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) {
cli_abort("
Passing `type`, `opts`, or `...` into `predict.epi_workflow()` is not
supported if you have frosting layers other than `layer_predict`. Please
provide these arguments earlier (i.e. while constructing the frosting
object) by passing them into an explicit call to `layer_predict(), and
adjust the remaining layers to account for resulting differences in
output format from these settings.
", class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers")
}

for (l in seq_along(layers)) {
la <- layers[[l]]
components <- slather(la, components, workflow, new_data)
if (inherits(la, "layer_predict")) {
components <- slather(la, components, workflow, new_data, type = type, opts = opts, ...)
} else {
# The check above should ensure we have default `type` and `opts`, and
# empty `...`; don't forward these default `type` and `opts`, to avoid
# upsetting some slather method validation.
components <- slather(la, components, workflow, new_data)
}
}

return(components)
Expand Down
2 changes: 1 addition & 1 deletion R/get_test_data.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Get test data for prediction based on longest lag period
#'
#' Based on the longest lag period in the recipe,
#' `get_test_data()` creates an [epi_df]
#' `get_test_data()` creates an [epi_df][epiprocess::as_epi_df]
#' with columns `geo_value`, `time_value`
#' and other variables in the original dataset,
#' which will be used to create features necessary to produce forecasts.
Expand Down
1 change: 1 addition & 0 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ layer_add_forecast_date_new <- function(forecast_date, id) {

#' @export
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
if (is.null(object$forecast_date)) {
max_time_value <- as.Date(max(
workflows::extract_preprocessor(workflow)$max_time_value,
Expand Down
1 change: 1 addition & 0 deletions R/layer_naomit.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ layer_naomit_new <- function(terms, id) {

#' @export
slather.layer_naomit <- function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
2 changes: 1 addition & 1 deletion R/layer_point_from_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ layer_point_from_distn_new <- function(type, name, id) {
#' @export
slather.layer_point_from_distn <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
dstn <- components$predictions$.pred
if (!inherits(dstn, "distribution")) {
rlang::warn(
Expand All @@ -86,6 +85,7 @@ slather.layer_point_from_distn <-
)
return(components)
}
rlang::check_dots_empty()

dstn <- match.fun(object$type)(dstn)
if (is.null(object$name)) {
Expand Down
2 changes: 1 addition & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ layer_population_scaling_new <-
#' @export
slather.layer_population_scaling <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
stopifnot(
"Only one population column allowed for scaling" =
length(object$df_pop_col) == 1
)
rlang::check_dots_empty()

if (is.null(object$by)) {
object$by <- intersect(
Expand Down
30 changes: 25 additions & 5 deletions R/layer_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,19 @@ layer_predict <-
id = rand_id("predict_default")) {
arg_is_chr_scalar(id)
arg_is_chr_scalar(type, allow_null = TRUE)
assert_class(opts, "list")
dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE)
if (any(rlang::names2(dots_list) == "")) {
cli_abort("All `...` arguments must be named.",
class = "epipredict__layer_predict__unnamed_dot"
)
}
add_layer(
frosting,
layer_predict_new(
type = type,
opts = opts,
dots_list = rlang::list2(...), # can't figure how to use this
dots_list = dots_list,
id = id
)
)
Expand All @@ -62,14 +69,27 @@ layer_predict_new <- function(type, opts, dots_list, id) {
}

#' @export
slather.layer_predict <- function(object, components, workflow, new_data, ...) {
slather.layer_predict <- function(object, components, workflow, new_data, type = NULL, opts = list(), ...) {
arg_is_chr_scalar(type, allow_null = TRUE)
if (!is.null(object$type) && !is.null(type) && !identical(object$type, type)) {
cli_abort("
Conflicting `type` settings were specified during frosting construction
(in call to `layer_predict()`) and while slathering (in call to
`slather()`/ `predict()`/etc.): {object$type} vs. {type}. Please remove
one of these `type` settings.
", class = "epipredict__layer_predict__conflicting_type_settings")
}
assert_class(opts, "list")

the_fit <- workflows::extract_fit_parsnip(workflow)

components$predictions <- predict(
components$predictions <- rlang::inject(predict(
the_fit,
components$forged$predictors,
type = object$type, opts = object$opts
)
type = object$type %||% type,
opts = c(object$opts, opts),
!!!object$dots_list, ...
))
components$predictions <- dplyr::bind_cols(
components$keys, components$predictions
)
Expand Down
1 change: 1 addition & 0 deletions R/layer_predictive_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) {
slather.layer_predictive_distn <-
function(object, components, workflow, new_data, ...) {
the_fit <- workflows::extract_fit_parsnip(workflow)
rlang::check_dots_empty()

m <- components$predictions$.pred
r <- grab_residuals(the_fit, components)
Expand Down
2 changes: 2 additions & 0 deletions R/layer_quantile_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ slather.layer_quantile_distn <-
"These are of class {.cls {class(dstn)}}."
))
}
rlang::check_dots_empty()

dstn <- dist_quantiles(
quantile(dstn, object$quantile_levels),
object$quantile_levels
Expand Down
2 changes: 2 additions & 0 deletions R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ layer_residual_quantiles_new <- function(
#' @export
slather.layer_residual_quantiles <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()

the_fit <- workflows::extract_fit_parsnip(workflow)

if (is.null(object$quantile_levels)) {
Expand Down
1 change: 1 addition & 0 deletions R/layer_threshold_preds.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ snap.dist_quantiles <- function(x, lower, upper, ...) {
#' @export
slather.layer_threshold <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
1 change: 1 addition & 0 deletions R/layer_unnest.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ layer_unnest_new <- function(terms, id) {
#' @export
slather.layer_unnest <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
1 change: 1 addition & 0 deletions inst/templates/layer.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) {
#' @export
slather.layer_{{{ name }}} <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()

# if layer_ used ... in tidyselect, we need to evaluate it now
exprs <- rlang::expr(c(!!!object$terms))
Expand Down
5 changes: 4 additions & 1 deletion man/apply_frosting.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/get_test_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading