Skip to content

using check_enough_train_data in practice #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 28, 2025
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.13
Version: 0.1.14
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Add `climatological_forecaster()` to automatically create climate baselines
- Replace `dist_quantiles()` with `hardhat::quantile_pred()`
- Allow `quantile()` to threshold to an interval if desired (#434)
- `arx_forecaster()` detects if there's enough data to predict

## Bug fixes

Expand Down
4 changes: 3 additions & 1 deletion R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ arx_fcast_epi_workflow <- function(
step_epi_ahead(!!outcome, ahead = args_list$ahead)
r <- r %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
check_enough_train_data(all_predictors(), n = args_list$check_enough_data_n, skip = FALSE)

if (!is.null(args_list$check_enough_data_n)) {
r <- r %>% check_enough_train_data(
all_predictors(),
Expand Down
26 changes: 23 additions & 3 deletions R/check_enough_train_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ check_enough_train_data <-
role = NA,
trained = FALSE,
columns = NULL,
skip = TRUE,
skip = FALSE,
id = rand_id("enough_train_data")) {
recipes::add_check(
recipe,
Expand Down Expand Up @@ -90,7 +90,7 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
}

if (x$drop_na) {
training <- tidyr::drop_na(training)
training <- tidyr::drop_na(training, any_of(unname(col_names)))
}
cols_not_enough_data <- training %>%
group_by(across(all_of(.env$x$epi_keys))) %>%
Expand All @@ -101,7 +101,8 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {

if (length(cols_not_enough_data) > 0) {
cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}."
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
class = "epipredict__not_enough_train_data"
)
}

Expand All @@ -120,6 +121,25 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {

#' @export
bake.check_enough_train_data <- function(object, new_data, ...) {
col_names <- object$columns
if (object$drop_na) {
non_na_data <- tidyr::drop_na(new_data, any_of(unname(col_names)))
} else {
non_na_data <- new_data
}
cols_not_enough_data <- non_na_data %>%
group_by(across(all_of(.env$object$epi_keys))) %>%
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$object$n), .groups = "drop") %>%
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
unlist() %>%
names(.)[.]

if (length(cols_not_enough_data) > 0) {
cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
class = "epipredict__not_enough_train_data"
)
}
new_data
}

Expand Down
2 changes: 1 addition & 1 deletion R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,6 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
hardhat::extract_preprocessor(object),
object$original_data
)

test_data
predict(object, new_data = test_data)
}
2 changes: 1 addition & 1 deletion tests/testthat/_snaps/check_enough_train_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

Code
epi_recipe(toy_epi_df) %>% step_epi_lag(x, lag = c(1, 2)) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>% prep(
check_enough_train_data(all_predictors(), y, n = 2 * n - 4) %>% prep(
toy_epi_df) %>% bake(new_data = NULL)
Condition
Error in `prep()`:
Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/test-arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,22 @@ test_that("arx_forecaster errors if forecast date, target date, and ahead are in
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
)
})

test_that("warns if there's not enough data to predict", {
edf <- tibble(
geo_value = "ct",
time_value = seq(as.Date("2020-10-01"), as.Date("2023-05-31"), by = "day"),
) %>%
mutate(value = seq_len(nrow(.)) + rnorm(nrow(.))) %>%
# Oct to May (flu season, ish) only:
filter(!between(as.POSIXlt(time_value)$mon + 1L, 6L, 9L)) %>%
# and actually, pretend we're around mid-October 2022:
filter(time_value <= as.Date("2022-10-12")) %>%
as_epi_df(as_of = as.Date("2022-10-12"))
edf %>% filter(time_value > "2022-08-01")

expect_error(
edf %>% arx_forecaster("value"),
class = "epipredict__not_enough_train_data"
)
})
11 changes: 6 additions & 5 deletions tests/testthat/test-check_enough_train_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,33 +94,34 @@ test_that("check_enough_train_data only checks train data", {
epiprocess::as_epi_df()
expect_no_error(
epi_recipe(toy_epi_df) %>%
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value") %>%
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value", skip = TRUE) %>%
prep(toy_epi_df) %>%
bake(new_data = toy_test_data)
)
# Same thing, but skip = FALSE
expect_no_error(
epi_recipe(toy_epi_df) %>%
check_enough_train_data(y, n = n - 2, epi_keys = "geo_value", skip = FALSE) %>%
check_enough_train_data(y, n = n - 2, epi_keys = "geo_value") %>%
prep(toy_epi_df) %>%
bake(new_data = toy_test_data)
)
})

test_that("check_enough_train_data works with all_predictors() downstream of constructed terms", {
# With a lag of 2, we will get 2 * n - 6 non-NA rows
# With a lag of 2, we will get 2 * n - 5 non-NA rows (NA's in x but not in the
# lags don't count)
expect_no_error(
epi_recipe(toy_epi_df) %>%
step_epi_lag(x, lag = c(1, 2)) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 6) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>%
prep(toy_epi_df) %>%
bake(new_data = NULL)
)
expect_snapshot(
error = TRUE,
epi_recipe(toy_epi_df) %>%
step_epi_lag(x, lag = c(1, 2)) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>%
check_enough_train_data(all_predictors(), y, n = 2 * n - 4) %>%
prep(toy_epi_df) %>%
bake(new_data = NULL)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ test_that("Canned forecasters work with / without", {
)

expect_silent(
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"))
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), args_list = arx_args_list(check_enough_data_n = 1))
)
expect_silent(
flatline_forecaster(
Expand Down
Loading