Skip to content

feat: check_enough_train_data #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
^renv$
^renv\.lock$
^epipredict\.Rproj$
^\.Rproj\.user$
^LICENSE\.md$
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ inst/doc
.DS_Store
/doc/
/Meta/
.Rprofile
renv.lock
renv/
12 changes: 12 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(adjust_frosting,frosting)
S3method(apply_frosting,default)
S3method(apply_frosting,epi_workflow)
S3method(augment,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_epi_ahead)
S3method(bake,step_epi_lag)
Expand Down Expand Up @@ -48,6 +49,7 @@ S3method(mean,dist_quantiles)
S3method(median,dist_quantiles)
S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,epi_recipe)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
Expand All @@ -60,6 +62,7 @@ S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,check_enough_train_data)
S3method(print,epi_recipe)
S3method(print,epi_workflow)
S3method(print,flat_fcast)
Expand Down Expand Up @@ -104,6 +107,7 @@ S3method(snap,default)
S3method(snap,dist_default)
S3method(snap,dist_quantiles)
S3method(snap,distribution)
S3method(tidy,check_enough_train_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
Expand All @@ -127,6 +131,7 @@ export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(check_enough_train_data)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
Expand Down Expand Up @@ -191,6 +196,12 @@ import(epiprocess)
import(parsnip)
import(recipes)
importFrom(cli,cli_abort)
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,group_by)
importFrom(dplyr,n)
importFrom(dplyr,summarise)
importFrom(dplyr,ungroup)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
Expand Down Expand Up @@ -225,6 +236,7 @@ importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,is_tibble)
importFrom(tibble,tibble)
importFrom(tidyr,drop_na)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
Expand Down
55 changes: 30 additions & 25 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
# epipredict (development)

# epipredict 0.0.7.9000

- add `check_enough_train_data` that will error if training data is too small
- added `check_enough_train_data` to `arx_forecaster`

# epipredict 0.0.7

* simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`

# epipredict 0.0.6

* rename the `dist_quantiles()` to be more descriptive, breaking change)
* removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
* add `pivot_quantiles_wider()` for easier plotting
* add complement `pivot_quantiles_longer()`
* add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`
- rename the `dist_quantiles()` to be more descriptive, breaking change)
- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
- add `pivot_quantiles_wider()` for easier plotting
- add complement `pivot_quantiles_longer()`
- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`

# epipredict 0.0.5

* add `smooth_quantile_reg()`
* improved printing of various methods / internals
* canned forecasters get a class
* fixed quantile bug in `flatline_forecaster()`
* add functionality to output the unfit workflow from the canned forecasters
- add `smooth_quantile_reg()`
- improved printing of various methods / internals
- canned forecasters get a class
- fixed quantile bug in `flatline_forecaster()`
- add functionality to output the unfit workflow from the canned forecasters

# epipredict 0.0.4

* add quantile_reg()
* clean up documentation bugs
* add smooth_quantile_reg()
* add classifier
* training window step debugged
* `min_train_window` argument removed from canned forecasters
- add quantile_reg()
- clean up documentation bugs
- add smooth_quantile_reg()
- add classifier
- training window step debugged
- `min_train_window` argument removed from canned forecasters

# epipredict 0.0.3

* add forecasters
* implement postprocessing
* vignettes avaliable
* arx_forecaster
* pkgdown
- add forecasters
- implement postprocessing
- vignettes avaliable
- arx_forecaster
- pkgdown

# epipredict 0.0.0.9000

* Publish public for easy navigation
* Two simple forecasters as test beds
* Working vignette
- Publish public for easy navigation
- Two simple forecasters as test beds
- Working vignette
23 changes: 21 additions & 2 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,14 @@ arx_fcast_epi_workflow <- function(
r <- r %>%
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
check_enough_train_data(
all_predictors(),
!!outcome,
n = args_list$n_training_min,
epi_keys = args_list$epi_keys,
drop_na = FALSE
)

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
Expand Down Expand Up @@ -199,6 +206,12 @@ arx_fcast_epi_workflow <- function(
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
#' will be treated as _additional_ allowed recent data rather than the
#' total amount of recent data to examine.
#' @param n_training_min Integer. The minimum number of rows per
#' epi_key that are required for training. If `NULL`, this will be set to
#' `number of predictors + 5`.
#' @param epi_keys Character vector. A character vector of column names on
#' which to group the data and check threshold within each group. Useful if
#' training per group (for example, per geo_value).
#' @param ... Space to handle future expansions (unused).
#'
#'
Expand All @@ -220,6 +233,8 @@ arx_args_list <- function(
nonneg = TRUE,
quantile_by_key = character(0L),
nafill_buffer = Inf,
n_training_min = NULL,
epi_keys = NULL,
...) {
# error checking if lags is a list
rlang::check_dots_empty()
Expand All @@ -236,6 +251,8 @@ arx_args_list <- function(
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
arg_is_pos(n_training_min, allow_null = TRUE)
arg_is_chr(epi_keys, allow_null = TRUE)

max_lags <- max(lags)
structure(
Expand All @@ -250,7 +267,9 @@ arx_args_list <- function(
nonneg,
max_lags,
quantile_by_key,
nafill_buffer
nafill_buffer,
n_training_min,
epi_keys
),
class = c("arx_fcast", "alist")
)
Expand Down
149 changes: 149 additions & 0 deletions R/check_enough_train_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#' Check the dataset contains enough data points.
#'
#' `check_enough_train_data` creates a *specification* of a recipe
#' operation that will check if variables contain enough data.
#'
#' @param recipe A recipe object. The check will be added to the
#' sequence of operations for this recipe.
#' @param ... One or more selector functions to choose variables
#' for this check. See [selections()] for more details.
#' @param n The minimum number of data points required for training.
#' @param epi_keys A character vector of column names on which to group the data
#' and check threshold within each group. Useful if your forecaster trains
#' per group (for example, per geo_value).
#' @param drop_na A logical for whether to count NA values as valid rows.
#' @param role Not used by this check since no new variables are
#' created.
#' @param trained A logical for whether the selectors in `...`
#' have been resolved by [prep()].
#' @param columns An internal argument that tracks which columns are evaluated
#' for this check. Should not be used by the user.
#' @param id A character string that is unique to this check to identify it.
#' @param skip A logical. Should the check be skipped when the
#' recipe is baked by [bake()]? While all operations are baked
#' when [prep()] is run, some operations may not be able to be
#' conducted on new data (e.g. processing the outcome variable(s)).
#' Care should be taken when using `skip = TRUE` as it may affect
#' the computations for subsequent operations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if skip=TRUE by default would solve the issue about running during fit vs predict? looks like you have a test demonstrating it does do that! So we definitely have a functional check for training data, if not test data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup! I'd like to handle test data checking next, unclear if that will be possible.

#' @family checks
#' @export
#' @details This check will break the `bake` function if any of the checked
#' columns have not enough non-NA values. If the check passes, nothing is
#' changed to the data.
#'
#' # tidy() results
#'
#' When you [`tidy()`][tidy.recipe()] this check, a tibble with column
#' `terms` (the selectors or variables selected) is returned.
#'
check_enough_train_data <-
function(recipe,
...,
n = NULL,
epi_keys = NULL,
drop_na = TRUE,
role = NA,
trained = FALSE,
columns = NULL,
skip = TRUE,
id = rand_id("enough_train_data")) {
add_check(
recipe,
check_enough_train_data_new(
n = n,
epi_keys = epi_keys,
drop_na = drop_na,
terms = rlang::enquos(...),
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
)
}

check_enough_train_data_new <-
function(n, epi_keys, drop_na, terms, role, trained, columns, skip, id) {
check(
subclass = "enough_train_data",
prefix = "check_",
n = n,
epi_keys = epi_keys,
drop_na = drop_na,
terms = terms,
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
}

#' @export
#' @importFrom dplyr group_by summarise ungroup across all_of n
#' @importFrom tidyr drop_na
prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
if (is.null(x$n)) {
x$n <- length(col_names) + 5
}

cols_not_enough_data <- training %>%
{
if (x$drop_na) {
drop_na(.)
} else {
.
}
} %>%
group_by(across(all_of(.env$x$epi_keys))) %>%
summarise(across(all_of(.env$col_names), ~ n() < .env$x$n), .groups = "drop") %>%
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
unlist() %>%
names(.)[.]

if (length(cols_not_enough_data) > 0) {
cli::cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}."
)
}

check_enough_train_data_new(
n = x$n,
epi_keys = x$epi_keys,
drop_na = x$drop_na,
terms = x$terms,
role = x$role,
trained = TRUE,
columns = col_names,
skip = x$skip,
id = x$id
)
}

#' @export
bake.check_enough_train_data <- function(object, new_data, ...) {
new_data
}

#' @export
print.check_enough_train_data <- function(x, width = max(20, options()$width - 30), ...) {
title <- paste0("Check enough data (n = ", x$n, ") for ")
print_step(x$columns, x$terms, x$trained, title, width)
invisible(x)
}

#' @export
tidy.check_enough_train_data <- function(x, ...) {
if (is_trained(x)) {
res <- tibble(terms = unname(x$columns))
} else {
res <- tibble(terms = sel2char(x$terms))
}
res$id <- x$id
res$n <- x$n
res$epi_keys <- x$epi_keys
res$drop_na <- x$drop_na
res
}
10 changes: 10 additions & 0 deletions man/arx_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading