From 40a7508348c5fa511916e7e50a602aa732533f0a Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 17 Jul 2023 11:20:10 -0700 Subject: [PATCH 01/15] Add new_data to slather signature --- R/epi_workflow.R | 2 +- R/frosting.R | 4 ++-- R/layer_naomit.R | 3 ++- R/layers.R | 4 +++- man/apply_frosting.Rd | 5 ++++- man/create_layer.Rd | 7 ++----- man/slather.Rd | 5 ++++- 7 files changed, 18 insertions(+), 12 deletions(-) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index c3eb409e6..2f18a4e5a 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -120,7 +120,7 @@ predict.epi_workflow <- function(object, new_data, ...) { blueprint = components$mold$blueprint) components$keys <- grab_forged_keys(components$forged, components$mold, new_data) - components <- apply_frosting(object, components, the_fit, the_recipe, ...) + components <- apply_frosting(object, components, the_fit, the_recipe, new_data, ...) components$predictions } diff --git a/R/frosting.R b/R/frosting.R index c8f30a2cb..41546d57b 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -227,7 +227,7 @@ apply_frosting.default <- function(workflow, components, ...) { #' @importFrom rlang abort #' @export apply_frosting.epi_workflow <- - function(workflow, components, the_fit, the_recipe, ...) { + function(workflow, components, the_fit, the_recipe, new_data, ...) { if (!has_postprocessor(workflow)) { components$predictions <- predict( @@ -260,7 +260,7 @@ apply_frosting.epi_workflow <- for (l in seq_along(layers)) { la <- layers[[l]] - components <- slather(la, components, the_fit, the_recipe) + components <- slather(la, components, the_fit, the_recipe, new_data) } return(components) diff --git a/R/layer_naomit.R b/R/layer_naomit.R index 24c276507..86ce46e6b 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -46,7 +46,8 @@ layer_naomit_new <- function(terms, id) { } #' @export -slather.layer_naomit <- function(object, components, the_fit, the_recipe, ...) { +slather.layer_naomit <- function(object, components, the_fit, the_recipe, new_data, ...) { + newd1 <<- new_data exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layers.R b/R/layers.R index 781c96089..64d481bb6 100644 --- a/R/layers.R +++ b/R/layers.R @@ -125,11 +125,13 @@ detect_layer.workflow <- function(x, name, ...) { #' here for ease. #' @param the_fit the fitted model object as returned by calling `parsnip::fit()` #' @param the_recipe the `epi_recipe` preprocessor +#' @param new_data A data frame containing the new predictors to preprocess +#' and predict on #' #' @param ... additional arguments used by methods. Currently unused. #' #' @return The `components` list. In the same format after applying any updates. #' @export -slather <- function(object, components, the_fit, the_recipe, ...) { +slather <- function(object, components, the_fit, the_recipe, new_data, ...) { UseMethod("slather") } diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index c9b636f45..b0d685afb 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -11,7 +11,7 @@ apply_frosting(workflow, ...) \method{apply_frosting}{default}(workflow, components, ...) -\method{apply_frosting}{epi_workflow}(workflow, components, the_fit, the_recipe, ...) +\method{apply_frosting}{epi_workflow}(workflow, components, the_fit, the_recipe, new_data, ...) } \arguments{ \item{workflow}{An object of class workflow} @@ -35,6 +35,9 @@ here for ease. \item{the_fit}{the fitted model object as returned by calling \code{parsnip::fit()}} \item{the_recipe}{the \code{epi_recipe} preprocessor} + +\item{new_data}{A data frame containing the new predictors to preprocess +and predict on} } \description{ This function is intended for internal use. It implements postprocessing diff --git a/man/create_layer.Rd b/man/create_layer.Rd index 81f5e33b0..399d62efa 100644 --- a/man/create_layer.Rd +++ b/man/create_layer.Rd @@ -7,11 +7,8 @@ create_layer(name = NULL, open = rlang::is_interactive()) } \arguments{ -\item{name}{Either a name without extension, or \code{NULL} to create the -paired file based on currently open file in the script editor. If -the \verb{R/} file is open, \code{use_test()} will create/open the corresponding -test file; if the test file is open, \code{use_r()} will create/open the -corresponding \verb{R/} file.} +\item{name}{Either a string giving a file name (without directory) or +\code{NULL} to take the name from the currently open file in RStudio.} \item{open}{Whether to open the file for interactive editing.} } diff --git a/man/slather.Rd b/man/slather.Rd index 8fb357e2a..f40784610 100644 --- a/man/slather.Rd +++ b/man/slather.Rd @@ -4,7 +4,7 @@ \alias{slather} \title{Spread a layer of frosting on a fitted workflow} \usage{ -slather(object, components, the_fit, the_recipe, ...) +slather(object, components, the_fit, the_recipe, new_data, ...) } \arguments{ \item{object}{a workflow with \code{frosting} postprocessing steps} @@ -27,6 +27,9 @@ here for ease. \item{the_recipe}{the \code{epi_recipe} preprocessor} +\item{new_data}{A data frame containing the new predictors to preprocess +and predict on} + \item{...}{additional arguments used by methods. Currently unused.} } \value{ From 827d106c998911f30bb18e058d9ea1a127ebb86b Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 17 Jul 2023 12:08:07 -0700 Subject: [PATCH 02/15] Add code to include workflow in slather signature --- R/frosting.R | 2 +- R/layer_naomit.R | 5 ++++- R/layers.R | 5 +++-- man/apply_frosting.Rd | 2 +- man/slather.Rd | 6 ++++-- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/R/frosting.R b/R/frosting.R index 41546d57b..0a6f841eb 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -260,7 +260,7 @@ apply_frosting.epi_workflow <- for (l in seq_along(layers)) { la <- layers[[l]] - components <- slather(la, components, the_fit, the_recipe, new_data) + components <- slather(la, components, the_fit, the_recipe, workflow, new_data) } return(components) diff --git a/R/layer_naomit.R b/R/layer_naomit.R index 86ce46e6b..597308cad 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -46,8 +46,11 @@ layer_naomit_new <- function(terms, id) { } #' @export -slather.layer_naomit <- function(object, components, the_fit, the_recipe, new_data, ...) { +slather.layer_naomit <- function(object, components, the_fit, the_recipe, workflow, new_data, ...) { newd1 <<- new_data + obj2 <<- object + compon <<- components + wf2 <<- workflow exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layers.R b/R/layers.R index 64d481bb6..1f3ec97e3 100644 --- a/R/layers.R +++ b/R/layers.R @@ -125,13 +125,14 @@ detect_layer.workflow <- function(x, name, ...) { #' here for ease. #' @param the_fit the fitted model object as returned by calling `parsnip::fit()` #' @param the_recipe the `epi_recipe` preprocessor -#' @param new_data A data frame containing the new predictors to preprocess +#' @param workflow an object of class workflow +#' @param new_data a data frame containing the new predictors to preprocess #' and predict on #' #' @param ... additional arguments used by methods. Currently unused. #' #' @return The `components` list. In the same format after applying any updates. #' @export -slather <- function(object, components, the_fit, the_recipe, new_data, ...) { +slather <- function(object, components, the_fit, the_recipe, workflow, new_data, ...) { UseMethod("slather") } diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index b0d685afb..6221e8f59 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -36,7 +36,7 @@ here for ease. \item{the_recipe}{the \code{epi_recipe} preprocessor} -\item{new_data}{A data frame containing the new predictors to preprocess +\item{new_data}{a data frame containing the new predictors to preprocess and predict on} } \description{ diff --git a/man/slather.Rd b/man/slather.Rd index f40784610..6e748372f 100644 --- a/man/slather.Rd +++ b/man/slather.Rd @@ -4,7 +4,7 @@ \alias{slather} \title{Spread a layer of frosting on a fitted workflow} \usage{ -slather(object, components, the_fit, the_recipe, new_data, ...) +slather(object, components, the_fit, the_recipe, workflow, new_data, ...) } \arguments{ \item{object}{a workflow with \code{frosting} postprocessing steps} @@ -27,7 +27,9 @@ here for ease. \item{the_recipe}{the \code{epi_recipe} preprocessor} -\item{new_data}{A data frame containing the new predictors to preprocess +\item{workflow}{an object of class workflow} + +\item{new_data}{a data frame containing the new predictors to preprocess and predict on} \item{...}{additional arguments used by methods. Currently unused.} From a16dd380f50952ac92400efff89c02e4a8468fef Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 17 Jul 2023 12:10:53 -0700 Subject: [PATCH 03/15] change wf test obj --- R/layer_naomit.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/layer_naomit.R b/R/layer_naomit.R index 597308cad..7c61fb7f9 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -50,7 +50,7 @@ slather.layer_naomit <- function(object, components, the_fit, the_recipe, workfl newd1 <<- new_data obj2 <<- object compon <<- components - wf2 <<- workflow + wf3 <<- workflow exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) From d9fc1688f017518b9a946649d3c1d335f07df5e1 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 17 Jul 2023 12:27:09 -0700 Subject: [PATCH 04/15] Remove instances of the_recipe, the_fit --- R/epi_workflow.R | 4 +--- R/frosting.R | 6 ++++-- R/layer_add_forecast_date.R | 2 +- R/layer_add_target_date.R | 4 +++- R/layer_naomit.R | 6 +----- R/layer_point_from_distn.R | 2 +- R/layer_population_scaling.R | 2 +- R/layer_predict.R | 4 +++- R/layer_predictive_distn.R | 4 +++- R/layer_quantile_distn.R | 2 +- R/layer_residual_quantiles.R | 5 ++++- R/layer_threshold_preds.R | 2 +- R/layer_unnest.R | 2 +- R/layers.R | 4 +--- man/apply_frosting.Rd | 6 +----- man/slather.Rd | 6 +----- 16 files changed, 28 insertions(+), 33 deletions(-) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 2f18a4e5a..ddd40fc9d 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -113,14 +113,12 @@ predict.epi_workflow <- function(object, new_data, ...) { i = "Do you need to call `fit()`?")) } components <- list() - the_fit <- workflows::extract_fit_parsnip(object) - the_recipe <- workflows::extract_recipe(object) components$mold <- workflows::extract_mold(object) components$forged <- hardhat::forge(new_data, blueprint = components$mold$blueprint) components$keys <- grab_forged_keys(components$forged, components$mold, new_data) - components <- apply_frosting(object, components, the_fit, the_recipe, new_data, ...) + components <- apply_frosting(object, components, new_data, ...) components$predictions } diff --git a/R/frosting.R b/R/frosting.R index 0a6f841eb..dc407d23b 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -227,7 +227,9 @@ apply_frosting.default <- function(workflow, components, ...) { #' @importFrom rlang abort #' @export apply_frosting.epi_workflow <- - function(workflow, components, the_fit, the_recipe, new_data, ...) { + function(workflow, components, new_data, ...) { + + the_fit <- workflows::extract_fit_parsnip(object) if (!has_postprocessor(workflow)) { components$predictions <- predict( @@ -260,7 +262,7 @@ apply_frosting.epi_workflow <- for (l in seq_along(layers)) { la <- layers[[l]] - components <- slather(la, components, the_fit, the_recipe, workflow, new_data) + components <- slather(la, components, workflow, new_data) } return(components) diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 34f34303e..d39e3612e 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -74,7 +74,7 @@ layer_add_forecast_date_new <- function(forecast_date, id) { } #' @export -slather.layer_add_forecast_date <- function(object, components, the_fit, the_recipe, ...) { +slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { if (is.null(object$forecast_date)) { max_time_value <- max(components$keys$time_value) diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index 3429b735e..aec425a24 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -63,7 +63,9 @@ layer_add_target_date_new <- function(id = id, target_date = target_date) { } #' @export -slather.layer_add_target_date <- function(object, components, the_fit, the_recipe, ...) { +slather.layer_add_target_date <- function(object, components, workflow, new_data, ...) { + + the_recipe <- workflows::extract_recipe(workflow) if (is.null(object$target_date)) { max_time_value <- max(components$keys$time_value) diff --git a/R/layer_naomit.R b/R/layer_naomit.R index 7c61fb7f9..ba1081e8d 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -46,11 +46,7 @@ layer_naomit_new <- function(terms, id) { } #' @export -slather.layer_naomit <- function(object, components, the_fit, the_recipe, workflow, new_data, ...) { - newd1 <<- new_data - obj2 <<- object - compon <<- components - wf3 <<- workflow +slather.layer_naomit <- function(object, components, workflow, new_data, ...) { exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layer_point_from_distn.R b/R/layer_point_from_distn.R index 7861cacf1..855d8b194 100644 --- a/R/layer_point_from_distn.R +++ b/R/layer_point_from_distn.R @@ -76,7 +76,7 @@ layer_point_from_distn_new <- function(type, name, id) { #' @export slather.layer_point_from_distn <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 2fb6dfa13..448812bd8 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -129,7 +129,7 @@ layer_population_scaling_new <- #' @export slather.layer_population_scaling <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { stopifnot("Only one population column allowed for scaling" = length(object$df_pop_col) == 1) diff --git a/R/layer_predict.R b/R/layer_predict.R index bfe8d219f..e60f0595c 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -62,7 +62,9 @@ layer_predict_new <- function(type, opts, dots_list, id) { } #' @export -slather.layer_predict <- function(object, components, the_fit, the_recipe, ...) { +slather.layer_predict <- function(object, components, workflow, new_data, ...) { + + the_fit <- workflows::extract_fit_parsnip(workflow) components$predictions <- predict( the_fit, diff --git a/R/layer_predictive_distn.R b/R/layer_predictive_distn.R index 9c03a488a..815ff883e 100644 --- a/R/layer_predictive_distn.R +++ b/R/layer_predictive_distn.R @@ -71,7 +71,9 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) { #' @export slather.layer_predictive_distn <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { + + the_fit <- workflows::extract_fit_parsnip(object) m <- components$predictions$.pred r <- grab_residuals(the_fit, components) diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index bc9f4a448..97d546ed1 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -71,7 +71,7 @@ layer_quantile_distn_new <- function(levels, truncate, name, id) { #' @export slather.layer_quantile_distn <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index d4edca457..fddc4ce20 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -72,7 +72,10 @@ layer_residual_quantiles_new <- function(probs, symmetrize, by_key, name, id) { #' @export slather.layer_residual_quantiles <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { + + the_fit <- workflows::extract_fit_parsnip(object) + if (is.null(object$probs)) return(components) s <- ifelse(object$symmetrize, -1, NA) diff --git a/R/layer_threshold_preds.R b/R/layer_threshold_preds.R index 8528d20e1..eb1cb0577 100644 --- a/R/layer_threshold_preds.R +++ b/R/layer_threshold_preds.R @@ -99,7 +99,7 @@ snap.dist_quantiles <- function(x, lower, upper, ...) { #' @export slather.layer_threshold <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layer_unnest.R b/R/layer_unnest.R index ef82e1308..8b545c9cd 100644 --- a/R/layer_unnest.R +++ b/R/layer_unnest.R @@ -27,7 +27,7 @@ layer_unnest_new <- function(terms, id) { #' @export slather.layer_unnest <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layers.R b/R/layers.R index 1f3ec97e3..ab3167609 100644 --- a/R/layers.R +++ b/R/layers.R @@ -123,8 +123,6 @@ detect_layer.workflow <- function(x, name, ...) { #' `keys`). #' * `keys` - we put the keys (`time_value`, `geo_value`, and any others) #' here for ease. -#' @param the_fit the fitted model object as returned by calling `parsnip::fit()` -#' @param the_recipe the `epi_recipe` preprocessor #' @param workflow an object of class workflow #' @param new_data a data frame containing the new predictors to preprocess #' and predict on @@ -133,6 +131,6 @@ detect_layer.workflow <- function(x, name, ...) { #' #' @return The `components` list. In the same format after applying any updates. #' @export -slather <- function(object, components, the_fit, the_recipe, workflow, new_data, ...) { +slather <- function(object, components, workflow, new_data, ...) { UseMethod("slather") } diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index 6221e8f59..fc01a3461 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -11,7 +11,7 @@ apply_frosting(workflow, ...) \method{apply_frosting}{default}(workflow, components, ...) -\method{apply_frosting}{epi_workflow}(workflow, components, the_fit, the_recipe, new_data, ...) +\method{apply_frosting}{epi_workflow}(workflow, components, new_data, ...) } \arguments{ \item{workflow}{An object of class workflow} @@ -32,10 +32,6 @@ have three components \code{predictors}, \code{outcomes} (if these were in the here for ease. }} -\item{the_fit}{the fitted model object as returned by calling \code{parsnip::fit()}} - -\item{the_recipe}{the \code{epi_recipe} preprocessor} - \item{new_data}{a data frame containing the new predictors to preprocess and predict on} } diff --git a/man/slather.Rd b/man/slather.Rd index 6e748372f..dd556b629 100644 --- a/man/slather.Rd +++ b/man/slather.Rd @@ -4,7 +4,7 @@ \alias{slather} \title{Spread a layer of frosting on a fitted workflow} \usage{ -slather(object, components, the_fit, the_recipe, workflow, new_data, ...) +slather(object, components, workflow, new_data, ...) } \arguments{ \item{object}{a workflow with \code{frosting} postprocessing steps} @@ -23,10 +23,6 @@ have three components \code{predictors}, \code{outcomes} (if these were in the here for ease. }} -\item{the_fit}{the fitted model object as returned by calling \code{parsnip::fit()}} - -\item{the_recipe}{the \code{epi_recipe} preprocessor} - \item{workflow}{an object of class workflow} \item{new_data}{a data frame containing the new predictors to preprocess From d99133710e58645aa4fbea30d966898986164c64 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 17 Jul 2023 12:31:28 -0700 Subject: [PATCH 05/15] Changed default forecast and target dates --- R/layer_add_forecast_date.R | 2 +- R/layer_add_target_date.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index d39e3612e..e8dbabd20 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -77,7 +77,7 @@ layer_add_forecast_date_new <- function(forecast_date, id) { slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { if (is.null(object$forecast_date)) { - max_time_value <- max(components$keys$time_value) + max_time_value <- max(new_data$time_value) object$forecast_date <- max_time_value } diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index aec425a24..29a3558df 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -68,7 +68,7 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data the_recipe <- workflows::extract_recipe(workflow) if (is.null(object$target_date)) { - max_time_value <- max(components$keys$time_value) + max_time_value <- max(new_data$time_value) ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") if (is.null(ahead)){ From bdb27e75480a6b3dc3c1d5f0abf41911fcfca84c Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 17 Jul 2023 12:45:15 -0700 Subject: [PATCH 06/15] Update tests and documentaiton --- R/frosting.R | 2 +- R/layer_add_forecast_date.R | 7 +++++++ man/layer_add_forecast_date.Rd | 7 +++++++ tests/testthat/test-layer_add_forecast_date.R | 2 +- tests/testthat/test-layer_add_target_date.R | 2 +- 5 files changed, 17 insertions(+), 3 deletions(-) diff --git a/R/frosting.R b/R/frosting.R index dc407d23b..f0c554935 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -229,7 +229,7 @@ apply_frosting.default <- function(workflow, components, ...) { apply_frosting.epi_workflow <- function(workflow, components, new_data, ...) { - the_fit <- workflows::extract_fit_parsnip(object) + the_fit <- workflows::extract_fit_parsnip(workflow) if (!has_postprocessor(workflow)) { components$predictions <- predict( diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index e8dbabd20..19d63ef24 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -28,6 +28,13 @@ #' latest <- jhu %>% #' dplyr::filter(time_value >= max(time_value) - 14) #' +#' # Don't specify `forecast_date` (by default, this should be last date in latest) +#' f <- frosting() %>% layer_predict() %>% +#' layer_naomit(.pred) +#' wf0 <- wf %>% add_frosting(f) +#' p0 <- predict(wf0, latest) +#' p0 +#' #' # Specify a `forecast_date` that is greater than or equal to `as_of` date #' f <- frosting() %>% layer_predict() %>% #' layer_add_forecast_date(forecast_date = "2022-05-31") %>% diff --git a/man/layer_add_forecast_date.Rd b/man/layer_add_forecast_date.Rd index 52af21866..96558e403 100644 --- a/man/layer_add_forecast_date.Rd +++ b/man/layer_add_forecast_date.Rd @@ -45,6 +45,13 @@ wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) latest <- jhu \%>\% dplyr::filter(time_value >= max(time_value) - 14) + # Don't specify `forecast_date` (by default, this should be last date in latest) +f <- frosting() \%>\% layer_predict() \%>\% + layer_naomit(.pred) +wf0 <- wf \%>\% add_frosting(f) +p0 <- predict(wf0, latest) +p0 + # Specify a `forecast_date` that is greater than or equal to `as_of` date f <- frosting() \%>\% layer_predict() \%>\% layer_add_forecast_date(forecast_date = "2022-05-31") \%>\% diff --git a/tests/testthat/test-layer_add_forecast_date.R b/tests/testthat/test-layer_add_forecast_date.R index 06d65da48..5d965e7b3 100644 --- a/tests/testthat/test-layer_add_forecast_date.R +++ b/tests/testthat/test-layer_add_forecast_date.R @@ -65,6 +65,6 @@ test_that("Do not specify a forecast_date in `layer_add_forecast_date()`", { expect_equal(ncol(p3), 4L) expect_s3_class(p3, "epi_df") expect_equal(nrow(p3), 3L) - expect_equal(p3$forecast_date, rep(as.Date("2022-01-14"), times = 3)) + expect_equal(p3$forecast_date, rep(as.Date("2021-12-31"), times = 3)) expect_named(p3, c("geo_value", "time_value", ".pred", "forecast_date")) }) diff --git a/tests/testthat/test-layer_add_target_date.R b/tests/testthat/test-layer_add_target_date.R index 2c80fe657..49ba1d997 100644 --- a/tests/testthat/test-layer_add_target_date.R +++ b/tests/testthat/test-layer_add_target_date.R @@ -19,7 +19,7 @@ test_that("Use ahead from preprocessing", { expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") expect_equal(nrow(p), 3L) - expect_equal(p$target_date, rep(as.Date("2022-01-21"), times = 3)) + expect_equal(p$target_date, rep(as.Date("2022-01-07"), times = 3)) expect_named(p, c("geo_value", "time_value", ".pred", "target_date")) }) From 47d7518c44536d660010ad914bbf8f0aeb42df01 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 17 Jul 2023 12:50:02 -0700 Subject: [PATCH 07/15] object to workflow --- R/layer_predictive_distn.R | 2 +- R/layer_residual_quantiles.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/layer_predictive_distn.R b/R/layer_predictive_distn.R index 815ff883e..c951d9ccd 100644 --- a/R/layer_predictive_distn.R +++ b/R/layer_predictive_distn.R @@ -73,7 +73,7 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) { slather.layer_predictive_distn <- function(object, components, workflow, new_data, ...) { - the_fit <- workflows::extract_fit_parsnip(object) + the_fit <- workflows::extract_fit_parsnip(workflow) m <- components$predictions$.pred r <- grab_residuals(the_fit, components) diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index fddc4ce20..c97525b41 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -74,7 +74,7 @@ layer_residual_quantiles_new <- function(probs, symmetrize, by_key, name, id) { slather.layer_residual_quantiles <- function(object, components, workflow, new_data, ...) { - the_fit <- workflows::extract_fit_parsnip(object) + the_fit <- workflows::extract_fit_parsnip(workflow) if (is.null(object$probs)) return(components) From 06c71d9dad8d4e4b92460917040223f0d2f7b6ee Mon Sep 17 00:00:00 2001 From: rachlobay Date: Sat, 22 Jul 2023 17:57:20 -0700 Subject: [PATCH 08/15] Doc. for fit.epi_workflow and trying different things for mtv --- NAMESPACE | 1 + R/epi_recipe.R | 2 ++ R/epi_workflow.R | 43 +++++++++++++++++++++++++++++++++++++ R/frosting.R | 18 ++++++++++++++++ R/layer_add_forecast_date.R | 4 ++-- man/fit-epi_workflow.Rd | 42 ++++++++++++++++++++++++++++++++++++ 6 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 man/fit-epi_workflow.Rd diff --git a/NAMESPACE b/NAMESPACE index e35124d55..de3a11e3e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -32,6 +32,7 @@ S3method(extract_layers,workflow) S3method(extrapolate_quantiles,dist_default) S3method(extrapolate_quantiles,dist_quantiles) S3method(extrapolate_quantiles,distribution) +S3method(fit,epi_workflow) S3method(format,dist_quantiles) S3method(is.na,dist_quantiles) S3method(is.na,distribution) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index a6fb616ed..9b8568689 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -131,6 +131,7 @@ epi_recipe.epi_df <- term_info = var_info, steps = NULL, template = x[1,], + mtv = max(x$time_value), #%% levels = NULL, retained = NA ) @@ -374,6 +375,7 @@ prep.epi_recipe <- function( } else { x$template <- training[0, ] } + x$mtv <- max(training$time_value) #%% x$tr_info <- tr_data x$levels <- lvls x$orig_lvls <- orig_lvls diff --git a/R/epi_workflow.R b/R/epi_workflow.R index ddd40fc9d..7952893ba 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -58,6 +58,47 @@ is_epi_workflow <- function(x) { inherits(x, "epi_workflow") } +#' Fit an `epi_workflow` object +#' +#' @description +#' This is the `fit()` method for an `epi_workflow` object that +#' estimates parameters for a given model from a set of data. +#' Fitting an `epi_workflow` involves two main steps, which are +#' preprocessing the data and fitting the underlying parsnip model. +#' +#' @inheritParams workflows::fit.workflow +#' +#' @param object an `epi_workflow` object +#' +#' @param x an `epi_df` of predictors and outcomes to use when +#' fitting the `epi_workflow` +#' +#' @return The `epi_workflow` object, updated with a fit parsnip +#' model in the `object$fit$fit` slot. +#' +#' @seealso workflows::fit-workflow +#' +#' @name fit-epi_workflow +#' @export +#' @examples +#' jhu <- case_death_rate_subset %>% +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' +#' r <- epi_recipe(jhu) %>% +#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_ahead(death_rate, ahead = 7) +#' +#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf +#' +#' @export +fit.epi_workflow <- function(object, x, ...){ + + object$fit$meta <- list(mtv = max(x$time_value)) + #object$fit$as_of <- attributes(x)$metadata$as_of + + NextMethod() +} #' Predict from an epi_workflow #' @@ -112,6 +153,7 @@ predict.epi_workflow <- function(object, new_data, ...) { c("Can't predict on an untrained epi_workflow.", i = "Do you need to call `fit()`?")) } + components <- list() components$mold <- workflows::extract_mold(object) components$forged <- hardhat::forge(new_data, @@ -119,6 +161,7 @@ predict.epi_workflow <- function(object, new_data, ...) { components$keys <- grab_forged_keys(components$forged, components$mold, new_data) components <- apply_frosting(object, components, new_data, ...) + components$predictions } diff --git a/R/frosting.R b/R/frosting.R index f0c554935..62be492cb 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -229,6 +229,9 @@ apply_frosting.default <- function(workflow, components, ...) { apply_frosting.epi_workflow <- function(workflow, components, new_data, ...) { + #%% wf1$post$meta <<- list(mtv = max(new_data$time_value)) #%% change wf1 and possibly <<- + # assign("workflow$post$meta", list(mtv = max(new_data$time_value)), envir = .GlobalEnv) + the_fit <- workflows::extract_fit_parsnip(workflow) if (!has_postprocessor(workflow)) { @@ -268,6 +271,21 @@ apply_frosting.epi_workflow <- return(components) } +#%% change_workflow = function(x){ +# assign(deparse(substitute(x)), "changed", env=.GlobalEnv) +#} + +#%% add_meta_post <- function(workflow, new_data){ +# +# workflow$post$meta <- list(mtv = max(new_data$time_value)) +# +# workflow +#} + +#%% changeMe = function(x){ +# assign(deparse(substitute(x)), "changed", env=.GlobalEnv) +#} + #' @export print.frosting <- function(x, form_width = 30, ...) { cli::cli_div( diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 19d63ef24..89ea2dd30 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -82,9 +82,9 @@ layer_add_forecast_date_new <- function(forecast_date, id) { #' @export slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { - + wf <<- workflow if (is.null(object$forecast_date)) { - max_time_value <- max(new_data$time_value) + max_time_value <- max(workflows::extract_preprocessor(wf)$mtv, wf$fit$meta$mtv, max(new_data$time_value))#wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value) object$forecast_date <- max_time_value } diff --git a/man/fit-epi_workflow.Rd b/man/fit-epi_workflow.Rd new file mode 100644 index 000000000..a5dcdee79 --- /dev/null +++ b/man/fit-epi_workflow.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/epi_workflow.R +\name{fit-epi_workflow} +\alias{fit-epi_workflow} +\alias{fit.epi_workflow} +\title{Fit an \code{epi_workflow} object} +\usage{ +\method{fit}{epi_workflow}(object, x, ...) +} +\arguments{ +\item{object}{an \code{epi_workflow} object} + +\item{x}{an \code{epi_df} of predictors and outcomes to use when +fitting the \code{epi_workflow}} + +\item{...}{Not used} +} +\value{ +The \code{epi_workflow} object, updated with a fit parsnip +model in the \code{object$fit$fit} slot. +} +\description{ +This is the \code{fit()} method for an \code{epi_workflow} object that +estimates parameters for a given model from a set of data. +Fitting an \code{epi_workflow} involves two main steps, which are +preprocessing the data and fitting the underlying parsnip model. +} +\examples{ +jhu <- case_death_rate_subset \%>\% +filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + +r <- epi_recipe(jhu) \%>\% + step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% + step_epi_ahead(death_rate, ahead = 7) + +wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf + +} +\seealso{ +workflows::fit-workflow +} From 106b2ad67aa916bfbf89cdbe7eefcd63c72895b3 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Mon, 24 Jul 2023 23:32:35 -0700 Subject: [PATCH 09/15] Update doc + code to pick max pre/fit/post --- R/epi_workflow.R | 11 +++++------ R/frosting.R | 9 +++++++++ R/layer_add_forecast_date.R | 22 ++++++++++++++-------- R/layer_add_target_date.R | 7 ++++--- man/fit-epi_workflow.Rd | 6 ++++-- man/layer_add_forecast_date.Rd | 11 ++++++----- man/layer_add_target_date.Rd | 5 +++-- 7 files changed, 45 insertions(+), 26 deletions(-) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 7952893ba..97171b55b 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -70,9 +70,11 @@ is_epi_workflow <- function(x) { #' #' @param object an `epi_workflow` object #' -#' @param x an `epi_df` of predictors and outcomes to use when +#' @param data an `epi_df` of predictors and outcomes to use when #' fitting the `epi_workflow` #' +#' @param control A [workflows::control_workflow()] object +#' #' @return The `epi_workflow` object, updated with a fit parsnip #' model in the `object$fit$fit` slot. #' @@ -92,10 +94,9 @@ is_epi_workflow <- function(x) { #' wf #' #' @export -fit.epi_workflow <- function(object, x, ...){ +fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()){ - object$fit$meta <- list(mtv = max(x$time_value)) - #object$fit$as_of <- attributes(x)$metadata$as_of + object$fit$meta <- list(mtv = max(data$time_value), as_of = attributes(data)$metadata$as_of) NextMethod() } @@ -153,7 +154,6 @@ predict.epi_workflow <- function(object, new_data, ...) { c("Can't predict on an untrained epi_workflow.", i = "Do you need to call `fit()`?")) } - components <- list() components$mold <- workflows::extract_mold(object) components$forged <- hardhat::forge(new_data, @@ -161,7 +161,6 @@ predict.epi_workflow <- function(object, new_data, ...) { components$keys <- grab_forged_keys(components$forged, components$mold, new_data) components <- apply_frosting(object, components, new_data, ...) - components$predictions } diff --git a/R/frosting.R b/R/frosting.R index 62be492cb..7622002c3 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -267,10 +267,19 @@ apply_frosting.epi_workflow <- la <- layers[[l]] components <- slather(la, components, workflow, new_data) } + #%% mtv <- max(new_data$time_value) + #%% update_workflow_post(workflow, mtv) return(components) } +#%% #' @export +# update_workflow_post <- function(x, mtv) { +# substitute(x) <- "changed" +# #assign(deparse(substitute(x)), "changed", env=.GlobalEnv) +# #workflow$post$meta <- list(mtv = max(new_data$time_value)) +# } + #%% change_workflow = function(x){ # assign(deparse(substitute(x)), "changed", env=.GlobalEnv) #} diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 89ea2dd30..3884eacd7 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -4,17 +4,18 @@ #' @param forecast_date The forecast date to add as a column to the `epi_df`. #' For most cases, this should be specified in the form "yyyy-mm-dd". Note that #' when the forecast date is left unspecified, it is set to the maximum time -#' value in the test data after any processing (ex. leads and lags) has been -#' applied. +#' value from the data used in pre-processing, fitting the model, and +#' postprocessing. #' @param id a random id string #' #' @return an updated `frosting` postprocessor #' #' @details To use this function, either specify a forecast date or leave the #' forecast date unspecifed here. In the latter case, the forecast date will -#' be set as the maximum time value in the processed test data. In any case, -#' when the forecast date is less than the most recent update date of the data -#' (ie. the `as_of` value), an appropriate warning will be thrown. +#' be set as the maximum time value from the data used in pre-processing, +#' fitting the model, and postprocessing. In any case, when the forecast date is +#' less than the maximum `as_of` value (from the data used pre-processing, +#' model fitting, and postprocessing), an appropriate warning will be thrown. #' #' @export #' @examples @@ -82,14 +83,19 @@ layer_add_forecast_date_new <- function(forecast_date, id) { #' @export slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { - wf <<- workflow + #%% wf <<- workflow + #%% comp <<- components if (is.null(object$forecast_date)) { - max_time_value <- max(workflows::extract_preprocessor(wf)$mtv, wf$fit$meta$mtv, max(new_data$time_value))#wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value) + max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value)) #%% wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value) object$forecast_date <- max_time_value } + as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of + as_of_fit <- workflow$fit$meta$as_of + as_of_post <- attributes(new_data)$metadata$as_of - as_of_date <- as.Date(attributes(components$keys)$metadata$as_of) + as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post)) #%% as.Date(attributes(components$keys)$metadata$as_of) + # It would be nice to say that forecast_date is >= to the max of all of them. if (object$forecast_date < as_of_date) { cli_warn( c("The forecast_date is less than the most ", diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index 29a3558df..e9178737d 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -2,8 +2,9 @@ #' #' @param frosting a `frosting` postprocessor #' @param target_date The target date to add as a column to the `epi_df`. -#' By default, this is the maximum `time_value` from the processed test -#' data plus `ahead`, where `ahead` has been specified in preprocessing +#' By default, this is the maximum `time_value` (from the data used in +#' pre-processing, fitting the model, and postprocessing) plus `ahead`, +#' where `ahead` has been specified in preprocessing #' (most likely in `step_epi_ahead`). The user may override this with a #' date of their own (that will usually be in the form "yyyy-mm-dd"). #' @param id a random id string @@ -68,7 +69,7 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data the_recipe <- workflows::extract_recipe(workflow) if (is.null(object$target_date)) { - max_time_value <- max(new_data$time_value) + max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value)) ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") if (is.null(ahead)){ diff --git a/man/fit-epi_workflow.Rd b/man/fit-epi_workflow.Rd index a5dcdee79..fb1c3af28 100644 --- a/man/fit-epi_workflow.Rd +++ b/man/fit-epi_workflow.Rd @@ -5,15 +5,17 @@ \alias{fit.epi_workflow} \title{Fit an \code{epi_workflow} object} \usage{ -\method{fit}{epi_workflow}(object, x, ...) +\method{fit}{epi_workflow}(object, data, ..., control = workflows::control_workflow()) } \arguments{ \item{object}{an \code{epi_workflow} object} -\item{x}{an \code{epi_df} of predictors and outcomes to use when +\item{data}{an \code{epi_df} of predictors and outcomes to use when fitting the \code{epi_workflow}} \item{...}{Not used} + +\item{control}{A \code{\link[workflows:control_workflow]{workflows::control_workflow()}} object} } \value{ The \code{epi_workflow} object, updated with a fit parsnip diff --git a/man/layer_add_forecast_date.Rd b/man/layer_add_forecast_date.Rd index 96558e403..421978eb5 100644 --- a/man/layer_add_forecast_date.Rd +++ b/man/layer_add_forecast_date.Rd @@ -16,8 +16,8 @@ layer_add_forecast_date( \item{forecast_date}{The forecast date to add as a column to the \code{epi_df}. For most cases, this should be specified in the form "yyyy-mm-dd". Note that when the forecast date is left unspecified, it is set to the maximum time -value in the test data after any processing (ex. leads and lags) has been -applied.} +value from the data used in pre-processing, fitting the model, and +postprocessing.} \item{id}{a random id string} } @@ -30,9 +30,10 @@ Postprocessing step to add the forecast date \details{ To use this function, either specify a forecast date or leave the forecast date unspecifed here. In the latter case, the forecast date will -be set as the maximum time value in the processed test data. In any case, -when the forecast date is less than the most recent update date of the data -(ie. the \code{as_of} value), an appropriate warning will be thrown. +be set as the maximum time value from the data used in pre-processing, +fitting the model, and postprocessing. In any case, when the forecast date is +less than the maximum \code{as_of} value (from the data used pre-processing, +model fitting, and postprocessing), an appropriate warning will be thrown. } \examples{ jhu <- case_death_rate_subset \%>\% diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index db73066fb..c75f480ba 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -14,8 +14,9 @@ layer_add_target_date( \item{frosting}{a \code{frosting} postprocessor} \item{target_date}{The target date to add as a column to the \code{epi_df}. -By default, this is the maximum \code{time_value} from the processed test -data plus \code{ahead}, where \code{ahead} has been specified in preprocessing +By default, this is the maximum \code{time_value} (from the data used in +pre-processing, fitting the model, and postprocessing) plus \code{ahead}, +where \code{ahead} has been specified in preprocessing (most likely in \code{step_epi_ahead}). The user may override this with a date of their own (that will usually be in the form "yyyy-mm-dd").} From a42de5b82c3d30763a943e5eb695b45f50e5eeb0 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Tue, 25 Jul 2023 21:59:33 -0700 Subject: [PATCH 10/15] Fix layer_add_target date and clean up --- R/epi_recipe.R | 4 +- R/frosting.R | 27 --------- R/layer_add_forecast_date.R | 10 +-- R/layer_add_target_date.R | 52 +++++++++++----- man/layer_add_target_date.Rd | 30 ++++++--- tests/testthat/test-layer_add_target_date.R | 67 +++++++++++++++++++-- 6 files changed, 125 insertions(+), 65 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 9b8568689..d157b0d62 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -131,7 +131,7 @@ epi_recipe.epi_df <- term_info = var_info, steps = NULL, template = x[1,], - mtv = max(x$time_value), #%% + mtv = max(x$time_value), levels = NULL, retained = NA ) @@ -375,7 +375,7 @@ prep.epi_recipe <- function( } else { x$template <- training[0, ] } - x$mtv <- max(training$time_value) #%% + x$mtv <- max(training$time_value) x$tr_info <- tr_data x$levels <- lvls x$orig_lvls <- orig_lvls diff --git a/R/frosting.R b/R/frosting.R index 7622002c3..f0c554935 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -229,9 +229,6 @@ apply_frosting.default <- function(workflow, components, ...) { apply_frosting.epi_workflow <- function(workflow, components, new_data, ...) { - #%% wf1$post$meta <<- list(mtv = max(new_data$time_value)) #%% change wf1 and possibly <<- - # assign("workflow$post$meta", list(mtv = max(new_data$time_value)), envir = .GlobalEnv) - the_fit <- workflows::extract_fit_parsnip(workflow) if (!has_postprocessor(workflow)) { @@ -267,34 +264,10 @@ apply_frosting.epi_workflow <- la <- layers[[l]] components <- slather(la, components, workflow, new_data) } - #%% mtv <- max(new_data$time_value) - #%% update_workflow_post(workflow, mtv) return(components) } -#%% #' @export -# update_workflow_post <- function(x, mtv) { -# substitute(x) <- "changed" -# #assign(deparse(substitute(x)), "changed", env=.GlobalEnv) -# #workflow$post$meta <- list(mtv = max(new_data$time_value)) -# } - -#%% change_workflow = function(x){ -# assign(deparse(substitute(x)), "changed", env=.GlobalEnv) -#} - -#%% add_meta_post <- function(workflow, new_data){ -# -# workflow$post$meta <- list(mtv = max(new_data$time_value)) -# -# workflow -#} - -#%% changeMe = function(x){ -# assign(deparse(substitute(x)), "changed", env=.GlobalEnv) -#} - #' @export print.frosting <- function(x, form_width = 30, ...) { cli::cli_div( diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 3884eacd7..e512cf159 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -83,19 +83,19 @@ layer_add_forecast_date_new <- function(forecast_date, id) { #' @export slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { - #%% wf <<- workflow - #%% comp <<- components + if (is.null(object$forecast_date)) { - max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value)) #%% wf$post$meta$mtv) # workflow$fit$max_train_time #max(new_data$time_value) + max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, + workflow$fit$meta$mtv, + max(new_data$time_value)) object$forecast_date <- max_time_value } as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of as_of_fit <- workflow$fit$meta$as_of as_of_post <- attributes(new_data)$metadata$as_of - as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post)) #%% as.Date(attributes(components$keys)$metadata$as_of) + as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post)) - # It would be nice to say that forecast_date is >= to the max of all of them. if (object$forecast_date < as_of_date) { cli_warn( c("The forecast_date is less than the most ", diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index e9178737d..d49e2481c 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -2,11 +2,10 @@ #' #' @param frosting a `frosting` postprocessor #' @param target_date The target date to add as a column to the `epi_df`. -#' By default, this is the maximum `time_value` (from the data used in -#' pre-processing, fitting the model, and postprocessing) plus `ahead`, -#' where `ahead` has been specified in preprocessing -#' (most likely in `step_epi_ahead`). The user may override this with a -#' date of their own (that will usually be in the form "yyyy-mm-dd"). +#' By default, this is the forecast date plus `ahead` (from `step_epi_ahead` +#' in the `epi_recipe`) if there is a `layer_add_forecast_date` in the +#' `epi_workflow`. If there's no such layer, then the user may specify +#' their own target date with a date (of the form "yyyy-mm-dd"). #' @param id a random id string #' #' @return an updated `frosting` postprocessor @@ -28,8 +27,9 @@ #' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) #' latest <- get_test_data(r, jhu) #' -#' # Use ahead from preprocessing +#' # Use ahead + forecast date #' f <- frosting() %>% layer_predict() %>% +#' layer_add_forecast_date(forecast_date = "2022-05-31") %>% #' layer_add_target_date() %>% #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) @@ -37,15 +37,25 @@ #' p <- predict(wf1, latest) #' p #' -#' # Override default behaviour by specifying own target date -#' f2 <- frosting() %>% -#' layer_predict() %>% -#' layer_add_target_date(target_date = "2022-01-08") %>% +#' # Use ahead + max time value from pre, fit, post +#' # which is the same if include `layer_add_forecast_date()` +#' f <- frosting() %>% layer_predict() %>% +#' layer_add_target_date() %>% #' layer_naomit(.pred) #' wf2 <- wf %>% add_frosting(f2) #' #' p2 <- predict(wf2, latest) #' p2 +#' +#' # Specify own target date +#' f3 <- frosting() %>% +#' layer_predict() %>% +#' layer_add_target_date(target_date = "2022-01-08") %>% +#' layer_naomit(.pred) +#' wf3 <- wf %>% add_frosting(f3) +#' +#' p3 <- predict(wf3, latest) +#' p3 layer_add_target_date <- function(frosting, target_date = NULL, id = rand_id("add_target_date")) { target_date <- arg_to_date(target_date, allow_null = TRUE) @@ -67,14 +77,26 @@ layer_add_target_date_new <- function(id = id, target_date = target_date) { slather.layer_add_target_date <- function(object, components, workflow, new_data, ...) { the_recipe <- workflows::extract_recipe(workflow) + the_frosting <- extract_frosting(workflow) + + if (detect_layer(the_frosting, "layer_add_forecast_date") && + !is.null(extract_argument(the_frosting, + "layer_add_forecast_date", "forecast_date"))) { + forecast_date <- extract_argument(the_frosting, + "layer_add_forecast_date", "forecast_date") + + ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") + + target_date = forecast_date + ahead + + } else if (is.null(object$target_date) || + detect_layer(the_frosting, "layer_add_forecast_date")) { + max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, + workflow$fit$meta$mtv, + max(new_data$time_value)) - if (is.null(object$target_date)) { - max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, workflow$fit$meta$mtv, max(new_data$time_value)) ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - if (is.null(ahead)){ - stop("`ahead` must be specified in preprocessing.") - } target_date = max_time_value + ahead } else{ target_date = as.Date(object$target_date) diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index c75f480ba..b09a2f05a 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -14,11 +14,10 @@ layer_add_target_date( \item{frosting}{a \code{frosting} postprocessor} \item{target_date}{The target date to add as a column to the \code{epi_df}. -By default, this is the maximum \code{time_value} (from the data used in -pre-processing, fitting the model, and postprocessing) plus \code{ahead}, -where \code{ahead} has been specified in preprocessing -(most likely in \code{step_epi_ahead}). The user may override this with a -date of their own (that will usually be in the form "yyyy-mm-dd").} +By default, this is the forecast date plus \code{ahead} (from \code{step_epi_ahead} +in the \code{epi_recipe}) if there is a \code{layer_add_forecast_date} in the +\code{epi_workflow}. If there's no such layer, then the user may specify +their own target date with a date (of the form "yyyy-mm-dd").} \item{id}{a random id string} } @@ -45,8 +44,9 @@ r <- epi_recipe(jhu) \%>\% wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) latest <- get_test_data(r, jhu) -# Use ahead from preprocessing +# Use ahead + forecast date f <- frosting() \%>\% layer_predict() \%>\% + layer_add_forecast_date(forecast_date = "2022-05-31") \%>\% layer_add_target_date() \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f) @@ -54,13 +54,23 @@ wf1 <- wf \%>\% add_frosting(f) p <- predict(wf1, latest) p -# Override default behaviour by specifying own target date -f2 <- frosting() \%>\% - layer_predict() \%>\% - layer_add_target_date(target_date = "2022-01-08") \%>\% +# Use ahead + max time value from pre, fit, post +# which is the same if include `layer_add_forecast_date()` +f <- frosting() \%>\% layer_predict() \%>\% + layer_add_target_date() \%>\% layer_naomit(.pred) wf2 <- wf \%>\% add_frosting(f2) p2 <- predict(wf2, latest) p2 + +# Specify own target date +f3 <- frosting() \%>\% + layer_predict() \%>\% + layer_add_target_date(target_date = "2022-01-08") \%>\% + layer_naomit(.pred) +wf3 <- wf \%>\% add_frosting(f3) + +p3 <- predict(wf3, latest) +p3 } diff --git a/tests/testthat/test-layer_add_target_date.R b/tests/testthat/test-layer_add_target_date.R index 49ba1d997..a8878b9b8 100644 --- a/tests/testthat/test-layer_add_target_date.R +++ b/tests/testthat/test-layer_add_target_date.R @@ -9,10 +9,12 @@ wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14) -test_that("Use ahead from preprocessing", { +test_that("Use ahead + max time value from pre, fit, post", { - f <- frosting() %>% layer_predict() %>% - layer_add_target_date() %>% layer_naomit(.pred) + f <- frosting() %>% + layer_predict() %>% + layer_add_target_date() %>% + layer_naomit(.pred) wf1 <- wf %>% add_frosting(f) expect_silent(p <- predict(wf1, latest)) @@ -21,12 +23,48 @@ test_that("Use ahead from preprocessing", { expect_equal(nrow(p), 3L) expect_equal(p$target_date, rep(as.Date("2022-01-07"), times = 3)) expect_named(p, c("geo_value", "time_value", ".pred", "target_date")) + + # Should be same dates as above + f2 <- frosting() %>% + layer_predict() %>% + layer_add_forecast_date() %>% + layer_add_target_date() %>% + layer_naomit(.pred) + wf2 <- wf %>% add_frosting(f2) + + expect_warning(p2 <- predict(wf2, latest)) + expect_equal(ncol(p2), 5L) + expect_s3_class(p2, "epi_df") + expect_equal(nrow(p2), 3L) + expect_equal(p2$target_date, rep(as.Date("2022-01-07"), times = 3)) + expect_named(p2, c("geo_value", "time_value", ".pred", "forecast_date", "target_date")) + }) -test_that("Override default behaviour and specify own target date", { +test_that("Use ahead + specified forecast date", { - f <- frosting() %>% layer_predict() %>% - layer_add_target_date(target_date = "2022-01-08") %>% layer_naomit(.pred) + f <- frosting() %>% + layer_predict() %>% + layer_add_forecast_date(forecast_date = "2022-05-31") %>% + layer_add_target_date() %>% + layer_naomit(.pred) + wf1 <- wf %>% add_frosting(f) + + expect_silent(p <- predict(wf1, latest)) + expect_equal(ncol(p), 5L) + expect_s3_class(p, "epi_df") + expect_equal(nrow(p), 3L) + expect_equal(p$target_date, rep(as.Date("2022-06-07"), times = 3)) + expect_named(p, c("geo_value", "time_value", ".pred", "forecast_date", "target_date")) + +}) + +test_that("Specify own target date", { + + f <- frosting() %>% + layer_predict() %>% + layer_add_target_date(target_date = "2022-01-08") %>% + layer_naomit(.pred) wf1 <- wf %>% add_frosting(f) expect_silent(p2 <- predict(wf1, latest)) @@ -36,3 +74,20 @@ test_that("Override default behaviour and specify own target date", { expect_equal(p2$target_date, rep(as.Date("2022-01-08"), times = 3)) expect_named(p2, c("geo_value", "time_value", ".pred", "target_date")) }) + +test_that("Specify own target date, but have a forecast date layer", { + + f <- frosting() %>% + layer_predict() %>% + layer_add_forecast_date() %>% + layer_add_target_date(target_date = "2022-01-08") %>% + layer_naomit(.pred) + wf1 <- wf %>% add_frosting(f) + + expect_warning(p2 <- predict(wf1, latest)) + expect_equal(ncol(p2), 5L) + expect_s3_class(p2, "epi_df") + expect_equal(nrow(p2), 3L) + expect_equal(p2$target_date, rep(as.Date("2022-01-07"), times = 3)) + expect_named(p2, c("geo_value", "time_value", ".pred", "forecast_date", "target_date")) +}) From 506b81f8387b2eb1473333ff0734c3a4b12c3005 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Tue, 25 Jul 2023 22:06:40 -0700 Subject: [PATCH 11/15] Update some doc --- R/layer_add_target_date.R | 10 +++++----- man/layer_add_target_date.Rd | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index d49e2481c..fce8d8490 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -1,11 +1,11 @@ #' Postprocessing step to add the target date #' #' @param frosting a `frosting` postprocessor -#' @param target_date The target date to add as a column to the `epi_df`. -#' By default, this is the forecast date plus `ahead` (from `step_epi_ahead` -#' in the `epi_recipe`) if there is a `layer_add_forecast_date` in the -#' `epi_workflow`. If there's no such layer, then the user may specify -#' their own target date with a date (of the form "yyyy-mm-dd"). +#' @param target_date The target date to add as a column to the +#' `epi_df`. By default, this is the forecast date plus `ahead` +#' (from `step_epi_ahead` in the `epi_recipe`). If there's no +#' forecast date layer, then the user can specify their own +#' target date (of the form "yyyy-mm-dd"). #' @param id a random id string #' #' @return an updated `frosting` postprocessor diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index b09a2f05a..c07bc4d14 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -13,11 +13,11 @@ layer_add_target_date( \arguments{ \item{frosting}{a \code{frosting} postprocessor} -\item{target_date}{The target date to add as a column to the \code{epi_df}. -By default, this is the forecast date plus \code{ahead} (from \code{step_epi_ahead} -in the \code{epi_recipe}) if there is a \code{layer_add_forecast_date} in the -\code{epi_workflow}. If there's no such layer, then the user may specify -their own target date with a date (of the form "yyyy-mm-dd").} +\item{target_date}{The target date to add as a column to the +\code{epi_df}. By default, this is the forecast date plus \code{ahead} +(from \code{step_epi_ahead} in the \code{epi_recipe}). If there's no +forecast date layer, then the user can specify their own +target date (of the form "yyyy-mm-dd").} \item{id}{a random id string} } From ec7314bf76a127ac4b446560ae072731a21a0cba Mon Sep 17 00:00:00 2001 From: rachlobay Date: Tue, 25 Jul 2023 22:14:16 -0700 Subject: [PATCH 12/15] f2 --- R/layer_add_target_date.R | 2 +- man/layer_add_target_date.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index fce8d8490..dd5ae189f 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -39,7 +39,7 @@ #' #' # Use ahead + max time value from pre, fit, post #' # which is the same if include `layer_add_forecast_date()` -#' f <- frosting() %>% layer_predict() %>% +#' f2 <- frosting() %>% layer_predict() %>% #' layer_add_target_date() %>% #' layer_naomit(.pred) #' wf2 <- wf %>% add_frosting(f2) diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index c07bc4d14..747637616 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -56,7 +56,7 @@ p # Use ahead + max time value from pre, fit, post # which is the same if include `layer_add_forecast_date()` -f <- frosting() \%>\% layer_predict() \%>\% +f2 <- frosting() \%>\% layer_predict() \%>\% layer_add_target_date() \%>\% layer_naomit(.pred) wf2 <- wf \%>\% add_frosting(f2) From 195a816e0601b0ac882e1aaa015f773c3a74ab30 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Thu, 27 Jul 2023 21:11:27 -0700 Subject: [PATCH 13/15] Update layer_add_target_date --- R/layer_add_target_date.R | 45 +++++++++++---------- man/layer_add_target_date.Rd | 11 +++-- tests/testthat/test-layer_add_target_date.R | 29 +++++++------ 3 files changed, 45 insertions(+), 40 deletions(-) diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index dd5ae189f..652951e10 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -2,10 +2,13 @@ #' #' @param frosting a `frosting` postprocessor #' @param target_date The target date to add as a column to the -#' `epi_df`. By default, this is the forecast date plus `ahead` -#' (from `step_epi_ahead` in the `epi_recipe`). If there's no -#' forecast date layer, then the user can specify their own -#' target date (of the form "yyyy-mm-dd"). +#' `epi_df`. If there's a forecast date specified in a layer, then +#' it is the forecast date plus `ahead` (from `step_epi_ahead` in +#' the `epi_recipe`). Otherwise, it is the maximum `time_value` +#' (from the data used in pre-processing, fitting the model, and +#' postprocessing) plus `ahead`, where `ahead` has been specified in +#' preprocessing. The user may override these by specifying a +#' target date of their own (of the form "yyyy-mm-dd"). #' @param id a random id string #' #' @return an updated `frosting` postprocessor @@ -79,27 +82,27 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data the_recipe <- workflows::extract_recipe(workflow) the_frosting <- extract_frosting(workflow) - if (detect_layer(the_frosting, "layer_add_forecast_date") && - !is.null(extract_argument(the_frosting, - "layer_add_forecast_date", "forecast_date"))) { - forecast_date <- extract_argument(the_frosting, - "layer_add_forecast_date", "forecast_date") - - ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") + if (!is.null(object$target_date)) { + target_date = as.Date(object$target_date) + } else { # null target date case + if (detect_layer(the_frosting, "layer_add_forecast_date") && + !is.null(extract_argument(the_frosting, + "layer_add_forecast_date", "forecast_date"))) { + forecast_date <- extract_argument(the_frosting, + "layer_add_forecast_date", "forecast_date") - target_date = forecast_date + ahead + ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - } else if (is.null(object$target_date) || - detect_layer(the_frosting, "layer_add_forecast_date")) { - max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, - workflow$fit$meta$mtv, - max(new_data$time_value)) + target_date = forecast_date + ahead + } else { + max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, + workflow$fit$meta$mtv, + max(new_data$time_value)) - ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") + ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - target_date = max_time_value + ahead - } else{ - target_date = as.Date(object$target_date) + target_date = max_time_value + ahead + } } components$predictions <- dplyr::bind_cols(components$predictions, diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index 747637616..58ff7770f 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -14,10 +14,13 @@ layer_add_target_date( \item{frosting}{a \code{frosting} postprocessor} \item{target_date}{The target date to add as a column to the -\code{epi_df}. By default, this is the forecast date plus \code{ahead} -(from \code{step_epi_ahead} in the \code{epi_recipe}). If there's no -forecast date layer, then the user can specify their own -target date (of the form "yyyy-mm-dd").} +\code{epi_df}. If there's a forecast date specified in a layer, then +it is the forecast date plus \code{ahead} (from \code{step_epi_ahead} in +the \code{epi_recipe}). Otherwise, it is the maximum \code{time_value} +(from the data used in pre-processing, fitting the model, and +postprocessing) plus \code{ahead}, where \code{ahead} has been specified in +preprocessing. The user may override these by specifying a +target date of their own (of the form "yyyy-mm-dd").} \item{id}{a random id string} } diff --git a/tests/testthat/test-layer_add_target_date.R b/tests/testthat/test-layer_add_target_date.R index a8878b9b8..b8627571c 100644 --- a/tests/testthat/test-layer_add_target_date.R +++ b/tests/testthat/test-layer_add_target_date.R @@ -61,33 +61,32 @@ test_that("Use ahead + specified forecast date", { test_that("Specify own target date", { + # No forecast date layer f <- frosting() %>% layer_predict() %>% layer_add_target_date(target_date = "2022-01-08") %>% layer_naomit(.pred) wf1 <- wf %>% add_frosting(f) - expect_silent(p2 <- predict(wf1, latest)) - expect_equal(ncol(p2), 4L) - expect_s3_class(p2, "epi_df") - expect_equal(nrow(p2), 3L) - expect_equal(p2$target_date, rep(as.Date("2022-01-08"), times = 3)) - expect_named(p2, c("geo_value", "time_value", ".pred", "target_date")) -}) - -test_that("Specify own target date, but have a forecast date layer", { + expect_silent(p1 <- predict(wf1, latest)) + expect_equal(ncol(p1), 4L) + expect_s3_class(p1, "epi_df") + expect_equal(nrow(p1), 3L) + expect_equal(p1$target_date, rep(as.Date("2022-01-08"), times = 3)) + expect_named(p1, c("geo_value", "time_value", ".pred", "target_date")) - f <- frosting() %>% + # Include forecast date layer - should be same results as previous + f2 <- frosting() %>% layer_predict() %>% layer_add_forecast_date() %>% layer_add_target_date(target_date = "2022-01-08") %>% layer_naomit(.pred) - wf1 <- wf %>% add_frosting(f) + wf2 <- wf %>% add_frosting(f2) - expect_warning(p2 <- predict(wf1, latest)) - expect_equal(ncol(p2), 5L) + expect_silent(p2 <- predict(wf1, latest)) + expect_equal(ncol(p2), 4L) expect_s3_class(p2, "epi_df") expect_equal(nrow(p2), 3L) - expect_equal(p2$target_date, rep(as.Date("2022-01-07"), times = 3)) - expect_named(p2, c("geo_value", "time_value", ".pred", "forecast_date", "target_date")) + expect_equal(p2$target_date, rep(as.Date("2022-01-08"), times = 3)) + expect_named(p2, c("geo_value", "time_value", ".pred", "target_date")) }) From e26894f2f3d08b91786bf98dd2cd6c0ac59b3fde Mon Sep 17 00:00:00 2001 From: rachlobay Date: Thu, 27 Jul 2023 22:24:53 -0700 Subject: [PATCH 14/15] mtv --> max_time_value --- R/epi_recipe.R | 4 ++-- R/epi_workflow.R | 2 +- R/layer_add_forecast_date.R | 4 ++-- R/layer_add_target_date.R | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index d157b0d62..4caea7476 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -131,7 +131,7 @@ epi_recipe.epi_df <- term_info = var_info, steps = NULL, template = x[1,], - mtv = max(x$time_value), + max_time_value = max(x$time_value), levels = NULL, retained = NA ) @@ -375,7 +375,7 @@ prep.epi_recipe <- function( } else { x$template <- training[0, ] } - x$mtv <- max(training$time_value) + x$max_time_value <- max(training$time_value) x$tr_info <- tr_data x$levels <- lvls x$orig_lvls <- orig_lvls diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 97171b55b..1379fef86 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -96,7 +96,7 @@ is_epi_workflow <- function(x) { #' @export fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()){ - object$fit$meta <- list(mtv = max(data$time_value), as_of = attributes(data)$metadata$as_of) + object$fit$meta <- list(max_time_value = max(data$time_value), as_of = attributes(data)$metadata$as_of) NextMethod() } diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index e512cf159..0b522ef65 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -85,8 +85,8 @@ layer_add_forecast_date_new <- function(forecast_date, id) { slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { if (is.null(object$forecast_date)) { - max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, - workflow$fit$meta$mtv, + max_time_value <- max(workflows::extract_preprocessor(workflow)$max_time_value, + workflow$fit$meta$max_time_value, max(new_data$time_value)) object$forecast_date <- max_time_value } diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index 652951e10..1fe151bce 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -95,8 +95,8 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data target_date = forecast_date + ahead } else { - max_time_value <- max(workflows::extract_preprocessor(workflow)$mtv, - workflow$fit$meta$mtv, + max_time_value <- max(workflows::extract_preprocessor(workflow)$max_time_value, + workflow$fit$meta$max_time_value, max(new_data$time_value)) ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") From 7566fcbaa0aa8c581281ada88411cdca9c403568 Mon Sep 17 00:00:00 2001 From: rachlobay Date: Fri, 11 Aug 2023 16:35:37 -0700 Subject: [PATCH 15/15] Update inst/templates/layer.R --- inst/templates/layer.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inst/templates/layer.R b/inst/templates/layer.R index 14dc4a82b..3fecb3c33 100644 --- a/inst/templates/layer.R +++ b/inst/templates/layer.R @@ -28,7 +28,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) { #' @export slather.layer_{{{ name }}} <- - function(object, components, the_fit, the_recipe, ...) { + function(object, components, workflow, new_data, ...) { # if layer_ used ... in tidyselect, we need to evaluate it now exprs <- rlang::expr(c(!!!object$terms))