Skip to content

feat: check_enough_train_data #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(adjust_frosting,frosting)
S3method(apply_frosting,default)
S3method(apply_frosting,epi_workflow)
S3method(augment,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_epi_ahead)
S3method(bake,step_epi_lag)
Expand Down Expand Up @@ -48,6 +49,7 @@ S3method(mean,dist_quantiles)
S3method(median,dist_quantiles)
S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,epi_recipe)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
Expand Down Expand Up @@ -104,6 +106,7 @@ S3method(snap,default)
S3method(snap,dist_default)
S3method(snap,dist_quantiles)
S3method(snap,distribution)
S3method(tidy,check_enough_train_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
Expand All @@ -127,6 +130,7 @@ export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(check_enough_train_data)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
Expand Down
146 changes: 146 additions & 0 deletions R/check_enough_train_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#' Check the dataset contains enough data points.
#'
#' `check_enough_train_data` creates a *specification* of a recipe
#' operation that will check if variables contain enough data.
#'
#' @param recipe A recipe object. The check will be added to the
#' sequence of operations for this recipe.
#' @param ... One or more selector functions to choose variables
#' for this check. See [selections()] for more details.
#' @param n The minimum number of data points required for training.
#' @param epi_keys A character vector of column names on which to group the data
#' and check threshold within each group. Useful if your forecaster trains
#' per group (for example, per geo_value).
#' @param drop_na A logical for whether to count NA values as valid rows.
#' @param role Not used by this check since no new variables are
#' created.
#' @param trained A logical for whether the selectors in `...`
#' have been resolved by [prep()].
#' @param id A character string that is unique to this check to identify it.
#' @param skip A logical. Should the check be skipped when the
#' recipe is baked by [bake()]? While all operations are baked
#' when [prep()] is run, some operations may not be able to be
#' conducted on new data (e.g. processing the outcome variable(s)).
#' Care should be taken when using `skip = TRUE` as it may affect
#' the computations for subsequent operations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if skip=TRUE by default would solve the issue about running during fit vs predict? looks like you have a test demonstrating it does do that! So we definitely have a functional check for training data, if not test data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup! I'd like to handle test data checking next, unclear if that will be possible.

#' @family checks
#' @export
#' @details This check will break the `bake` function if any of the checked
#' columns have not enough non-NA values. If the check passes, nothing is
#' changed to the data.
#'
#' # tidy() results
#'
#' When you [`tidy()`][tidy.recipe()] this check, a tibble with column
#' `terms` (the selectors or variables selected) is returned.
#'
check_enough_train_data <-
function(recipe,
...,
n,
epi_keys = NULL,
drop_na = TRUE,
role = NA,
trained = FALSE,
columns = NULL,
skip = TRUE,
id = rand_id("enough_train_data")) {
add_check(
recipe,
check_enough_train_data_new(
n = n,
epi_keys = epi_keys,
drop_na = drop_na,
terms = rlang::enquos(...),
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
)
}

check_enough_train_data_new <-
function(n, epi_keys, drop_na, terms, role, trained, columns, skip, id) {
check(
subclass = "enough_train_data",
prefix = "check_",
n = n,
epi_keys = epi_keys,
drop_na = drop_na,
terms = terms,
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
}

#' @export
prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

cols_not_enough_data <- purrr::map(col_names, function(col) {
groups_below_thresh <- training %>%
dplyr::select(all_of(c(epi_keys(training), col))) %>%
{
if (x$drop_na) {
tidyr::drop_na(.)
} else {
.
}
} %>%
dplyr::count(dplyr::across(dplyr::all_of(x$epi_keys))) %>%
dplyr::filter(n < x$n)
if (nrow(groups_below_thresh) > 0) {
col
}
}) %>% purrr::keep(~ !is.null(.))

if (length(cols_not_enough_data) > 0) {
cli::cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}."
)
}

check_enough_train_data_new(
n = x$n,
epi_keys = x$epi_keys,
drop_na = x$drop_na,
terms = x$terms,
role = x$role,
trained = TRUE,
columns = col_names,
skip = x$skip,
id = x$id
)
}

#' @export
bake.check_enough_train_data <- function(object, new_data, ...) {
new_data
}

print.check_enough_train_data <-
function(x, width = max(20, options()$width - 30), ...) {
title <- paste0("Check enough data (n = ", x$n, ") for ")
print_step(x$columns, x$terms, x$trained, title, width)
invisible(x)
}

#' @rdname tidy.recipe
#' @export
tidy.check_enough_train_data <- function(x, ...) {
if (is_trained(x)) {
res <- tibble(terms = unname(x$columns))
} else {
res <- tibble(terms = sel2char(x$terms))
}
res$id <- x$id
res$n <- x$n
res$epi_keys <- x$epi_keys
res$drop_na <- x$drop_na
res
}
64 changes: 64 additions & 0 deletions man/check_enough_train_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

125 changes: 125 additions & 0 deletions tests/testthat/test-check_enough_train_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Setup toy data
n <- 10
toy_epi_df <- tibble::tibble(
time_value = rep(
seq(
as.Date("2020-01-01"),
by = 1,
length.out = n
),
times = 2
),
geo_value = rep(c("ca", "hi"), each = n),
x = c(1:n, c(1:(n - 2), NA, NA)),
y = 1:(2 * n)
) %>% epiprocess::as_epi_df()

test_that("check_enough_train_data works on pooled data", {
# Check both columns have enough data, with geo pooling
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, y, n = 2 * n - 2) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)
# Check one column doesn't have enough data, with geo pooling
expect_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, y, n = 2 * n - 1) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL),
regexp = "The following columns don't have enough data"
)
# Check column without enough data doesn't error when not checked, with geo
# pooling
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(y, n = 2 * n - 1) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)
})
test_that("check_enough_train_data works on unpooled data", {
# Check both columns have enough data, without geo pooling
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value") %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)
# Check one column doesn't have enough data, without geo pooling
expect_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, y, n = n - 1, epi_keys = "geo_value") %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL),
regexp = "The following columns don't have enough data"
)
# Check column without enough data doesn't error when not checked, without geo
# pooling
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(y, n = n - 1, epi_keys = "geo_value") %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)
# Check column with NAs counts the NAs if drop_na=TRUE, without geo pooling
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, n = n - 1, epi_keys = "geo_value", drop_na = FALSE) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)
})

test_that("check_enough_train_data outputs the correct values", {
# Sanity check the output of a passing recipe
expect_no_error(
p <- epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, y, n = 2 * n - 2) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)

expect_equal(nrow(p), 2 * n)
expect_equal(ncol(p), 4L)
expect_s3_class(p, "epi_df")
expect_named(p, c("time_value", "geo_value", "x", "y"))
expect_equal(
p$time_value,
rep(seq(as.Date("2020-01-01"), by = 1, length.out = n), times = 2)
)
expect_equal(p$geo_value, rep(c("ca", "hi"), each = n))
})

test_that("check_enough_train_data only checks train data", {
# Check that the train data has enough data, the test data does not, but
# the check passes anyway (because it should be applied to training data)
n_minus <- n - 2
toy_test_data <- tibble::tibble(
time_value = rep(
seq(
as.Date("2020-01-01"),
by = 1,
length.out = n_minus
),
times = 2
),
geo_value = rep(c("ca", "hi"), each = n_minus),
x = c(1:n_minus, c(1:(n_minus - 2), NA, NA)),
y = 1:(2 * n_minus)
) %>% epiprocess::as_epi_df()
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(y, n = n - 1, epi_keys = "geo_value") %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = toy_test_data)
)
# Same thing, but skip = FALSE
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(y, n = n - 1, epi_keys = "geo_value", skip = FALSE) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = toy_test_data)
)
})