Skip to content

Commit 3dd7d65

Browse files
authored
Merge pull request #84 from cmu-delphi/72-refactor-epi_workflow-builder
Refactor epi_workflow() builder
2 parents 0b0ceb7 + 98eba97 commit 3dd7d65

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

R/epi_workflow.R

+3-5
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,13 @@
3636
#' wf <- epi_workflow(r, linear_reg())
3737
#'
3838
#' wf
39-
epi_workflow <- function(preprocessor = NULL, spec = NULL,
40-
postprocessor = NULL) {
39+
epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL) {
4140
out <- workflows::workflow(spec = spec)
4241
class(out) <- c("epi_workflow", class(out))
4342

4443
if (is_epi_recipe(preprocessor)) {
45-
return(add_epi_recipe(out, preprocessor))
46-
}
47-
if (!is_null(preprocessor)) {
44+
out <- add_epi_recipe(out, preprocessor)
45+
}else if (!is_null(preprocessor)) {
4846
out <- workflows:::add_preprocessor(out, preprocessor)
4947
}
5048
if (!is_null(postprocessor)) {

tests/testthat/test-epi_workflow.R

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
test_that("postprocesser was evaluated", {
3+
r <- epi_recipe(case_death_rate_subset)
4+
s <- parsnip::linear_reg()
5+
f <- frosting()
6+
7+
ef <- epi_workflow(r, s, f)
8+
ef2 <- epi_workflow(r, s) %>% add_frosting(f)
9+
10+
expect_true(epipredict:::has_postprocessor(ef))
11+
expect_true(epipredict:::has_postprocessor(ef2))
12+
})
13+
14+
15+
test_that("outcome of the two methods are the same", {
16+
jhu <- case_death_rate_subset
17+
18+
r <- epi_recipe(jhu) %>%
19+
step_epi_lag(death_rate, lag = c(0, 7)) %>%
20+
step_epi_ahead(death_rate, ahead = 7) %>%
21+
step_epi_lag(case_rate, lag = c(7)) %>%
22+
step_naomit(all_predictors()) %>%
23+
step_naomit(all_outcomes())
24+
25+
s <- parsnip::linear_reg()
26+
f <- frosting() %>%
27+
layer_predict() %>%
28+
layer_naomit(.pred) %>%
29+
layer_residual_quantile()
30+
31+
ef <- epi_workflow(r, s, f)
32+
ef2 <- epi_workflow(r, s) %>% add_frosting(f)
33+
34+
expect_equal(ef,ef2)
35+
})

0 commit comments

Comments
 (0)