Skip to content

Slather access workflow and test data & update layer_add_target_date() and layer_add_forecast_date() accordingly #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 12, 2023
Merged
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ S3method(extract_layers,workflow)
S3method(extrapolate_quantiles,dist_default)
S3method(extrapolate_quantiles,dist_quantiles)
S3method(extrapolate_quantiles,distribution)
S3method(fit,epi_workflow)
S3method(format,dist_quantiles)
S3method(is.na,dist_quantiles)
S3method(is.na,distribution)
Expand Down
2 changes: 2 additions & 0 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ epi_recipe.epi_df <-
term_info = var_info,
steps = NULL,
template = x[1,],
max_time_value = max(x$time_value),
levels = NULL,
retained = NA
)
Expand Down Expand Up @@ -374,6 +375,7 @@ prep.epi_recipe <- function(
} else {
x$template <- training[0, ]
}
x$max_time_value <- max(training$time_value)
x$tr_info <- tr_data
x$levels <- lvls
x$orig_lvls <- orig_lvls
Expand Down
46 changes: 43 additions & 3 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,48 @@ is_epi_workflow <- function(x) {
inherits(x, "epi_workflow")
}

#' Fit an `epi_workflow` object
#'
#' @description
#' This is the `fit()` method for an `epi_workflow` object that
#' estimates parameters for a given model from a set of data.
#' Fitting an `epi_workflow` involves two main steps, which are
#' preprocessing the data and fitting the underlying parsnip model.
#'
#' @inheritParams workflows::fit.workflow
#'
#' @param object an `epi_workflow` object
#'
#' @param data an `epi_df` of predictors and outcomes to use when
#' fitting the `epi_workflow`
#'
#' @param control A [workflows::control_workflow()] object
#'
#' @return The `epi_workflow` object, updated with a fit parsnip
#' model in the `object$fit$fit` slot.
#'
#' @seealso workflows::fit-workflow
#'
#' @name fit-epi_workflow
#' @export
#' @examples
#' jhu <- case_death_rate_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
#'
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7)
#'
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
#' wf
#'
#' @export
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()){

object$fit$meta <- list(max_time_value = max(data$time_value), as_of = attributes(data)$metadata$as_of)

NextMethod()
}

#' Predict from an epi_workflow
#'
Expand Down Expand Up @@ -113,14 +155,12 @@ predict.epi_workflow <- function(object, new_data, ...) {
i = "Do you need to call `fit()`?"))
}
components <- list()
the_fit <- workflows::extract_fit_parsnip(object)
the_recipe <- workflows::extract_recipe(object)
components$mold <- workflows::extract_mold(object)
components$forged <- hardhat::forge(new_data,
blueprint = components$mold$blueprint)
components$keys <- grab_forged_keys(components$forged,
components$mold, new_data)
components <- apply_frosting(object, components, the_fit, the_recipe, ...)
components <- apply_frosting(object, components, new_data, ...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we still need the_fit and the_recipe? Am I missing something in the changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The epi_workflow (object) contains the recipe and fit, so it seemed unnecessary to include them separately.

Copy link
Contributor

@dajmcdon dajmcdon Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need some more tests then. I would think this would fail for sure.

I see now. You modified it in apply_frosting.epi_workflow().

For the future, when PRs touch lots of files like this one, it's helpful for me if your initial comment goes through the logic of the things you do. That way I can figure out where to look, and if there are simplifications that happen across multiple files, it's easier for me to understand the idea.

components$predictions
}

Expand Down
6 changes: 4 additions & 2 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ apply_frosting.default <- function(workflow, components, ...) {
#' @importFrom rlang abort
#' @export
apply_frosting.epi_workflow <-
function(workflow, components, the_fit, the_recipe, ...) {
function(workflow, components, new_data, ...) {

the_fit <- workflows::extract_fit_parsnip(workflow)

if (!has_postprocessor(workflow)) {
components$predictions <- predict(
Expand Down Expand Up @@ -260,7 +262,7 @@ apply_frosting.epi_workflow <-

for (l in seq_along(layers)) {
la <- layers[[l]]
components <- slather(la, components, the_fit, the_recipe)
components <- slather(la, components, workflow, new_data)
}

return(components)
Expand Down
29 changes: 21 additions & 8 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
#' @param forecast_date The forecast date to add as a column to the `epi_df`.
#' For most cases, this should be specified in the form "yyyy-mm-dd". Note that
#' when the forecast date is left unspecified, it is set to the maximum time
#' value in the test data after any processing (ex. leads and lags) has been
#' applied.
#' value from the data used in pre-processing, fitting the model, and
#' postprocessing.
#' @param id a random id string
#'
#' @return an updated `frosting` postprocessor
#'
#' @details To use this function, either specify a forecast date or leave the
#' forecast date unspecifed here. In the latter case, the forecast date will
#' be set as the maximum time value in the processed test data. In any case,
#' when the forecast date is less than the most recent update date of the data
#' (ie. the `as_of` value), an appropriate warning will be thrown.
#' be set as the maximum time value from the data used in pre-processing,
#' fitting the model, and postprocessing. In any case, when the forecast date is
#' less than the maximum `as_of` value (from the data used pre-processing,
#' model fitting, and postprocessing), an appropriate warning will be thrown.
#'
#' @export
#' @examples
Expand All @@ -28,6 +29,13 @@
#' latest <- jhu %>%
#' dplyr::filter(time_value >= max(time_value) - 14)
#'
#' # Don't specify `forecast_date` (by default, this should be last date in latest)
#' f <- frosting() %>% layer_predict() %>%
#' layer_naomit(.pred)
#' wf0 <- wf %>% add_frosting(f)
#' p0 <- predict(wf0, latest)
#' p0
#'
#' # Specify a `forecast_date` that is greater than or equal to `as_of` date
#' f <- frosting() %>% layer_predict() %>%
#' layer_add_forecast_date(forecast_date = "2022-05-31") %>%
Expand Down Expand Up @@ -74,14 +82,19 @@ layer_add_forecast_date_new <- function(forecast_date, id) {
}

#' @export
slather.layer_add_forecast_date <- function(object, components, the_fit, the_recipe, ...) {
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {

if (is.null(object$forecast_date)) {
max_time_value <- max(components$keys$time_value)
max_time_value <- max(workflows::extract_preprocessor(workflow)$max_time_value,
workflow$fit$meta$max_time_value,
max(new_data$time_value))
object$forecast_date <- max_time_value
}
as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of
as_of_fit <- workflow$fit$meta$as_of
as_of_post <- attributes(new_data)$metadata$as_of

as_of_date <- as.Date(attributes(components$keys)$metadata$as_of)
as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post))

if (object$forecast_date < as_of_date) {
cli_warn(
Expand Down
66 changes: 47 additions & 19 deletions R/layer_add_target_date.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#' Postprocessing step to add the target date
#'
#' @param frosting a `frosting` postprocessor
#' @param target_date The target date to add as a column to the `epi_df`.
#' By default, this is the maximum `time_value` from the processed test
#' data plus `ahead`, where `ahead` has been specified in preprocessing
#' (most likely in `step_epi_ahead`). The user may override this with a
#' date of their own (that will usually be in the form "yyyy-mm-dd").
#' @param target_date The target date to add as a column to the
#' `epi_df`. If there's a forecast date specified in a layer, then
#' it is the forecast date plus `ahead` (from `step_epi_ahead` in
#' the `epi_recipe`). Otherwise, it is the maximum `time_value`
#' (from the data used in pre-processing, fitting the model, and
#' postprocessing) plus `ahead`, where `ahead` has been specified in
#' preprocessing. The user may override these by specifying a
#' target date of their own (of the form "yyyy-mm-dd").
#' @param id a random id string
#'
#' @return an updated `frosting` postprocessor
Expand All @@ -27,24 +30,35 @@
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
#' latest <- get_test_data(r, jhu)
#'
#' # Use ahead from preprocessing
#' # Use ahead + forecast date
#' f <- frosting() %>% layer_predict() %>%
#' layer_add_forecast_date(forecast_date = "2022-05-31") %>%
#' layer_add_target_date() %>%
#' layer_naomit(.pred)
#' wf1 <- wf %>% add_frosting(f)
#'
#' p <- predict(wf1, latest)
#' p
#'
#' # Override default behaviour by specifying own target date
#' f2 <- frosting() %>%
#' layer_predict() %>%
#' layer_add_target_date(target_date = "2022-01-08") %>%
#' # Use ahead + max time value from pre, fit, post
#' # which is the same if include `layer_add_forecast_date()`
#' f2 <- frosting() %>% layer_predict() %>%
#' layer_add_target_date() %>%
#' layer_naomit(.pred)
#' wf2 <- wf %>% add_frosting(f2)
#'
#' p2 <- predict(wf2, latest)
#' p2
#'
#' # Specify own target date
#' f3 <- frosting() %>%
#' layer_predict() %>%
#' layer_add_target_date(target_date = "2022-01-08") %>%
#' layer_naomit(.pred)
#' wf3 <- wf %>% add_frosting(f3)
#'
#' p3 <- predict(wf3, latest)
#' p3
layer_add_target_date <-
function(frosting, target_date = NULL, id = rand_id("add_target_date")) {
target_date <- arg_to_date(target_date, allow_null = TRUE)
Expand All @@ -63,18 +77,32 @@ layer_add_target_date_new <- function(id = id, target_date = target_date) {
}

#' @export
slather.layer_add_target_date <- function(object, components, the_fit, the_recipe, ...) {
slather.layer_add_target_date <- function(object, components, workflow, new_data, ...) {

if (is.null(object$target_date)) {
max_time_value <- max(components$keys$time_value)
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
the_recipe <- workflows::extract_recipe(workflow)
the_frosting <- extract_frosting(workflow)

if (is.null(ahead)){
stop("`ahead` must be specified in preprocessing.")
}
target_date = max_time_value + ahead
} else{
if (!is.null(object$target_date)) {
target_date = as.Date(object$target_date)
} else { # null target date case
if (detect_layer(the_frosting, "layer_add_forecast_date") &&
!is.null(extract_argument(the_frosting,
"layer_add_forecast_date", "forecast_date"))) {
forecast_date <- extract_argument(the_frosting,
"layer_add_forecast_date", "forecast_date")

ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")

target_date = forecast_date + ahead
} else {
max_time_value <- max(workflows::extract_preprocessor(workflow)$max_time_value,
workflow$fit$meta$max_time_value,
max(new_data$time_value))

ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")

target_date = max_time_value + ahead
}
}

components$predictions <- dplyr::bind_cols(components$predictions,
Expand Down
2 changes: 1 addition & 1 deletion R/layer_naomit.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ layer_naomit_new <- function(terms, id) {
}

#' @export
slather.layer_naomit <- function(object, components, the_fit, the_recipe, ...) {
slather.layer_naomit <- function(object, components, workflow, new_data, ...) {
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
2 changes: 1 addition & 1 deletion R/layer_point_from_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ layer_point_from_distn_new <- function(type, name, id) {

#' @export
slather.layer_point_from_distn <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {

dstn <- components$predictions$.pred
if (!inherits(dstn, "distribution")) {
Expand Down
2 changes: 1 addition & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ layer_population_scaling_new <-

#' @export
slather.layer_population_scaling <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {
stopifnot("Only one population column allowed for scaling" =
length(object$df_pop_col) == 1)

Expand Down
4 changes: 3 additions & 1 deletion R/layer_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ layer_predict_new <- function(type, opts, dots_list, id) {
}

#' @export
slather.layer_predict <- function(object, components, the_fit, the_recipe, ...) {
slather.layer_predict <- function(object, components, workflow, new_data, ...) {

the_fit <- workflows::extract_fit_parsnip(workflow)

components$predictions <- predict(
the_fit,
Expand Down
4 changes: 3 additions & 1 deletion R/layer_predictive_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) {

#' @export
slather.layer_predictive_distn <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {

the_fit <- workflows::extract_fit_parsnip(workflow)

m <- components$predictions$.pred
r <- grab_residuals(the_fit, components)
Expand Down
2 changes: 1 addition & 1 deletion R/layer_quantile_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ layer_quantile_distn_new <- function(levels, truncate, name, id) {

#' @export
slather.layer_quantile_distn <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {

dstn <- components$predictions$.pred
if (!inherits(dstn, "distribution")) {
Expand Down
5 changes: 4 additions & 1 deletion R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ layer_residual_quantiles_new <- function(probs, symmetrize, by_key, name, id) {

#' @export
slather.layer_residual_quantiles <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {

the_fit <- workflows::extract_fit_parsnip(workflow)

if (is.null(object$probs)) return(components)

s <- ifelse(object$symmetrize, -1, NA)
Expand Down
2 changes: 1 addition & 1 deletion R/layer_threshold_preds.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ snap.dist_quantiles <- function(x, lower, upper, ...) {

#' @export
slather.layer_threshold <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
2 changes: 1 addition & 1 deletion R/layer_unnest.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ layer_unnest_new <- function(terms, id) {

#' @export
slather.layer_unnest <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
7 changes: 4 additions & 3 deletions R/layers.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,14 @@ detect_layer.workflow <- function(x, name, ...) {
#' `keys`).
#' * `keys` - we put the keys (`time_value`, `geo_value`, and any others)
#' here for ease.
#' @param the_fit the fitted model object as returned by calling `parsnip::fit()`
#' @param the_recipe the `epi_recipe` preprocessor
#' @param workflow an object of class workflow
#' @param new_data a data frame containing the new predictors to preprocess
#' and predict on
#'
#' @param ... additional arguments used by methods. Currently unused.
#'
#' @return The `components` list. In the same format after applying any updates.
#' @export
slather <- function(object, components, the_fit, the_recipe, ...) {
slather <- function(object, components, workflow, new_data, ...) {
UseMethod("slather")
}
2 changes: 1 addition & 1 deletion inst/templates/layer.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) {

#' @export
slather.layer_{{{ name }}} <-
function(object, components, the_fit, the_recipe, ...) {
function(object, components, workflow, new_data, ...) {

# if layer_ used ... in tidyselect, we need to evaluate it now
exprs <- rlang::expr(c(!!!object$terms))
Expand Down
Loading