@@ -73,11 +73,7 @@ epi_recipe.default <- function(x, ...) {
73
73
# '
74
74
# ' r
75
75
epi_recipe.epi_df <-
76
- function (x ,
77
- formula = NULL ,
78
- ... ,
79
- vars = NULL ,
80
- roles = NULL ) {
76
+ function (x , formula = NULL , ... , vars = NULL , roles = NULL ) {
81
77
if (! is.null(formula )) {
82
78
if (! is.null(vars )) {
83
79
rlang :: abort(
@@ -115,12 +111,9 @@ epi_recipe.epi_df <-
115
111
# # Check and add roles when available
116
112
if (! is.null(roles )) {
117
113
if (length(roles ) != length(vars )) {
118
- rlang :: abort(
119
- paste0(
114
+ rlang :: abort(c(
120
115
" The number of roles should be the same as the number of " ,
121
- " variables"
122
- )
123
- )
116
+ " variables." ))
124
117
}
125
118
var_info $ role <- roles
126
119
} else {
@@ -161,6 +154,7 @@ epi_recipe.epi_df <-
161
154
162
155
163
156
# ' @rdname epi_recipe
157
+ # ' @importFrom rlang abort
164
158
# ' @export
165
159
epi_recipe.formula <- function (formula , data , ... ) {
166
160
# we ensure that there's only 1 row in the template
@@ -170,9 +164,9 @@ epi_recipe.formula <- function(formula, data, ...) {
170
164
return (recipes :: recipe(formula , data , ... ))
171
165
}
172
166
173
- f_funcs <- fun_calls(formula )
167
+ f_funcs <- recipes ::: fun_calls(formula )
174
168
if (any(f_funcs == " -" )) {
175
- Abort (" `-` is not allowed in a recipe formula. Use `step_rm()` instead." )
169
+ abort (" `-` is not allowed in a recipe formula. Use `step_rm()` instead." )
176
170
}
177
171
178
172
# Check for other in-line functions
@@ -193,11 +187,11 @@ epi_form2args <- function(formula, data, ...) {
193
187
if (! rlang :: is_formula(formula )) formula <- as.formula(formula )
194
188
195
189
# # check for in-line formulas
196
- inline_check(formula )
190
+ recipes ::: inline_check(formula )
197
191
198
192
# # use rlang to get both sides of the formula
199
- outcomes <- get_lhs_vars(formula , data )
200
- predictors <- get_rhs_vars(formula , data , no_lhs = TRUE )
193
+ outcomes <- recipes ::: get_lhs_vars(formula , data )
194
+ predictors <- recipes ::: get_rhs_vars(formula , data , no_lhs = TRUE )
201
195
keys <- epi_keys(data )
202
196
203
197
# # if . was used on the rhs, subtract out the outcomes
@@ -316,3 +310,109 @@ default_epi_recipe_blueprint <-
316
310
hardhat :: default_recipe_blueprint(
317
311
intercept , allow_novel_levels , fresh , bake_dependent_roles , composition )
318
312
}
313
+
314
+
315
+ # unfortunately, everything the same as in prep.recipe except string/fctr handling
316
+ # ' @export
317
+ prep.epi_recipe <- function (
318
+ x , training = NULL , fresh = FALSE , verbose = FALSE ,
319
+ retain = TRUE , log_changes = FALSE , strings_as_factors = TRUE , ... ) {
320
+ training <- recipes ::: check_training_set(training , x , fresh )
321
+ tr_data <- recipes ::: train_info(training )
322
+ keys <- epi_keys(training )
323
+ orig_lvls <- lapply(training , recipes ::: get_levels )
324
+ orig_lvls <- kill_levels(orig_lvls , keys )
325
+ if (strings_as_factors ) {
326
+ lvls <- lapply(training , recipes ::: get_levels )
327
+ lvls <- kill_levels(lvls , keys )
328
+ training <- recipes ::: strings2factors(training , lvls )
329
+ } else {
330
+ lvls <- NULL
331
+ }
332
+ skippers <- map_lgl(x $ steps , recipes ::: is_skipable )
333
+ if (any(skippers ) & ! retain ) {
334
+ rlang :: warn(c(" Since some operations have `skip = TRUE`, using " ,
335
+ " `retain = TRUE` will allow those steps results to " ,
336
+ " be accessible." ))
337
+ }
338
+ if (fresh ) x $ term_info <- x $ var_info
339
+
340
+ running_info <- x $ term_info %> % dplyr :: mutate(number = 0 , skip = FALSE )
341
+ for (i in seq(along.with = x $ steps )) {
342
+ needs_tuning <- map_lgl(x $ steps [[i ]], recipes ::: is_tune )
343
+ if (any(needs_tuning )) {
344
+ arg <- names(needs_tuning )[needs_tuning ]
345
+ arg <- paste0(" '" , arg , " '" , collapse = " , " )
346
+ msg <- paste0(
347
+ " You cannot `prep()` a tuneable recipe. Argument(s) with `tune()`: " ,
348
+ arg , " . Do you want to use a tuning function such as `tune_grid()`?" )
349
+ rlang :: abort(msg )
350
+ }
351
+ note <- paste(" oper" , i , gsub(" _" , " " , class(x $ steps [[i ]])[1 ]))
352
+ if (! x $ steps [[i ]]$ trained | fresh ) {
353
+ if (verbose ) {
354
+ cat(note , " [training]" , " \n " )
355
+ }
356
+ before_nms <- names(training )
357
+ x $ steps [[i ]] <- prep(x $ steps [[i ]], training = training ,
358
+ info = x $ term_info )
359
+ training <- bake(x $ steps [[i ]], new_data = training )
360
+ if (! tibble :: is_tibble(training )) {
361
+ abort(" bake() methods should always return tibbles" )
362
+ }
363
+ x $ term_info <- recipes ::: merge_term_info(get_types(training ), x $ term_info )
364
+ if (! is.na(x $ steps [[i ]]$ role )) {
365
+ new_vars <- setdiff(x $ term_info $ variable , running_info $ variable )
366
+ pos_new_var <- x $ term_info $ variable %in% new_vars
367
+ pos_new_and_na_role <- pos_new_var & is.na(x $ term_info $ role )
368
+ pos_new_and_na_source <- pos_new_var & is.na(x $ term_info $ source )
369
+ x $ term_info $ role [pos_new_and_na_role ] <- x $ steps [[i ]]$ role
370
+ x $ term_info $ source [pos_new_and_na_source ] <- " derived"
371
+ }
372
+ recipes ::: changelog(log_changes , before_nms , names(training ), x $ steps [[i ]])
373
+ running_info <- rbind(
374
+ running_info ,
375
+ dplyr :: mutate(x $ term_info , number = i , skip = x $ steps [[i ]]$ skip ))
376
+ } else {
377
+ if (verbose ) cat(note , " [pre-trained]\n " )
378
+ }
379
+ }
380
+ if (strings_as_factors ) {
381
+ lvls <- lapply(training , recipes ::: get_levels )
382
+ lvls <- kill_levels(lvls , keys )
383
+ check_lvls <- recipes ::: has_lvls(lvls )
384
+ if (! any(check_lvls )) lvls <- NULL
385
+ } else {
386
+ lvls <- NULL
387
+ }
388
+ if (retain ) {
389
+ if (verbose ) {
390
+ cat(" The retained training set is ~" ,
391
+ format(object.size(training ), units = " Mb" , digits = 2 ),
392
+ " in memory.\n\n " )
393
+ }
394
+ x $ template <- training
395
+ } else {
396
+ x $ template <- training [0 , ]
397
+ }
398
+ x $ tr_info <- tr_data
399
+ x $ levels <- lvls
400
+ x $ orig_lvls <- orig_lvls
401
+ x $ retained <- retain
402
+ x $ last_term_info <- running_info %> %
403
+ dplyr :: group_by(variable ) %> %
404
+ dplyr :: arrange(dplyr :: desc(number )) %> %
405
+ dplyr :: summarise(
406
+ type = dplyr :: first(type ),
407
+ role = as.list(unique(unlist(role ))),
408
+ source = dplyr :: first(source ),
409
+ number = dplyr :: first(number ),
410
+ skip = dplyr :: first(skip ),
411
+ .groups = " keep" )
412
+ x
413
+ }
414
+
415
+ kill_levels <- function (x , keys ) {
416
+ for (i in which(names(x ) %in% keys )) x [[i ]] <- list (values = NA , ordered = NA )
417
+ x
418
+ }
0 commit comments