Skip to content

Commit 5b77c1a

Browse files
committed
predict layer debugged
1 parent bd754ca commit 5b77c1a

File tree

5 files changed

+36
-10
lines changed

5 files changed

+36
-10
lines changed

R/epi_workflow.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ predict.epi_workflow <- function(object, new_data, ...) {
130130
components$keys <- grab_forged_keys(components$forged,
131131
components$mold, new_data)
132132
components <- apply_frosting(object, components, the_fit, ...)
133-
out <- dplyr::bind_cols(components$keys, components$preds)
134-
out
133+
components$predictions
135134
}
136135

137136
grab_forged_keys <- function(forged, mold, new_data) {

R/frosting.R

+11-8
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,26 @@ apply_frosting.default <- function(workflow, components, ...) {
7979
#' @export
8080
apply_frosting.epi_workflow <- function(workflow, components, the_fit, ...) {
8181
if (!has_postprocessor(workflow)) {
82-
components$preds <- predict(the_fit, components$forged$predictors, ...)
82+
components$predictions <- predict(the_fit, components$forged$predictors, ...)
83+
components$predictions <- dplyr::bind_cols(components$keys, components$predictions)
8384
return(components)
8485
}
8586
if (!has_postprocessor_frosting(workflow)) {
8687
rlang::warn(c("Only postprocessors of class frosting are allowed.",
8788
"Returning unpostprocessed predictions."))
88-
components$preds <- predict(the_fit, components$forged$predictors, ...)
89+
components$predictions <- predict(the_fit, components$forged$predictors, ...)
90+
components$predictions <- dplyr::bind_cols(components$keys, components$predictions)
8991
return(components)
9092
}
91-
layers <- workflow$post$actions$frosting$frosting
92-
for (l in seq_along(layers$layers)) {
93-
layer <- layers$layers[[l]]
94-
components <- slather(layer, components = components, the_fit)
93+
layers <- workflow$post$actions$frosting$frosting$layers
94+
for (l in seq_along(layers)) {
95+
la <- layers[[l]]
96+
components <- slather(la, components = components, the_fit)
9597
}
9698
# last for the moment, anticipating that layer[1] will do the prediction...
97-
if (is_null(components$preds)) {
98-
components$preds <- predict(the_fit, components$forged$predictors, ...)
99+
if (is_null(components$predictions)) {
100+
components$predictions <- predict(the_fit, components$forged$predictors, ...)
101+
components$predictions <- dplyr::bind_cols(components$keys, components$predictions)
99102
}
100103
return(components)
101104
}

R/layer_predict.R

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ layer_predict_new <- function(id) {
1616
#' @export
1717
slather.layer_predict <- function(object, components, the_fit) {
1818
components$predictions <- predict(the_fit, components$forged$predictors)
19+
components$predictions <- dplyr::bind_cols(components$keys, components$predictions)
1920
components
2021
}
2122

tests/testthat/test-frosting.R

+2
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@ test_that("prediction works without any postprocessor", {
3434
expect_equal(tail(p$time_value, 1), as.Date("2021-12-31"))
3535
expect_equal(unique(p$geo_value), c("ak", "ca", "ny"))
3636
})
37+
38+

tests/testthat/test-layer_predict.R

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
test_that("predict layer works alone", {
2+
jhu <- case_death_rate_subset %>%
3+
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
4+
r <- epi_recipe(jhu) %>%
5+
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
6+
step_epi_ahead(death_rate, ahead = 7) %>%
7+
step_naomit(all_predictors()) %>%
8+
step_naomit(all_outcomes(), skip = TRUE)
9+
wf <- epi_workflow(r, linear_reg()) %>% fit(jhu)
10+
latest <- jhu %>% filter(time_value >= max(time_value) - 14)
11+
12+
f <- frosting() %>% layer_predict()
13+
wf <- wf %>% add_frosting(f)
14+
15+
expect_silent(p <- predict(wf, latest))
16+
expect_equal(ncol(p), 3L)
17+
expect_s3_class(p, "epi_df")
18+
expect_equal(nrow(p), 108L)
19+
expect_named(p, c("geo_value", "time_value", ".pred"))
20+
21+
})

0 commit comments

Comments
 (0)