Skip to content

Commit 714b972

Browse files
authored
Merge pull request #53 from cmu-delphi/36-step_training_window
Created a preprocessing step that limits the size of the training window
2 parents a0d7cbd + 57fcae7 commit 714b972

10 files changed

+445
-4
lines changed

NAMESPACE

+9
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ S3method(Ops,dist_quantiles)
55
S3method(apply_frosting,default)
66
S3method(apply_frosting,epi_workflow)
77
S3method(augment,epi_workflow)
8+
S3method(bake,epi_recipe)
89
S3method(bake,step_epi_ahead)
910
S3method(bake,step_epi_lag)
1011
S3method(bake,step_population_scaling)
12+
S3method(bake,step_training_window)
1113
S3method(detect_layer,frosting)
1214
S3method(detect_layer,workflow)
1315
S3method(epi_keys,default)
@@ -36,10 +38,12 @@ S3method(prep,epi_recipe)
3638
S3method(prep,step_epi_ahead)
3739
S3method(prep,step_epi_lag)
3840
S3method(prep,step_population_scaling)
41+
S3method(prep,step_training_window)
3942
S3method(print,epi_workflow)
4043
S3method(print,frosting)
4144
S3method(print,step_epi_ahead)
4245
S3method(print,step_epi_lag)
46+
S3method(print,step_training_window)
4347
S3method(quantile,dist_quantiles)
4448
S3method(refresh_blueprint,default_epi_recipe_blueprint)
4549
S3method(residuals,flatline)
@@ -105,6 +109,7 @@ export(step_epi_ahead)
105109
export(step_epi_lag)
106110
export(step_epi_naomit)
107111
export(step_population_scaling)
112+
export(step_training_window)
108113
export(validate_layer)
109114
import(distributional)
110115
import(recipes)
@@ -119,7 +124,9 @@ importFrom(rlang,":=")
119124
importFrom(rlang,`%||%`)
120125
importFrom(rlang,abort)
121126
importFrom(rlang,caller_env)
127+
importFrom(rlang,is_empty)
122128
importFrom(rlang,is_null)
129+
importFrom(rlang,quos)
123130
importFrom(stats,as.formula)
124131
importFrom(stats,family)
125132
importFrom(stats,lm)
@@ -130,4 +137,6 @@ importFrom(stats,predict)
130137
importFrom(stats,qnorm)
131138
importFrom(stats,quantile)
132139
importFrom(stats,residuals)
140+
importFrom(stats,setNames)
141+
importFrom(tibble,is_tibble)
133142
importFrom(tibble,tibble)

R/bake.epi_recipe.R

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#' Bake an epi_recipe
2+
#'
3+
#' @param object A trained object such as a [recipe()] with at least
4+
#' one preprocessing operation.
5+
#' @param new_data An `epi_df`, data frame or tibble for whom the
6+
#' preprocessing will be applied. If `NULL` is given to `new_data`,
7+
#' the pre-processed _training data_ will be returned.
8+
#' @param ... One or more selector functions to choose which variables will be
9+
#' returned by the function. See \code{\link[selections]{recipes}} for
10+
#' more details. If no selectors are given, the default is to
11+
#' use [everything()].
12+
#' @return An `epi_df` that may have different columns than the
13+
#' original columns in `new_data`.
14+
#' @importFrom rlang is_empty quos
15+
#' @importFrom tibble is_tibble
16+
#' @rdname bake
17+
#' @export
18+
bake.epi_recipe <- function(object, new_data, ...) {
19+
20+
if (rlang::is_missing(new_data)) {
21+
rlang::abort("'new_data' must be either an epi_df or NULL. No value is not allowed.")
22+
}
23+
24+
if (is.null(new_data)) {
25+
return(epi_juice(object, ...))
26+
}
27+
28+
if (!fully_trained(object)) {
29+
rlang::abort("At least one step has not been trained. Please run `prep`.")
30+
}
31+
32+
terms <- quos(...)
33+
if (is_empty(terms)) {
34+
terms <- quos(everything())
35+
}
36+
37+
# In case someone used the deprecated `newdata`:
38+
if (is.null(new_data) || is.null(ncol(new_data))) {
39+
if (any(names(terms) == "newdata")) {
40+
rlang::abort("Please use `new_data` instead of `newdata` with `bake`.")
41+
} else {
42+
rlang::abort("Please pass a data set to `new_data`.")
43+
}
44+
}
45+
46+
if (!is_tibble(new_data)) {
47+
new_data <- as_tibble(new_data)
48+
}
49+
50+
recipes:::check_role_requirements(object, new_data)
51+
52+
recipes:::check_nominal_type(new_data, object$orig_lvls)
53+
54+
# Drop completely new columns from `new_data` and reorder columns that do
55+
# still exist to match the ordering used when training
56+
original_names <- names(new_data)
57+
original_training_names <- unique(object$var_info$variable)
58+
bakeable_names <- intersect(original_training_names, original_names)
59+
new_data <- new_data[, bakeable_names]
60+
61+
n_steps <- length(object$steps)
62+
63+
for (i in seq_len(n_steps)) {
64+
step <- object$steps[[i]]
65+
66+
if (recipes:::is_skipable(step)) {
67+
next
68+
}
69+
70+
new_data <- bake(step, new_data = new_data)
71+
72+
if (!is_tibble(new_data)) {
73+
abort("bake() methods should always return tibbles")
74+
}
75+
}
76+
77+
# Use `last_term_info`, which maintains info on all columns that got added
78+
# and removed from the training data. This is important for skipped steps
79+
# which might have resulted in columns not being added/removed in the test
80+
# set.
81+
info <- object$last_term_info
82+
83+
# Now reduce to only user selected columns
84+
out_names <- recipes_eval_select(terms, new_data, info,
85+
check_case_weights = FALSE)
86+
new_data <- new_data[, out_names]
87+
88+
# The levels are not null when no nominal data are present or
89+
# if strings_as_factors = FALSE in `prep`
90+
if (!is.null(object$levels)) {
91+
var_levels <- object$levels
92+
var_levels <- var_levels[out_names]
93+
check_values <-
94+
vapply(var_levels, function(x) {
95+
(!all(is.na(x)))
96+
}, c(all = TRUE))
97+
var_levels <- var_levels[check_values]
98+
if (length(var_levels) > 0) {
99+
new_data <- recipes:::strings2factors(new_data, var_levels)
100+
}
101+
}
102+
103+
new_data
104+
}

R/epi_juice.R

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#' Extract transformed training set
2+
#'
3+
#' @inheritParams bake.epi_recipe
4+
epi_juice <- function(object, ...) {
5+
if (!fully_trained(object)) {
6+
rlang::abort("At least one step has not been trained. Please run `prep()`.")
7+
}
8+
9+
if (!isTRUE(object$retained)) {
10+
rlang::abort(paste0(
11+
"Use `retain = TRUE` in `prep()` to be able ",
12+
"to extract the training set"
13+
))
14+
}
15+
16+
terms <- quos(...)
17+
if (is_empty(terms)) {
18+
terms <- quos(everything())
19+
}
20+
21+
# Get user requested columns
22+
new_data <- object$template
23+
out_names <- recipes_eval_select(terms, new_data, object$term_info,
24+
check_case_weights = FALSE)
25+
new_data <- new_data[, out_names]
26+
27+
# Since most models require factors, do the conversion from character
28+
if (!is.null(object$levels)) {
29+
var_levels <- object$levels
30+
var_levels <- var_levels[out_names]
31+
check_values <-
32+
vapply(var_levels, function(x) {
33+
(!all(is.na(x)))
34+
}, c(all = TRUE))
35+
var_levels <- var_levels[check_values]
36+
if (length(var_levels) > 0) {
37+
new_data <- recipes:::strings2factors(new_data, var_levels)
38+
}
39+
}
40+
41+
new_data
42+
}

R/training_window.R

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#' Limits the size of the training window to the most recent observations
2+
#'
3+
#' `step_training_window` creates a *specification* of a recipe step that
4+
#' limit the size of the training window to the `n_recent` most recent
5+
#' observations in `time_value` per group, where the groups are formed
6+
#' based on the remaining `epi_keys`.
7+
#'
8+
#' @param recipe A recipe object. The step will be added to the
9+
#' sequence of operations for this recipe.
10+
#' @param role Not used by this step since no new variables are created.
11+
#' @param trained A logical to indicate if the quantities for
12+
#' preprocessing have been estimated.
13+
#' @param n_recent An integer value that represents the number of most recent
14+
#' observations that are to be kept in the training window per location.
15+
#' The default value is 50.
16+
#' @param id A character string that is unique to this step to identify it.
17+
#' @template step-return
18+
#'
19+
#' @details Note that `step_epi_lead()` and `step_epi_lag()` should come
20+
#' after any filtering step.
21+
#'
22+
#' @export
23+
#'
24+
#' @examples
25+
#' tib <- tibble::tibble(
26+
#' x = 1:10, y = 1:10,
27+
#' time_value = rep(seq(as.Date("2020-01-01"), by = 1,
28+
#' length.out = 5), times = 2),
29+
#' geo_value = rep(c("ca", "hi"), each = 5)
30+
#' ) %>% epiprocess::as_epi_df()
31+
#'
32+
#' library(recipes)
33+
#' epi_recipe(y ~ x, data = tib) %>%
34+
#' step_training_window(n_recent = 3) %>%
35+
#' prep(tib) %>%
36+
#' bake(new_data = NULL)
37+
step_training_window <-
38+
function(recipe,
39+
role = NA,
40+
trained = FALSE,
41+
n_recent = 50,
42+
id = rand_id("training_window")) {
43+
44+
add_step(
45+
recipe,
46+
step_training_window_new(
47+
role = role,
48+
trained = trained,
49+
n_recent = n_recent,
50+
skip = TRUE,
51+
id = id
52+
)
53+
)
54+
}
55+
56+
step_training_window_new <-
57+
function(terms, role, trained, n_recent, skip, id = id) {
58+
step(
59+
subclass = "training_window",
60+
role = role,
61+
trained = trained,
62+
n_recent = n_recent,
63+
skip = skip,
64+
id = id
65+
)
66+
}
67+
68+
#' @export
69+
prep.step_training_window <- function(x, training, info = NULL) {
70+
71+
step_training_window_new(
72+
role = x$role,
73+
trained = TRUE,
74+
n_recent = x$n_recent,
75+
skip = x$skip,
76+
id = x$id
77+
)
78+
}
79+
80+
#' @export
81+
bake.step_training_window <- function(object, new_data) {
82+
if (!all(object$n_recent == as.integer(object$n_recent))) {
83+
rlang::abort("step_training_window requires 'n_recent' to be integer valued.")
84+
}
85+
86+
ek <- epi_keys(new_data)[which(epi_keys(new_data) != "time_value")]
87+
88+
new_data %>%
89+
dplyr::group_by(dplyr::across(dplyr::all_of(ek))) %>%
90+
dplyr::arrange(time_value) %>%
91+
dplyr::slice_tail(n = object$n_recent) %>%
92+
dplyr::ungroup()
93+
}
94+
95+
#' @export
96+
print.step_training_window <-
97+
function(x, width = max(20, options()$width - 30), ...) {
98+
title <- "Number of most recent observations per location used in training window "
99+
n_recent = x$n_recent
100+
tr_obj = format_selectors(rlang::enquos(n_recent), width)
101+
recipes::print_step(tr_obj, rlang::enquos(n_recent),
102+
x$trained, title, width)
103+
invisible(x)
104+
}

man/bake.Rd

+27
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/epi_juice.Rd

+19
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/epi_workflow.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/flatline.Rd

+2-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)