Skip to content

Commit bd754ca

Browse files
committed
create a point prediction layer.
* currently bugs in the workflow hierarchy. For some reason the layers are deeply nested.
1 parent 298d837 commit bd754ca

File tree

4 files changed

+31
-3
lines changed

4 files changed

+31
-3
lines changed

NAMESPACE

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ S3method(prep,step_epi_ahead)
1818
S3method(prep,step_epi_lag)
1919
S3method(print,step_epi_ahead)
2020
S3method(print,step_epi_lag)
21+
S3method(slather,layer_naomit)
22+
S3method(slather,layer_predict)
2123
export("%>%")
2224
export(add_epi_recipe)
2325
export(apply_frosting)

R/frosting.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ apply_frosting.epi_workflow <- function(workflow, components, the_fit, ...) {
8888
components$preds <- predict(the_fit, components$forged$predictors, ...)
8989
return(components)
9090
}
91-
layers <- workflow$post$actions$frosting
92-
for (l in seq_along(layers)) {
91+
layers <- workflow$post$actions$frosting$frosting
92+
for (l in seq_along(layers$layers)) {
9393
layer <- layers$layers[[l]]
9494
components <- slather(layer, components = components, the_fit)
9595
}

R/layer_naomit.R

+6-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ layer_naomit_new <- function(terms, id) {
1212
layer("naomit", terms = terms, id = id)
1313
}
1414

15+
#' @export
1516
slather.layer_naomit <- function(object, components, the_fit) {
17+
exprs <- rlang::expr(c(!!!object$terms))
18+
pos <- tidyselect::eval_select(exprs, components$predictions)
19+
col_names <- names(pos)
1620
components$predictions <- components$predictions %>%
17-
filter()
21+
dplyr::filter(dplyr::if_any(dplyr::all_of(col_names), ~ !is.na(.x)))
22+
components
1823
}

R/layer_predict.R

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
layer_predict <-
2+
function(frosting, id = rand_id("predict_default")) {
3+
add_layer(
4+
frosting,
5+
layer_predict_new(
6+
id = id
7+
)
8+
)
9+
}
10+
11+
12+
layer_predict_new <- function(id) {
13+
layer("predict", id = id)
14+
}
15+
16+
#' @export
17+
slather.layer_predict <- function(object, components, the_fit) {
18+
components$predictions <- predict(the_fit, components$forged$predictors)
19+
components
20+
}
21+

0 commit comments

Comments
 (0)