diff --git a/.Rbuildignore b/.Rbuildignore index 79f39d67e..3a77bb347 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -1,3 +1,5 @@ +^renv$ +^renv\.lock$ ^epipredict\.Rproj$ ^\.Rproj\.user$ ^LICENSE\.md$ diff --git a/.gitignore b/.gitignore index 292b968f2..3375a07f1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,6 @@ inst/doc .DS_Store /doc/ /Meta/ +.Rprofile +renv.lock +renv/ \ No newline at end of file diff --git a/DESCRIPTION b/DESCRIPTION index f91b64d28..b0b592e2a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.7 +Version: 0.0.8 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), @@ -22,11 +22,11 @@ License: MIT + file LICENSE URL: https://github.com/cmu-delphi/epipredict/, https://cmu-delphi.github.io/epipredict BugReports: https://github.com/cmu-delphi/epipredict/issues/ -Depends: +Depends: epiprocess (>= 0.6.0), parsnip (>= 1.0.0), R (>= 3.5.0) -Imports: +Imports: cli, distributional, dplyr, @@ -48,7 +48,7 @@ Imports: usethis, vctrs, workflows (>= 1.0.0) -Suggests: +Suggests: covidcast, data.table, epidatr (>= 1.0.0), @@ -61,7 +61,7 @@ Suggests: rmarkdown, testthat (>= 3.0.0), xgboost -VignetteBuilder: +VignetteBuilder: knitr Remotes: cmu-delphi/epidatr, diff --git a/NAMESPACE b/NAMESPACE index 9beb39a88..fc7a7ea00 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,7 @@ S3method(adjust_frosting,frosting) S3method(apply_frosting,default) S3method(apply_frosting,epi_workflow) S3method(augment,epi_workflow) +S3method(bake,check_enough_train_data) S3method(bake,epi_recipe) S3method(bake,step_epi_ahead) S3method(bake,step_epi_lag) @@ -48,6 +49,7 @@ S3method(mean,dist_quantiles) S3method(median,dist_quantiles) S3method(predict,epi_workflow) S3method(predict,flatline) +S3method(prep,check_enough_train_data) S3method(prep,epi_recipe) S3method(prep,step_epi_ahead) S3method(prep,step_epi_lag) @@ -60,6 +62,7 @@ S3method(print,arx_class) S3method(print,arx_fcast) S3method(print,canned_epipred) S3method(print,cdc_baseline_fcast) +S3method(print,check_enough_train_data) S3method(print,epi_recipe) S3method(print,epi_workflow) S3method(print,flat_fcast) @@ -104,6 +107,7 @@ S3method(snap,default) S3method(snap,dist_default) S3method(snap,dist_quantiles) S3method(snap,distribution) +S3method(tidy,check_enough_train_data) S3method(tidy,frosting) S3method(tidy,layer) S3method(update,layer) @@ -127,6 +131,7 @@ export(arx_forecaster) export(bake) export(cdc_baseline_args_list) export(cdc_baseline_forecaster) +export(check_enough_train_data) export(create_layer) export(default_epi_recipe_blueprint) export(detect_layer) @@ -191,6 +196,12 @@ import(epiprocess) import(parsnip) import(recipes) importFrom(cli,cli_abort) +importFrom(dplyr,across) +importFrom(dplyr,all_of) +importFrom(dplyr,group_by) +importFrom(dplyr,n) +importFrom(dplyr,summarise) +importFrom(dplyr,ungroup) importFrom(epiprocess,growth_rate) importFrom(generics,augment) importFrom(generics,fit) @@ -225,6 +236,7 @@ importFrom(stats,residuals) importFrom(tibble,as_tibble) importFrom(tibble,is_tibble) importFrom(tibble,tibble) +importFrom(tidyr,drop_na) importFrom(vctrs,as_list_of) importFrom(vctrs,field) importFrom(vctrs,new_rcrd) diff --git a/NEWS.md b/NEWS.md index d14ff9db9..d2cbc0d29 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,44 +1,49 @@ # epipredict (development) +# epipredict 0.0.8 + +- add `check_enough_train_data` that will error if training data is too small +- added `check_enough_train_data` to `arx_forecaster` + # epipredict 0.0.7 -* simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()` +- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()` # epipredict 0.0.6 -* rename the `dist_quantiles()` to be more descriptive, breaking change) -* removes previous `pivot_quantiles()` (now `*_wider()`, breaking change) -* add `pivot_quantiles_wider()` for easier plotting -* add complement `pivot_quantiles_longer()` -* add `cdc_baseline_forecaster()` and `flusight_hub_formatter()` +- rename the `dist_quantiles()` to be more descriptive, breaking change) +- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change) +- add `pivot_quantiles_wider()` for easier plotting +- add complement `pivot_quantiles_longer()` +- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()` # epipredict 0.0.5 -* add `smooth_quantile_reg()` -* improved printing of various methods / internals -* canned forecasters get a class -* fixed quantile bug in `flatline_forecaster()` -* add functionality to output the unfit workflow from the canned forecasters +- add `smooth_quantile_reg()` +- improved printing of various methods / internals +- canned forecasters get a class +- fixed quantile bug in `flatline_forecaster()` +- add functionality to output the unfit workflow from the canned forecasters # epipredict 0.0.4 -* add quantile_reg() -* clean up documentation bugs -* add smooth_quantile_reg() -* add classifier -* training window step debugged -* `min_train_window` argument removed from canned forecasters +- add quantile_reg() +- clean up documentation bugs +- add smooth_quantile_reg() +- add classifier +- training window step debugged +- `min_train_window` argument removed from canned forecasters # epipredict 0.0.3 -* add forecasters -* implement postprocessing -* vignettes avaliable -* arx_forecaster -* pkgdown +- add forecasters +- implement postprocessing +- vignettes avaliable +- arx_forecaster +- pkgdown # epipredict 0.0.0.9000 -* Publish public for easy navigation -* Two simple forecasters as test beds -* Working vignette +- Publish public for easy navigation +- Two simple forecasters as test beds +- Working vignette diff --git a/R/arx_classifier.R b/R/arx_classifier.R index a03ee072b..d42247426 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -180,7 +180,21 @@ arx_class_epi_workflow <- function( role = "outcome" ) %>% step_epi_naomit() %>% - step_training_window(n_recent = args_list$n_training) + step_training_window(n_recent = args_list$n_training) %>% + { + if (!is.null(args_list$check_enough_data_n)) { + check_enough_train_data( + ., + all_predictors(), + !!outcome, + n = args_list$check_enough_data_n, + epi_keys = args_list$check_enough_data_epi_keys, + drop_na = FALSE + ) + } else { + . + } + } forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) target_date <- args_list$target_date %||% forecast_date + args_list$ahead @@ -228,6 +242,11 @@ arx_class_epi_workflow <- function( #' @param additional_gr_args List. Optional arguments controlling growth rate #' calculation. See [epiprocess::growth_rate()] and the related Vignette for #' more details. +#' @param check_enough_data_n Integer. A lower limit for the number of rows per +#' epi_key that are required for training. If `NULL`, this check is ignored. +#' @param check_enough_data_epi_keys Character vector. A character vector of +#' column names on which to group the data and check threshold within each +#' group. Useful if training per group (for example, per geo_value). #' #' @return A list containing updated parameter choices with class `arx_clist`. #' @export @@ -251,6 +270,8 @@ arx_class_args_list <- function( log_scale = FALSE, additional_gr_args = list(), nafill_buffer = Inf, + check_enough_data_n = NULL, + check_enough_data_epi_keys = NULL, ...) { rlang::check_dots_empty() .lags <- lags @@ -275,6 +296,8 @@ arx_class_args_list <- function( ) ) } + arg_is_pos(check_enough_data_n, allow_null = TRUE) + arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE) breaks <- sort(breaks) if (min(breaks) > -Inf) breaks <- c(-Inf, breaks) @@ -296,7 +319,9 @@ arx_class_args_list <- function( method, log_scale, additional_gr_args, - nafill_buffer + nafill_buffer, + check_enough_data_n, + check_enough_data_epi_keys ), class = c("arx_class", "alist") ) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index ea1e891a5..ce2fa57b0 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -126,7 +126,21 @@ arx_fcast_epi_workflow <- function( r <- r %>% step_epi_ahead(!!outcome, ahead = args_list$ahead) %>% step_epi_naomit() %>% - step_training_window(n_recent = args_list$n_training) + step_training_window(n_recent = args_list$n_training) %>% + { + if (!is.null(args_list$check_enough_data_n)) { + check_enough_train_data( + ., + all_predictors(), + !!outcome, + n = args_list$check_enough_data_n, + epi_keys = args_list$check_enough_data_epi_keys, + drop_na = FALSE + ) + } else { + . + } + } forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) target_date <- args_list$target_date %||% forecast_date + args_list$ahead @@ -199,6 +213,11 @@ arx_fcast_epi_workflow <- function( #' create a prediction. For this reason, setting `nafill_buffer < min(lags)` #' will be treated as _additional_ allowed recent data rather than the #' total amount of recent data to examine. +#' @param check_enough_data_n Integer. A lower limit for the number of rows per +#' epi_key that are required for training. If `NULL`, this check is ignored. +#' @param check_enough_data_epi_keys Character vector. A character vector of +#' column names on which to group the data and check threshold within each +#' group. Useful if training per group (for example, per geo_value). #' @param ... Space to handle future expansions (unused). #' #' @@ -220,6 +239,8 @@ arx_args_list <- function( nonneg = TRUE, quantile_by_key = character(0L), nafill_buffer = Inf, + check_enough_data_n = NULL, + check_enough_data_epi_keys = NULL, ...) { # error checking if lags is a list rlang::check_dots_empty() @@ -236,6 +257,8 @@ arx_args_list <- function( arg_is_pos(n_training) if (is.finite(n_training)) arg_is_pos_int(n_training) if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) + arg_is_pos(check_enough_data_n, allow_null = TRUE) + arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE) max_lags <- max(lags) structure( @@ -250,7 +273,9 @@ arx_args_list <- function( nonneg, max_lags, quantile_by_key, - nafill_buffer + nafill_buffer, + check_enough_data_n, + check_enough_data_epi_keys ), class = c("arx_fcast", "alist") ) diff --git a/R/check_enough_train_data.R b/R/check_enough_train_data.R new file mode 100644 index 000000000..af2183d15 --- /dev/null +++ b/R/check_enough_train_data.R @@ -0,0 +1,151 @@ +#' Check the dataset contains enough data points. +#' +#' `check_enough_train_data` creates a *specification* of a recipe +#' operation that will check if variables contain enough data. +#' +#' @param recipe A recipe object. The check will be added to the +#' sequence of operations for this recipe. +#' @param ... One or more selector functions to choose variables for this check. +#' See [selections()] for more details. You will usually want to use +#' [recipes::all_predictors()] here. +#' @param n The minimum number of data points required for training. If this is +#' NULL, the total number of predictors will be used. +#' @param epi_keys A character vector of column names on which to group the data +#' and check threshold within each group. Useful if your forecaster trains +#' per group (for example, per geo_value). +#' @param drop_na A logical for whether to count NA values as valid rows. +#' @param role Not used by this check since no new variables are +#' created. +#' @param trained A logical for whether the selectors in `...` +#' have been resolved by [prep()]. +#' @param columns An internal argument that tracks which columns are evaluated +#' for this check. Should not be used by the user. +#' @param id A character string that is unique to this check to identify it. +#' @param skip A logical. Should the check be skipped when the +#' recipe is baked by [bake()]? While all operations are baked +#' when [prep()] is run, some operations may not be able to be +#' conducted on new data (e.g. processing the outcome variable(s)). +#' Care should be taken when using `skip = TRUE` as it may affect +#' the computations for subsequent operations. +#' @family checks +#' @export +#' @details This check will break the `bake` function if any of the checked +#' columns have not enough non-NA values. If the check passes, nothing is +#' changed to the data. +#' +#' # tidy() results +#' +#' When you [`tidy()`][tidy.recipe()] this check, a tibble with column +#' `terms` (the selectors or variables selected) is returned. +#' +check_enough_train_data <- + function(recipe, + ..., + n = NULL, + epi_keys = NULL, + drop_na = TRUE, + role = NA, + trained = FALSE, + columns = NULL, + skip = TRUE, + id = rand_id("enough_train_data")) { + add_check( + recipe, + check_enough_train_data_new( + n = n, + epi_keys = epi_keys, + drop_na = drop_na, + terms = rlang::enquos(...), + role = role, + trained = trained, + columns = columns, + skip = skip, + id = id + ) + ) + } + +check_enough_train_data_new <- + function(n, epi_keys, drop_na, terms, role, trained, columns, skip, id) { + check( + subclass = "enough_train_data", + prefix = "check_", + n = n, + epi_keys = epi_keys, + drop_na = drop_na, + terms = terms, + role = role, + trained = trained, + columns = columns, + skip = skip, + id = id + ) + } + +#' @export +#' @importFrom dplyr group_by summarise ungroup across all_of n +#' @importFrom tidyr drop_na +prep.check_enough_train_data <- function(x, training, info = NULL, ...) { + col_names <- recipes_eval_select(x$terms, training, info) + if (is.null(x$n)) { + x$n <- length(col_names) + } + + cols_not_enough_data <- training %>% + { + if (x$drop_na) { + drop_na(.) + } else { + . + } + } %>% + group_by(across(all_of(.env$x$epi_keys))) %>% + summarise(across(all_of(.env$col_names), ~ n() < .env$x$n), .groups = "drop") %>% + summarise(across(all_of(.env$col_names), any), .groups = "drop") %>% + unlist() %>% + names(.)[.] + + if (length(cols_not_enough_data) > 0) { + cli::cli_abort( + "The following columns don't have enough data to predict: {cols_not_enough_data}." + ) + } + + check_enough_train_data_new( + n = x$n, + epi_keys = x$epi_keys, + drop_na = x$drop_na, + terms = x$terms, + role = x$role, + trained = TRUE, + columns = col_names, + skip = x$skip, + id = x$id + ) +} + +#' @export +bake.check_enough_train_data <- function(object, new_data, ...) { + new_data +} + +#' @export +print.check_enough_train_data <- function(x, width = max(20, options()$width - 30), ...) { + title <- paste0("Check enough data (n = ", x$n, ") for ") + print_step(x$columns, x$terms, x$trained, title, width) + invisible(x) +} + +#' @export +tidy.check_enough_train_data <- function(x, ...) { + if (is_trained(x)) { + res <- tibble(terms = unname(x$columns)) + } else { + res <- tibble(terms = sel2char(x$terms)) + } + res$id <- x$id + res$n <- x$n + res$epi_keys <- x$epi_keys + res$drop_na <- x$drop_na + res +} diff --git a/man/arx_args_list.Rd b/man/arx_args_list.Rd index e5d2391c8..c9ae4e733 100644 --- a/man/arx_args_list.Rd +++ b/man/arx_args_list.Rd @@ -15,6 +15,8 @@ arx_args_list( nonneg = TRUE, quantile_by_key = character(0L), nafill_buffer = Inf, + check_enough_data_n = NULL, + check_enough_data_epi_keys = NULL, ... ) } @@ -65,6 +67,13 @@ create a prediction. For this reason, setting \code{nafill_buffer < min(lags)} will be treated as \emph{additional} allowed recent data rather than the total amount of recent data to examine.} +\item{check_enough_data_n}{Integer. A lower limit for the number of rows per +epi_key that are required for training. If \code{NULL}, this check is ignored.} + +\item{check_enough_data_epi_keys}{Character vector. A character vector of +column names on which to group the data and check threshold within each +group. Useful if training per group (for example, per geo_value).} + \item{...}{Space to handle future expansions (unused).} } \value{ diff --git a/man/arx_class_args_list.Rd b/man/arx_class_args_list.Rd index fa7a407f0..a1205c71a 100644 --- a/man/arx_class_args_list.Rd +++ b/man/arx_class_args_list.Rd @@ -17,6 +17,8 @@ arx_class_args_list( log_scale = FALSE, additional_gr_args = list(), nafill_buffer = Inf, + check_enough_data_n = NULL, + check_enough_data_epi_keys = NULL, ... ) } @@ -84,6 +86,13 @@ create a prediction. For this reason, setting \code{nafill_buffer < min(lags)} will be treated as \emph{additional} allowed recent data rather than the total amount of recent data to examine.} +\item{check_enough_data_n}{Integer. A lower limit for the number of rows per +epi_key that are required for training. If \code{NULL}, this check is ignored.} + +\item{check_enough_data_epi_keys}{Character vector. A character vector of +column names on which to group the data and check threshold within each +group. Useful if training per group (for example, per geo_value).} + \item{...}{Space to handle future expansions (unused).} } \value{ diff --git a/man/check_enough_train_data.Rd b/man/check_enough_train_data.Rd new file mode 100644 index 000000000..57a4a9d78 --- /dev/null +++ b/man/check_enough_train_data.Rd @@ -0,0 +1,69 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/check_enough_train_data.R +\name{check_enough_train_data} +\alias{check_enough_train_data} +\title{Check the dataset contains enough data points.} +\usage{ +check_enough_train_data( + recipe, + ..., + n = NULL, + epi_keys = NULL, + drop_na = TRUE, + role = NA, + trained = FALSE, + columns = NULL, + skip = TRUE, + id = rand_id("enough_train_data") +) +} +\arguments{ +\item{recipe}{A recipe object. The check will be added to the +sequence of operations for this recipe.} + +\item{...}{One or more selector functions to choose variables for this check. +See \code{\link[=selections]{selections()}} for more details. You will usually want to use +\code{\link[recipes:has_role]{recipes::all_predictors()}} here.} + +\item{n}{The minimum number of data points required for training. If this is +NULL, the total number of predictors will be used.} + +\item{epi_keys}{A character vector of column names on which to group the data +and check threshold within each group. Useful if your forecaster trains +per group (for example, per geo_value).} + +\item{drop_na}{A logical for whether to count NA values as valid rows.} + +\item{role}{Not used by this check since no new variables are +created.} + +\item{trained}{A logical for whether the selectors in \code{...} +have been resolved by \code{\link[=prep]{prep()}}.} + +\item{columns}{An internal argument that tracks which columns are evaluated +for this check. Should not be used by the user.} + +\item{skip}{A logical. Should the check be skipped when the +recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked +when \code{\link[=prep]{prep()}} is run, some operations may not be able to be +conducted on new data (e.g. processing the outcome variable(s)). +Care should be taken when using \code{skip = TRUE} as it may affect +the computations for subsequent operations.} + +\item{id}{A character string that is unique to this check to identify it.} +} +\description{ +\code{check_enough_train_data} creates a \emph{specification} of a recipe +operation that will check if variables contain enough data. +} +\details{ +This check will break the \code{bake} function if any of the checked +columns have not enough non-NA values. If the check passes, nothing is +changed to the data. +} +\section{tidy() results}{ +When you \code{\link[=tidy.recipe]{tidy()}} this check, a tibble with column +\code{terms} (the selectors or variables selected) is returned. +} + +\concept{checks} diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R index 138a75e87..7566fd90d 100644 --- a/tests/testthat/test-arx_args_list.R +++ b/tests/testthat/test-arx_args_list.R @@ -22,6 +22,9 @@ test_that("arx_args checks inputs", { arx_args_list(target_date = as.Date("2022-01-01"))$target_date, as.Date("2022-01-01") ) + + expect_error(arx_args_list(n_training_min = "de")) + expect_error(arx_args_list(epi_keys = 1)) }) test_that("arx forecaster disambiguates quantiles", { diff --git a/tests/testthat/test-arx_cargs_list.R b/tests/testthat/test-arx_cargs_list.R index 31ed7cd10..69901220c 100644 --- a/tests/testthat/test-arx_cargs_list.R +++ b/tests/testthat/test-arx_cargs_list.R @@ -16,4 +16,7 @@ test_that("arx_class_args checks inputs", { arx_class_args_list(target_date = as.Date("2022-01-01"))$target_date, as.Date("2022-01-01") ) + + expect_error(arx_class_args_list(n_training_min = "de")) + expect_error(arx_class_args_list(epi_keys = 1)) }) diff --git a/tests/testthat/test-check_enough_train_data.R b/tests/testthat/test-check_enough_train_data.R new file mode 100644 index 000000000..5eae01bb2 --- /dev/null +++ b/tests/testthat/test-check_enough_train_data.R @@ -0,0 +1,124 @@ +# Setup toy data +n <- 10 +toy_epi_df <- tibble::tibble( + time_value = rep( + seq( + as.Date("2020-01-01"), + by = 1, + length.out = n + ), + times = 2 + ), + geo_value = rep(c("ca", "hi"), each = n), + x = c(1:n, c(1:(n - 2), NA, NA)), + y = 1:(2 * n) +) %>% epiprocess::as_epi_df() + +test_that("check_enough_train_data works on pooled data", { + # Check both columns have enough data + expect_no_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = 2 * n, drop_na = FALSE) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL) + ) + # Check both column don't have enough data + expect_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = 2 * n + 1, drop_na = FALSE) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL), + regexp = "The following columns don't have enough data" + ) + # Check drop_na works + expect_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = 2 * n - 1, drop_na = TRUE) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL) + ) +}) + +test_that("check_enough_train_data works on unpooled data", { + # Check both columns have enough data + expect_no_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = n, epi_keys = "geo_value", drop_na = FALSE) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL) + ) + # Check one column don't have enough data + expect_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = n + 1, epi_keys = "geo_value", drop_na = FALSE) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL), + regexp = "The following columns don't have enough data" + ) + # Check drop_na works + expect_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = 2 * n - 3, epi_keys = "geo_value", drop_na = TRUE) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL) + ) +}) + +test_that("check_enough_train_data outputs the correct recipe values", { + expect_no_error( + p <- epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = 2 * n - 2) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL) + ) + + expect_equal(nrow(p), 2 * n) + expect_equal(ncol(p), 4L) + expect_s3_class(p, "epi_df") + expect_named(p, c("time_value", "geo_value", "x", "y")) + expect_equal( + p$time_value, + rep(seq(as.Date("2020-01-01"), by = 1, length.out = n), times = 2) + ) + expect_equal(p$geo_value, rep(c("ca", "hi"), each = n)) +}) + +test_that("check_enough_train_data only checks train data", { + # Check that the train data has enough data, the test data does not, but + # the check passes anyway (because it should be applied to training data) + toy_test_data <- toy_epi_df %>% + group_by(geo_value) %>% + slice(3:10) %>% + epiprocess::as_epi_df() + expect_no_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value") %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = toy_test_data) + ) + # Same thing, but skip = FALSE + expect_no_error( + epi_recipe(toy_epi_df) %>% + check_enough_train_data(y, n = n - 2, epi_keys = "geo_value", skip = FALSE) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = toy_test_data) + ) +}) + +test_that("check_enough_train_data works with all_predictors() downstream of constructed terms", { + # With a lag of 2, we will get 2 * n - 6 non-NA rows + expect_no_error( + epi_recipe(toy_epi_df) %>% + step_epi_lag(x, lag = c(1, 2)) %>% + check_enough_train_data(all_predictors(), y, n = 2 * n - 6) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL) + ) + expect_error( + epi_recipe(toy_epi_df) %>% + step_epi_lag(x, lag = c(1, 2)) %>% + check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>% + recipes::prep(toy_epi_df) %>% + recipes::bake(new_data = NULL) + ) +})