Skip to content

feat: check_enough_train_data #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
^renv$
^renv\.lock$
^epipredict\.Rproj$
^\.Rproj\.user$
^LICENSE\.md$
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ inst/doc
.DS_Store
/doc/
/Meta/
.Rprofile
renv.lock
renv/
10 changes: 5 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.7
Version: 0.0.8
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand All @@ -22,11 +22,11 @@ License: MIT + file LICENSE
URL: https://github.com/cmu-delphi/epipredict/,
https://cmu-delphi.github.io/epipredict
BugReports: https://github.com/cmu-delphi/epipredict/issues/
Depends:
Depends:
epiprocess (>= 0.6.0),
parsnip (>= 1.0.0),
R (>= 3.5.0)
Imports:
Imports:
cli,
distributional,
dplyr,
Expand All @@ -48,7 +48,7 @@ Imports:
usethis,
vctrs,
workflows (>= 1.0.0)
Suggests:
Suggests:
covidcast,
data.table,
epidatr (>= 1.0.0),
Expand All @@ -61,7 +61,7 @@ Suggests:
rmarkdown,
testthat (>= 3.0.0),
xgboost
VignetteBuilder:
VignetteBuilder:
knitr
Remotes:
cmu-delphi/epidatr,
Expand Down
12 changes: 12 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(adjust_frosting,frosting)
S3method(apply_frosting,default)
S3method(apply_frosting,epi_workflow)
S3method(augment,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_epi_ahead)
S3method(bake,step_epi_lag)
Expand Down Expand Up @@ -48,6 +49,7 @@ S3method(mean,dist_quantiles)
S3method(median,dist_quantiles)
S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,epi_recipe)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
Expand All @@ -60,6 +62,7 @@ S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,check_enough_train_data)
S3method(print,epi_recipe)
S3method(print,epi_workflow)
S3method(print,flat_fcast)
Expand Down Expand Up @@ -104,6 +107,7 @@ S3method(snap,default)
S3method(snap,dist_default)
S3method(snap,dist_quantiles)
S3method(snap,distribution)
S3method(tidy,check_enough_train_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
Expand All @@ -127,6 +131,7 @@ export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(check_enough_train_data)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
Expand Down Expand Up @@ -191,6 +196,12 @@ import(epiprocess)
import(parsnip)
import(recipes)
importFrom(cli,cli_abort)
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,group_by)
importFrom(dplyr,n)
importFrom(dplyr,summarise)
importFrom(dplyr,ungroup)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
Expand Down Expand Up @@ -225,6 +236,7 @@ importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,is_tibble)
importFrom(tibble,tibble)
importFrom(tidyr,drop_na)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
Expand Down
55 changes: 30 additions & 25 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
# epipredict (development)

# epipredict 0.0.8

- add `check_enough_train_data` that will error if training data is too small
- added `check_enough_train_data` to `arx_forecaster`

# epipredict 0.0.7

* simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`

# epipredict 0.0.6

* rename the `dist_quantiles()` to be more descriptive, breaking change)
* removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
* add `pivot_quantiles_wider()` for easier plotting
* add complement `pivot_quantiles_longer()`
* add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`
- rename the `dist_quantiles()` to be more descriptive, breaking change)
- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
- add `pivot_quantiles_wider()` for easier plotting
- add complement `pivot_quantiles_longer()`
- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`

# epipredict 0.0.5

* add `smooth_quantile_reg()`
* improved printing of various methods / internals
* canned forecasters get a class
* fixed quantile bug in `flatline_forecaster()`
* add functionality to output the unfit workflow from the canned forecasters
- add `smooth_quantile_reg()`
- improved printing of various methods / internals
- canned forecasters get a class
- fixed quantile bug in `flatline_forecaster()`
- add functionality to output the unfit workflow from the canned forecasters

# epipredict 0.0.4

* add quantile_reg()
* clean up documentation bugs
* add smooth_quantile_reg()
* add classifier
* training window step debugged
* `min_train_window` argument removed from canned forecasters
- add quantile_reg()
- clean up documentation bugs
- add smooth_quantile_reg()
- add classifier
- training window step debugged
- `min_train_window` argument removed from canned forecasters

# epipredict 0.0.3

* add forecasters
* implement postprocessing
* vignettes avaliable
* arx_forecaster
* pkgdown
- add forecasters
- implement postprocessing
- vignettes avaliable
- arx_forecaster
- pkgdown

# epipredict 0.0.0.9000

* Publish public for easy navigation
* Two simple forecasters as test beds
* Working vignette
- Publish public for easy navigation
- Two simple forecasters as test beds
- Working vignette
29 changes: 27 additions & 2 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,21 @@ arx_class_epi_workflow <- function(
role = "outcome"
) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
{
if (!is.null(args_list$check_enough_data_n)) {
check_enough_train_data(
.,
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
} else {
.
}
}

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
Expand Down Expand Up @@ -228,6 +242,11 @@ arx_class_epi_workflow <- function(
#' @param additional_gr_args List. Optional arguments controlling growth rate
#' calculation. See [epiprocess::growth_rate()] and the related Vignette for
#' more details.
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
#' epi_key that are required for training. If `NULL`, this check is ignored.
#' @param check_enough_data_epi_keys Character vector. A character vector of
#' column names on which to group the data and check threshold within each
#' group. Useful if training per group (for example, per geo_value).
#'
#' @return A list containing updated parameter choices with class `arx_clist`.
#' @export
Expand All @@ -251,6 +270,8 @@ arx_class_args_list <- function(
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf,
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
rlang::check_dots_empty()
.lags <- lags
Expand All @@ -275,6 +296,8 @@ arx_class_args_list <- function(
)
)
}
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

breaks <- sort(breaks)
if (min(breaks) > -Inf) breaks <- c(-Inf, breaks)
Expand All @@ -296,7 +319,9 @@ arx_class_args_list <- function(
method,
log_scale,
additional_gr_args,
nafill_buffer
nafill_buffer,
check_enough_data_n,
check_enough_data_epi_keys
),
class = c("arx_class", "alist")
)
Expand Down
29 changes: 27 additions & 2 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,21 @@ arx_fcast_epi_workflow <- function(
r <- r %>%
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
{
if (!is.null(args_list$check_enough_data_n)) {
check_enough_train_data(
.,
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
} else {
.
}
}

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
Expand Down Expand Up @@ -199,6 +213,11 @@ arx_fcast_epi_workflow <- function(
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
#' will be treated as _additional_ allowed recent data rather than the
#' total amount of recent data to examine.
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
#' epi_key that are required for training. If `NULL`, this check is ignored.
#' @param check_enough_data_epi_keys Character vector. A character vector of
#' column names on which to group the data and check threshold within each
#' group. Useful if training per group (for example, per geo_value).
#' @param ... Space to handle future expansions (unused).
#'
#'
Expand All @@ -220,6 +239,8 @@ arx_args_list <- function(
nonneg = TRUE,
quantile_by_key = character(0L),
nafill_buffer = Inf,
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
# error checking if lags is a list
rlang::check_dots_empty()
Expand All @@ -236,6 +257,8 @@ arx_args_list <- function(
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

max_lags <- max(lags)
structure(
Expand All @@ -250,7 +273,9 @@ arx_args_list <- function(
nonneg,
max_lags,
quantile_by_key,
nafill_buffer
nafill_buffer,
check_enough_data_n,
check_enough_data_epi_keys
),
class = c("arx_fcast", "alist")
)
Expand Down
Loading