Skip to content

Commit 62868ff

Browse files
committed
create frosting, add simple tests, prepare naomit layer
1 parent 98f5326 commit 62868ff

12 files changed

+910
-84
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(augment,epi_workflow)
34
S3method(bake,step_epi_ahead)
45
S3method(bake,step_epi_lag)
56
S3method(epi_keys,default)
@@ -9,6 +10,7 @@ S3method(epi_recipe,default)
910
S3method(epi_recipe,epi_df)
1011
S3method(epi_recipe,formula)
1112
S3method(predict,epi_workflow)
13+
S3method(prep,epi_recipe)
1214
S3method(prep,step_epi_ahead)
1315
S3method(prep,step_epi_lag)
1416
S3method(print,step_epi_ahead)
@@ -40,6 +42,8 @@ import(recipes)
4042
importFrom(magrittr,"%>%")
4143
importFrom(rlang,"!!")
4244
importFrom(rlang,":=")
45+
importFrom(rlang,abort)
46+
importFrom(rlang,caller_env)
4347
importFrom(rlang,is_null)
4448
importFrom(stats,as.formula)
4549
importFrom(stats,lm)

R/compat-purrr.R

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# See https://github.com/r-lib/rlang/blob/main/R/compat-purrr.R
2+
3+
4+
map <- function(.x, .f, ...) {
5+
.f <- rlang::as_function(.f, env = global_env())
6+
lapply(.x, .f, ...)
7+
}
8+
walk <- function(.x, .f, ...) {
9+
map(.x, .f, ...)
10+
invisible(.x)
11+
}
12+
13+
map_lgl <- function(.x, .f, ...) {
14+
.rlang_purrr_map_mold(.x, .f, logical(1), ...)
15+
}
16+
map_int <- function(.x, .f, ...) {
17+
.rlang_purrr_map_mold(.x, .f, integer(1), ...)
18+
}
19+
map_dbl <- function(.x, .f, ...) {
20+
.rlang_purrr_map_mold(.x, .f, double(1), ...)
21+
}
22+
map_chr <- function(.x, .f, ...) {
23+
.rlang_purrr_map_mold(.x, .f, character(1), ...)
24+
}
25+
.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) {
26+
.f <- rlang::as_function(.f, env = global_env())
27+
out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE)
28+
names(out) <- names(.x)
29+
out
30+
}
31+
32+
map2 <- function(.x, .y, .f, ...) {
33+
.f <- as_function(.f, env = global_env())
34+
out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE)
35+
if (length(out) == length(.x)) {
36+
set_names(out, names(.x))
37+
} else {
38+
set_names(out, NULL)
39+
}
40+
}
41+
map2_lgl <- function(.x, .y, .f, ...) {
42+
as.vector(map2(.x, .y, .f, ...), "logical")
43+
}
44+
map2_int <- function(.x, .y, .f, ...) {
45+
as.vector(map2(.x, .y, .f, ...), "integer")
46+
}
47+
map2_dbl <- function(.x, .y, .f, ...) {
48+
as.vector(map2(.x, .y, .f, ...), "double")
49+
}
50+
map2_chr <- function(.x, .y, .f, ...) {
51+
as.vector(map2(.x, .y, .f, ...), "character")
52+
}
53+
imap <- function(.x, .f, ...) {
54+
map2(.x, names(.x) %||% seq_along(.x), .f, ...)
55+
}
56+
57+
pmap <- function(.l, .f, ...) {
58+
.f <- as.function(.f)
59+
args <- .rlang_purrr_args_recycle(.l)
60+
do.call("mapply", c(
61+
FUN = list(quote(.f)),
62+
args, MoreArgs = quote(list(...)),
63+
SIMPLIFY = FALSE, USE.NAMES = FALSE
64+
))
65+
}

R/epi_recipe.R

Lines changed: 115 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,7 @@ epi_recipe.default <- function(x, ...) {
7373
#'
7474
#' r
7575
epi_recipe.epi_df <-
76-
function(x,
77-
formula = NULL,
78-
...,
79-
vars = NULL,
80-
roles = NULL) {
76+
function(x, formula = NULL, ..., vars = NULL, roles = NULL) {
8177
if (!is.null(formula)) {
8278
if (!is.null(vars)) {
8379
rlang::abort(
@@ -115,12 +111,9 @@ epi_recipe.epi_df <-
115111
## Check and add roles when available
116112
if (!is.null(roles)) {
117113
if (length(roles) != length(vars)) {
118-
rlang::abort(
119-
paste0(
114+
rlang::abort(c(
120115
"The number of roles should be the same as the number of ",
121-
"variables"
122-
)
123-
)
116+
"variables."))
124117
}
125118
var_info$role <- roles
126119
} else {
@@ -161,6 +154,7 @@ epi_recipe.epi_df <-
161154

162155

163156
#' @rdname epi_recipe
157+
#' @importFrom rlang abort
164158
#' @export
165159
epi_recipe.formula <- function(formula, data, ...) {
166160
# we ensure that there's only 1 row in the template
@@ -170,9 +164,9 @@ epi_recipe.formula <- function(formula, data, ...) {
170164
return(recipes::recipe(formula, data, ...))
171165
}
172166

173-
f_funcs <- fun_calls(formula)
167+
f_funcs <- recipes:::fun_calls(formula)
174168
if (any(f_funcs == "-")) {
175-
Abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
169+
abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
176170
}
177171

178172
# Check for other in-line functions
@@ -193,11 +187,11 @@ epi_form2args <- function(formula, data, ...) {
193187
if (! rlang::is_formula(formula)) formula <- as.formula(formula)
194188

195189
## check for in-line formulas
196-
inline_check(formula)
190+
recipes:::inline_check(formula)
197191

198192
## use rlang to get both sides of the formula
199-
outcomes <- get_lhs_vars(formula, data)
200-
predictors <- get_rhs_vars(formula, data, no_lhs = TRUE)
193+
outcomes <- recipes:::get_lhs_vars(formula, data)
194+
predictors <- recipes:::get_rhs_vars(formula, data, no_lhs = TRUE)
201195
keys <- epi_keys(data)
202196

203197
## if . was used on the rhs, subtract out the outcomes
@@ -316,3 +310,109 @@ default_epi_recipe_blueprint <-
316310
hardhat::default_recipe_blueprint(
317311
intercept, allow_novel_levels, fresh, bake_dependent_roles, composition)
318312
}
313+
314+
315+
# unfortunately, everything the same as in prep.recipe except string/fctr handling
316+
#' @export
317+
prep.epi_recipe <- function(
318+
x, training = NULL, fresh = FALSE, verbose = FALSE,
319+
retain = TRUE, log_changes = FALSE, strings_as_factors = TRUE, ...) {
320+
training <- recipes:::check_training_set(training, x, fresh)
321+
tr_data <- recipes:::train_info(training)
322+
keys <- epi_keys(training)
323+
orig_lvls <- lapply(training, recipes:::get_levels)
324+
orig_lvls <- kill_levels(orig_lvls, keys)
325+
if (strings_as_factors) {
326+
lvls <- lapply(training, recipes:::get_levels)
327+
lvls <- kill_levels(lvls, keys)
328+
training <- recipes:::strings2factors(training, lvls)
329+
} else {
330+
lvls <- NULL
331+
}
332+
skippers <- map_lgl(x$steps, recipes:::is_skipable)
333+
if (any(skippers) & !retain) {
334+
rlang::warn(c("Since some operations have `skip = TRUE`, using ",
335+
"`retain = TRUE` will allow those steps results to ",
336+
"be accessible."))
337+
}
338+
if (fresh) x$term_info <- x$var_info
339+
340+
running_info <- x$term_info %>% dplyr::mutate(number = 0, skip = FALSE)
341+
for (i in seq(along.with = x$steps)) {
342+
needs_tuning <- map_lgl(x$steps[[i]], recipes:::is_tune)
343+
if (any(needs_tuning)) {
344+
arg <- names(needs_tuning)[needs_tuning]
345+
arg <- paste0("'", arg, "'", collapse = ", ")
346+
msg <- paste0(
347+
"You cannot `prep()` a tuneable recipe. Argument(s) with `tune()`: ",
348+
arg, ". Do you want to use a tuning function such as `tune_grid()`?")
349+
rlang::abort(msg)
350+
}
351+
note <- paste("oper", i, gsub("_", " ", class(x$steps[[i]])[1]))
352+
if (!x$steps[[i]]$trained | fresh) {
353+
if (verbose) {
354+
cat(note, "[training]", "\n")
355+
}
356+
before_nms <- names(training)
357+
x$steps[[i]] <- prep(x$steps[[i]], training = training,
358+
info = x$term_info)
359+
training <- bake(x$steps[[i]], new_data = training)
360+
if (!tibble::is_tibble(training)) {
361+
abort("bake() methods should always return tibbles")
362+
}
363+
x$term_info <- recipes:::merge_term_info(get_types(training), x$term_info)
364+
if (!is.na(x$steps[[i]]$role)) {
365+
new_vars <- setdiff(x$term_info$variable, running_info$variable)
366+
pos_new_var <- x$term_info$variable %in% new_vars
367+
pos_new_and_na_role <- pos_new_var & is.na(x$term_info$role)
368+
pos_new_and_na_source <- pos_new_var & is.na(x$term_info$source)
369+
x$term_info$role[pos_new_and_na_role] <- x$steps[[i]]$role
370+
x$term_info$source[pos_new_and_na_source] <- "derived"
371+
}
372+
recipes:::changelog(log_changes, before_nms, names(training), x$steps[[i]])
373+
running_info <- rbind(
374+
running_info,
375+
dplyr::mutate(x$term_info, number = i, skip = x$steps[[i]]$skip))
376+
} else {
377+
if (verbose) cat(note, "[pre-trained]\n")
378+
}
379+
}
380+
if (strings_as_factors) {
381+
lvls <- lapply(training, recipes:::get_levels)
382+
lvls <- kill_levels(lvls, keys)
383+
check_lvls <- recipes:::has_lvls(lvls)
384+
if (!any(check_lvls)) lvls <- NULL
385+
} else {
386+
lvls <- NULL
387+
}
388+
if (retain) {
389+
if (verbose) {
390+
cat("The retained training set is ~",
391+
format(object.size(training), units = "Mb", digits = 2),
392+
" in memory.\n\n")
393+
}
394+
x$template <- training
395+
} else {
396+
x$template <- training[0, ]
397+
}
398+
x$tr_info <- tr_data
399+
x$levels <- lvls
400+
x$orig_lvls <- orig_lvls
401+
x$retained <- retain
402+
x$last_term_info <- running_info %>%
403+
dplyr::group_by(variable) %>%
404+
dplyr::arrange(dplyr::desc(number)) %>%
405+
dplyr::summarise(
406+
type = dplyr::first(type),
407+
role = as.list(unique(unlist(role))),
408+
source = dplyr::first(source),
409+
number = dplyr::first(number),
410+
skip = dplyr::first(skip),
411+
.groups = "keep")
412+
x
413+
}
414+
415+
kill_levels <- function(x, keys) {
416+
for (i in which(names(x) %in% keys)) x[[i]] <- list(values = NA, ordered = NA)
417+
x
418+
}

R/epi_workflow.R

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,21 @@
3535
#' wf <- epi_workflow(r, linear_reg())
3636
#'
3737
#' wf
38-
epi_workflow <- function(preprocessor = NULL, spec = NULL) {
38+
epi_workflow <- function(preprocessor = NULL, spec = NULL,
39+
postprocessor = NULL) {
3940
out <- workflows::workflow(spec = spec)
4041
class(out) <- c("epi_workflow", class(out))
4142

4243
if (is_epi_recipe(preprocessor)) {
4344
return(add_epi_recipe(out, preprocessor))
4445
}
45-
4646
if (!is_null(preprocessor)) {
47-
return(workflows:::add_preprocessor(out, preprocessor))
47+
out <- workflows:::add_preprocessor(out, preprocessor)
48+
}
49+
if (!is_null(postprocessor)) {
50+
out <- add_postprocessor(out, postprocessor)
4851
}
52+
4953
out
5054
}
5155

@@ -95,17 +99,11 @@ is_epi_workflow <- function(x) {
9599
#' @export
96100
#' @examples
97101
#'
98-
#' library(epiprocess)
99102
#' library(dplyr)
100103
#' library(parsnip)
101104
#' library(recipes)
102105
#'
103-
#' jhu <- jhu_csse_daily_subset %>%
104-
#' filter(time_value > "2021-08-01") %>%
105-
#' select(geo_value:death_rate_7d_av) %>%
106-
#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av)
107-
#'
108-
#' r <- epi_recipe(jhu) %>%
106+
#' r <- epi_recipe(case_death_rate_subset) %>%
109107
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
110108
#' step_epi_ahead(death_rate, ahead = 7) %>%
111109
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
@@ -114,33 +112,29 @@ is_epi_workflow <- function(x) {
114112
#'
115113
#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu)
116114
#'
117-
#' jhu_latest <- jhu %>%
118-
#' filter(!is.na(case_rate), !is.na(death_rate)) %>%
119-
#' group_by(geo_value) %>%
120-
#' slice_tail(n = 15) %>% # have lags 0,...,14, so need 15 for a complete case
121-
#' ungroup()
115+
#' latest <- get_test_data(r, case_death_rate_subset)
122116
#'
123-
#' preds <- predict(wf, jhu_latest, forecast_date = "2021-12-31") %>%
117+
#' preds <- predict(wf, latest) %>%
124118
#' filter(!is.na(.pred))
125119
#'
126120
#' preds
127-
predict.epi_workflow <-
128-
function(object, new_data, type = NULL, opts = list(),
129-
forecast_date = NULL, ...) {
130-
if (!workflows::is_trained_workflow(object)) {
131-
rlang::abort(
132-
c("Can't predict on an untrained epi_workflow.",
133-
i = "Do you need to call `fit()`?"))
134-
}
135-
if (!is_null(forecast_date)) forecast_date <- as.Date(forecast_date)
136-
the_fit <- workflows::extract_fit_parsnip(object)
137-
mold <- workflows::extract_mold(object)
138-
forged <- hardhat::forge(new_data, blueprint = mold$blueprint)
139-
preds <- predict(the_fit, forged$predictors, type = type, opts = opts, ...)
140-
keys <- grab_forged_keys(forged, mold, new_data)
141-
out <- dplyr::bind_cols(keys, forecast_date = forecast_date, preds)
142-
out
121+
predict.epi_workflow <- function(object, new_data, ...) {
122+
if (!workflows::is_trained_workflow(object)) {
123+
rlang::abort(
124+
c("Can't predict on an untrained epi_workflow.",
125+
i = "Do you need to call `fit()`?"))
143126
}
127+
components <- list()
128+
the_fit <- workflows::extract_fit_parsnip(object)
129+
components$mold <- workflows::extract_mold(object)
130+
components$forged <- hardhat::forge(new_data,
131+
blueprint = components$mold$blueprint)
132+
components$keys <- grab_forged_keys(components$forged,
133+
components$mold, new_data)
134+
components <- apply_frosting(object, components, the_fit, ...)
135+
out <- dplyr::bind_cols(components$keys, components$preds)
136+
out
137+
}
144138

145139
grab_forged_keys <- function(forged, mold, new_data) {
146140
keys <- c("time_value", "geo_value", "key")

0 commit comments

Comments
 (0)