diff --git a/DESCRIPTION b/DESCRIPTION index e068fe7e3..9329382fd 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.10 +Version: 0.0.11 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), @@ -46,6 +46,7 @@ Imports: tibble, tidyr, tidyselect, + tsibble, usethis, vctrs, workflows (>= 1.0.0) diff --git a/NEWS.md b/NEWS.md index 04dc78e4f..5f629f45f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -33,3 +33,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Working vignette - use `checkmate` for input validation - refactor quantile extrapolation (possibly creates different results) +- force `target_date` + `forecast_date` handling to match the time_type of + the epi_df. allows for annual and weekly data diff --git a/R/dist_quantiles.R b/R/dist_quantiles.R index 7f7af40f4..f20ebe3dc 100644 --- a/R/dist_quantiles.R +++ b/R/dist_quantiles.R @@ -233,10 +233,10 @@ quantile_extrapolate <- function(x, tau_out, middle) { dplyr::arrange(q) } if (any(indl)) { - qvals_out[indl] <- tail_extrapolate(tau_out[indl], head(qv, 2)) + qvals_out[indl] <- tail_extrapolate(tau_out[indl], utils::head(qv, 2)) } if (any(indr)) { - qvals_out[indr] <- tail_extrapolate(tau_out[indr], tail(qv, 2)) + qvals_out[indr] <- tail_extrapolate(tau_out[indr], utils::tail(qv, 2)) } qvals_out } diff --git a/R/epi_workflow.R b/R/epi_workflow.R index d5e7d13a2..e64e0f7bc 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -251,11 +251,10 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' preds predict.epi_workflow <- function(object, new_data, ...) { if (!workflows::is_trained_workflow(object)) { - rlang::abort( - c("Can't predict on an untrained epi_workflow.", - i = "Do you need to call `fit()`?" - ) - ) + cli::cli_abort(c( + "Can't predict on an untrained epi_workflow.", + i = "Do you need to call `fit()`?" + )) } components <- list() components$mold <- workflows::extract_mold(object) diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 6bb2cf572..5bd6b6918 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -68,6 +68,9 @@ #' p3 layer_add_forecast_date <- function(frosting, forecast_date = NULL, id = rand_id("add_forecast_date")) { + arg_is_chr_scalar(id) + arg_is_scalar(forecast_date, allow_null = TRUE) + # can't validate the type of forecast_date until we know the time_type add_layer( frosting, layer_add_forecast_date_new( @@ -78,8 +81,6 @@ layer_add_forecast_date <- } layer_add_forecast_date_new <- function(forecast_date, id) { - forecast_date <- arg_to_date(forecast_date, allow_null = TRUE) - arg_is_chr_scalar(id) layer("add_forecast_date", forecast_date = forecast_date, id = id) } @@ -91,26 +92,25 @@ slather.layer_add_forecast_date <- function(object, components, workflow, new_da workflow$fit$meta$max_time_value, max(new_data$time_value) ) - object$forecast_date <- max_time_value + forecast_date <- max_time_value + } else { + forecast_date <- object$forecast_date } - as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of - as_of_fit <- workflow$fit$meta$as_of - as_of_post <- attributes(new_data)$metadata$as_of - as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post)) - - if (object$forecast_date < as_of_date) { - cli_warn( - c("The forecast_date is less than the most ", - "recent update date of the data: ", - i = "forecast_date = {object$forecast_date} while data is from {as_of_date}." - ) - ) - } + expected_time_type <- attr( + workflows::extract_preprocessor(workflow)$template, "metadata" + )$time_type + if (expected_time_type == "week") expected_time_type <- "day" + validate_date(forecast_date, expected_time_type, + call = expr(layer_add_forecast_date()) + ) + forecast_date <- coerce_time_type(forecast_date, expected_time_type) + object$forecast_date <- forecast_date components$predictions <- dplyr::bind_cols( components$predictions, - forecast_date = as.Date(object$forecast_date) + forecast_date = forecast_date ) + components } diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index f2fee889f..a50b6042c 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -63,8 +63,9 @@ #' p3 layer_add_target_date <- function(frosting, target_date = NULL, id = rand_id("add_target_date")) { - target_date <- arg_to_date(target_date, allow_null = TRUE) arg_is_chr_scalar(id) + arg_is_scalar(target_date, allow_null = TRUE) + # can't validate the type of target_date until we know the time_type add_layer( frosting, layer_add_target_date_new( @@ -84,35 +85,39 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data the_recipe <- workflows::extract_recipe(workflow) the_frosting <- extract_frosting(workflow) + expected_time_type <- attr( + workflows::extract_preprocessor(workflow)$template, "metadata" + )$time_type + if (expected_time_type == "week") expected_time_type <- "day" + if (!is.null(object$target_date)) { - target_date <- as.Date(object$target_date) - } else { # null target date case - if (detect_layer(the_frosting, "layer_add_forecast_date") && - !is.null(extract_argument( - the_frosting, - "layer_add_forecast_date", "forecast_date" + target_date <- object$target_date + validate_date(target_date, expected_time_type, + call = expr(layer_add_target_date()) + ) + target_date <- coerce_time_type(target_date, expected_time_type) + } else if ( + detect_layer(the_frosting, "layer_add_forecast_date") && + !is.null(forecast_date <- extract_argument( + the_frosting, "layer_add_forecast_date", "forecast_date" ))) { - forecast_date <- extract_argument( - the_frosting, - "layer_add_forecast_date", "forecast_date" - ) - - ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - - target_date <- forecast_date + ahead - } else { - max_time_value <- max( - workflows::extract_preprocessor(workflow)$max_time_value, - workflow$fit$meta$max_time_value, - max(new_data$time_value) - ) - - ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - - target_date <- max_time_value + ahead - } + validate_date(forecast_date, expected_time_type, + call = expr(layer_add_forecast_date()) + ) + forecast_date <- coerce_time_type(forecast_date, expected_time_type) + ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") + target_date <- forecast_date + ahead + } else { + max_time_value <- max( + workflows::extract_preprocessor(workflow)$max_time_value, + workflow$fit$meta$max_time_value, + max(new_data$time_value) + ) + ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") + target_date <- max_time_value + ahead } + object$target_date <- target_date components$predictions <- dplyr::bind_cols(components$predictions, target_date = target_date ) diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 2b7057bef..f89160794 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -144,6 +144,12 @@ slather.layer_population_scaling <- length(object$df_pop_col) == 1 ) + if (is.null(object$by)) { + object$by <- intersect( + kill_time_value(epi_keys(components$predictions)), + colnames(dplyr::select(object$df, !object$df_pop_col)) + ) + } try_join <- try( dplyr::left_join(components$predictions, object$df, by = object$by @@ -157,8 +163,8 @@ slather.layer_population_scaling <- )) } - object$df <- object$df %>% - dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) + # object$df <- object$df %>% + # dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) pop_col <- rlang::sym(object$df_pop_col) exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index ec5428d8f..52f51de16 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -53,20 +53,20 @@ step_epi_lag <- function(recipe, ..., + lag, role = "predictor", trained = FALSE, - lag, prefix = "lag_", default = NA, columns = NULL, skip = FALSE, id = rand_id("epi_lag")) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli::cli_abort("This step can only operate on an `epi_recipe`.") } if (missing(lag)) { - rlang::abort( + cli::cli_abort( c("The `lag` argument must not be empty.", i = "Did you perhaps pass an integer in `...` accidentally?" ) @@ -75,7 +75,8 @@ step_epi_lag <- arg_is_nonneg_int(lag) arg_is_chr_scalar(prefix, id) if (!is.null(columns)) { - rlang::abort(c("The `columns` argument must be `NULL.", + cli::cli_abort(c( + "The `columns` argument must be `NULL.", i = "Use `tidyselect` methods to choose columns to lag." )) } @@ -85,7 +86,7 @@ step_epi_lag <- terms = dplyr::enquos(...), role = role, trained = trained, - lag = lag, + lag = as.integer(lag), prefix = prefix, default = default, keys = epi_keys(recipe), @@ -104,24 +105,23 @@ step_epi_lag <- step_epi_ahead <- function(recipe, ..., + ahead, role = "outcome", trained = FALSE, - ahead, prefix = "ahead_", default = NA, columns = NULL, skip = FALSE, id = rand_id("epi_ahead")) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli::cli_abort("This step can only operate on an `epi_recipe`.") } if (missing(ahead)) { - rlang::abort( - c("The `ahead` argument must not be empty.", - i = "Did you perhaps pass an integer in `...` accidentally?" - ) - ) + cli::cli_abort(c( + "The `ahead` argument must not be empty.", + i = "Did you perhaps pass an integer in `...` accidentally?" + )) } arg_is_nonneg_int(ahead) arg_is_chr_scalar(prefix, id) @@ -136,7 +136,7 @@ step_epi_ahead <- terms = dplyr::enquos(...), role = role, trained = trained, - ahead = ahead, + ahead = as.integer(ahead), prefix = prefix, default = default, keys = epi_keys(recipe), diff --git a/R/time_types.R b/R/time_types.R new file mode 100644 index 000000000..7fe3e47b4 --- /dev/null +++ b/R/time_types.R @@ -0,0 +1,73 @@ +guess_time_type <- function(time_value) { + # similar to epiprocess:::guess_time_type() but w/o the gap handling + arg_is_scalar(time_value) + if (is.character(time_value)) { + if (nchar(time_value) <= "10") { + new_time_value <- tryCatch( + { + as.Date(time_value) + }, + error = function(e) NULL + ) + } else { + new_time_value <- tryCatch( + { + as.POSIXct(time_value) + }, + error = function(e) NULL + ) + } + if (!is.null(new_time_value)) time_value <- new_time_value + } + if (inherits(time_value, "POSIXct")) { + return("day-time") + } + if (inherits(time_value, "Date")) { + return("day") + } + if (inherits(time_value, "yearweek")) { + return("yearweek") + } + if (inherits(time_value, "yearmonth")) { + return("yearmonth") + } + if (inherits(time_value, "yearquarter")) { + return("yearquarter") + } + if (is.numeric(time_value) && all(time_value == as.integer(time_value)) && + all(time_value >= 1582)) { + return("year") + } + return("custom") +} + +coerce_time_type <- function(x, target_type) { + if (target_type == "year") { + if (is.numeric(x)) { + return(as.integer(x)) + } else { + return(as.POSIXlt(x)$year + 1900L) + } + } + switch(target_type, + "day-time" = as.POSIXct(x), + "day" = as.Date(x), + "week" = as.Date(x), + "yearweek" = tsibble::yearweek(x), + "yearmonth" = tsibble::yearmonth(x), + "yearquarter" = tsibble::yearquarter(x) + ) +} + +validate_date <- function(x, expected, arg = rlang::caller_arg(x), + call = rlang::caller_env()) { + time_type_x <- guess_time_type(x) + ok <- time_type_x == expected + if (!ok) { + cli::cli_abort(c( + "The {.arg {arg}} was given as a {.val {time_type_x}} while the", + `!` = "`time_type` of the training data was {.val {expected}}.", + i = "See {.topic epiprocess::epi_df} for descriptions of these are determined." + ), call = call) + } +} diff --git a/man/step_epi_shift.Rd b/man/step_epi_shift.Rd index bf135346e..f4419b831 100644 --- a/man/step_epi_shift.Rd +++ b/man/step_epi_shift.Rd @@ -8,9 +8,9 @@ step_epi_lag( recipe, ..., + lag, role = "predictor", trained = FALSE, - lag, prefix = "lag_", default = NA, columns = NULL, @@ -21,9 +21,9 @@ step_epi_lag( step_epi_ahead( recipe, ..., + ahead, role = "outcome", trained = FALSE, - ahead, prefix = "ahead_", default = NA, columns = NULL, @@ -38,16 +38,16 @@ sequence of operations for this recipe.} \item{...}{One or more selector functions to choose variables for this step. See \code{\link[recipes:selections]{recipes::selections()}} for more details.} +\item{lag, ahead}{A vector of integers. Each specified column will +be the lag or lead for each value in the vector. Lag integers must be +nonnegative, while ahead integers must be positive.} + \item{role}{For model terms created by this step, what analysis role should they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} \item{trained}{A logical to indicate if the quantities for preprocessing have been estimated.} -\item{lag, ahead}{A vector of integers. Each specified column will -be the lag or lead for each value in the vector. Lag integers must be -nonnegative, while ahead integers must be positive.} - \item{prefix}{A prefix to indicate what type of variable this is} \item{default}{Determines what fills empty rows diff --git a/tests/testthat/test-extract_argument.R b/tests/testthat/test-extract_argument.R index 0654304ba..3250b2991 100644 --- a/tests/testthat/test-extract_argument.R +++ b/tests/testthat/test-extract_argument.R @@ -43,19 +43,19 @@ test_that("recipe argument extractor works", { expect_error(extract_argument(r$steps[[1]], "uhoh", "bubble")) expect_error(extract_argument(r$steps[[1]], "step_epi_lag", "bubble")) - expect_identical(extract_argument(r$steps[[2]], "step_epi_ahead", "ahead"), 7) + expect_identical(extract_argument(r$steps[[2]], "step_epi_ahead", "ahead"), 7L) expect_error(extract_argument(r, "step_lightly", "quantile_levels")) expect_identical( extract_argument(r, "step_epi_lag", "lag"), - list(c(0, 7, 14), c(0, 7, 14)) + list(c(0L, 7L, 14L), c(0L, 7L, 14L)) ) wf <- epi_workflow(preprocessor = r) expect_error(extract_argument(epi_workflow(), "step_epi_lag", "lag")) expect_identical( extract_argument(wf, "step_epi_lag", "lag"), - list(c(0, 7, 14), c(0, 7, 14)) + list(c(0L, 7L, 14L), c(0L, 7L, 14L)) ) }) diff --git a/tests/testthat/test-layer_add_forecast_date.R b/tests/testthat/test-layer_add_forecast_date.R index 1830118dc..9595b47b6 100644 --- a/tests/testthat/test-layer_add_forecast_date.R +++ b/tests/testthat/test-layer_add_forecast_date.R @@ -11,8 +11,9 @@ latest <- jhu %>% test_that("layer validation works", { f <- frosting() - expect_error(layer_add_forecast_date(f, "a")) - expect_error(layer_add_forecast_date(f, "2022-05-31", id = c("a", "b"))) + expect_error(layer_add_forecast_date(f, c("2022-05-31", "2022-05-31"))) # multiple forecast_dates + expect_error(layer_add_forecast_date(f, "2022-05-31", id = 2)) # id is not a character + expect_error(layer_add_forecast_date(f, "2022-05-31", id = c("a", "b"))) # multiple ids expect_silent(layer_add_forecast_date(f, "2022-05-31")) expect_silent(layer_add_forecast_date(f)) expect_silent(layer_add_forecast_date(f, as.Date("2022-05-31"))) @@ -41,10 +42,12 @@ test_that("Specify a `forecast_date` that is less than `as_of` date", { layer_naomit(.pred) wf2 <- wf %>% add_frosting(f2) - expect_warning( - p2 <- predict(wf2, latest), - "forecast_date is less than the most recent update date of the data." - ) + # this warning has been removed + # expect_warning( + # p2 <- predict(wf2, latest), + # "forecast_date is less than the most recent update date of the data." + # ) + expect_silent(p2 <- predict(wf2, latest)) expect_equal(ncol(p2), 4L) expect_s3_class(p2, "epi_df") expect_equal(nrow(p2), 3L) @@ -59,13 +62,51 @@ test_that("Do not specify a forecast_date in `layer_add_forecast_date()`", { layer_naomit(.pred) wf3 <- wf %>% add_frosting(f3) - expect_warning( - p3 <- predict(wf3, latest), - "forecast_date is less than the most recent update date of the data." - ) + # this warning has been removed + # expect_warning( + # p3 <- predict(wf3, latest), + # "forecast_date is less than the most recent update date of the data." + # ) + expect_silent(p3 <- predict(wf3, latest)) expect_equal(ncol(p3), 4L) expect_s3_class(p3, "epi_df") expect_equal(nrow(p3), 3L) expect_equal(p3$forecast_date, rep(as.Date("2021-12-31"), times = 3)) expect_named(p3, c("geo_value", "time_value", ".pred", "forecast_date")) }) + + +test_that("forecast date works for daily", { + f <- frosting() %>% + layer_predict() %>% + layer_add_forecast_date() %>% + layer_naomit(.pred) + + wf1 <- add_frosting(wf, f) + p <- predict(wf1, latest) + # both forecast_date and epi_df are dates + expect_identical(p$forecast_date[1], as.Date("2021-12-31")) + + # the error happens at predict time because the + # time_value train/test types don't match + latest_yearly <- latest %>% + unclass() %>% + as.data.frame() %>% + mutate(time_value = as.POSIXlt(time_value)$year + 1900L) %>% + as_epi_df() + expect_error(predict(wf1, latest_yearly)) + + # forecast_date is a string, gets correctly converted to date + wf2 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_forecast_date", forecast_date = "2022-01-01") + ) + expect_silent(predict(wf2, latest)) + + # forecast_date is a year/int while the epi_df is a date + wf3 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_forecast_date", forecast_date = 2022L) + ) + expect_error(predict(wf3, latest)) +}) diff --git a/tests/testthat/test-layer_add_target_date.R b/tests/testthat/test-layer_add_target_date.R index 287956612..e5349839b 100644 --- a/tests/testthat/test-layer_add_target_date.R +++ b/tests/testthat/test-layer_add_target_date.R @@ -31,7 +31,8 @@ test_that("Use ahead + max time value from pre, fit, post", { layer_naomit(.pred) wf2 <- wf %>% add_frosting(f2) - expect_warning(p2 <- predict(wf2, latest)) + # expect_warning(p2 <- predict(wf2, latest)) # this warning has been removed + expect_silent(p2 <- predict(wf2, latest)) expect_equal(ncol(p2), 5L) expect_s3_class(p2, "epi_df") expect_equal(nrow(p2), 3L) @@ -85,3 +86,38 @@ test_that("Specify own target date", { expect_equal(p2$target_date, rep(as.Date("2022-01-08"), times = 3)) expect_named(p2, c("geo_value", "time_value", ".pred", "target_date")) }) + +test_that("target date works for daily and yearly", { + f <- frosting() %>% + layer_predict() %>% + layer_add_target_date() %>% + layer_naomit(.pred) + + wf1 <- add_frosting(wf, f) + p <- predict(wf1, latest) + # both target_date and epi_df are dates + expect_identical(p$target_date[1], as.Date("2021-12-31") + 7L) + + # the error happens at predict time because the + # time_value train/test types don't match + latest_bad <- latest %>% + unclass() %>% + as.data.frame() %>% + mutate(time_value = as.POSIXlt(time_value)$year + 1900L) %>% + as_epi_df() + expect_error(predict(wf1, latest_bad)) + + # target_date is a string (gets correctly converted to Date) + wf1 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_target_date", target_date = "2022-01-07") + ) + expect_silent(predict(wf1, latest)) + + # target_date is a year/int while the epi_df is a date + wf1 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_target_date", target_date = 2022L) + ) + expect_error(predict(wf1, latest)) # wrong time type of forecast_date +}) diff --git a/tests/testthat/test-propagate_samples.R b/tests/testthat/test-propagate_samples.R deleted file mode 100644 index 5278ab385..000000000 --- a/tests/testthat/test-propagate_samples.R +++ /dev/null @@ -1,7 +0,0 @@ -test_that("propagate_samples", { - r <- -30:50 - p <- 40 - quantiles <- 1:9 / 10 - aheads <- c(2, 4, 7) - nsim <- 100 -}) diff --git a/tests/testthat/test-step_growth_rate.R b/tests/testthat/test-step_growth_rate.R index d0dec170e..052141710 100644 --- a/tests/testthat/test-step_growth_rate.R +++ b/tests/testthat/test-step_growth_rate.R @@ -34,7 +34,7 @@ test_that("step_growth_rate works for a single signal", { res <- r %>% step_growth_rate(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_value, c(NA, 1 / 6:9)) @@ -46,7 +46,7 @@ test_that("step_growth_rate works for a single signal", { r <- epi_recipe(edf) res <- r %>% step_growth_rate(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_value, rep(c(NA, 1 / 6:9), each = 2)) }) @@ -63,7 +63,7 @@ test_that("step_growth_rate works for a two signals", { res <- r %>% step_growth_rate(v1, v2, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_v1, c(NA, 1 / 6:9)) expect_equal(res$gr_1_rel_change_v2, c(NA, 1 / 1:4)) @@ -76,7 +76,7 @@ test_that("step_growth_rate works for a two signals", { r <- epi_recipe(edf) res <- r %>% step_growth_rate(v1, v2, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_v1, rep(c(NA, 1 / 6:9), each = 2)) expect_equal(res$gr_1_rel_change_v2, rep(c(NA, 1 / 1:4), each = 2)) diff --git a/tests/testthat/test-step_lag_difference.R b/tests/testthat/test-step_lag_difference.R index dc61d12d4..c0fd377e6 100644 --- a/tests/testthat/test-step_lag_difference.R +++ b/tests/testthat/test-step_lag_difference.R @@ -27,13 +27,13 @@ test_that("step_lag_difference works for a single signal", { res <- r %>% step_lag_difference(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_value, c(NA, rep(1, 4))) res <- r %>% step_lag_difference(value, horizon = 1:2) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_value, c(NA, rep(1, 4))) expect_equal(res$lag_diff_2_value, c(NA, NA, rep(2, 3))) @@ -48,7 +48,7 @@ test_that("step_lag_difference works for a single signal", { r <- epi_recipe(edf) res <- r %>% step_lag_difference(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_value, c(NA, NA, rep(1, 8))) }) @@ -65,7 +65,7 @@ test_that("step_lag_difference works for a two signals", { res <- r %>% step_lag_difference(v1, v2, horizon = 1:2) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_v1, c(NA, rep(1, 4))) expect_equal(res$lag_diff_2_v1, c(NA, NA, rep(2, 3))) @@ -80,7 +80,7 @@ test_that("step_lag_difference works for a two signals", { r <- epi_recipe(edf) res <- r %>% step_lag_difference(v1, v2, horizon = 1:2) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_v1, rep(c(NA, rep(1, 4)), each = 2)) expect_equal(res$lag_diff_2_v1, rep(c(NA, NA, rep(2, 3)), each = 2))