Skip to content

Commit 20f8b4f

Browse files
authored
Merge pull request #90 from cmu-delphi/frosting
merge to pass tests.
2 parents 18e1ee1 + a7a83c5 commit 20f8b4f

8 files changed

+85
-84
lines changed

R/epi_recipe.R

+4-6
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@ epi_recipe.default <- function(x, ...) {
4545
#' library(dplyr)
4646
#' library(recipes)
4747
#'
48-
#' jhu <- jhu_csse_daily_subset %>%
48+
#' jhu <- case_death_rate_subset %>%
4949
#' filter(time_value > "2021-08-01") %>%
50-
#' select(geo_value:death_rate_7d_av) %>%
51-
#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av)
50+
#' dplyr::arrange(geo_value, time_value)
5251
#'
5352
#' r <- epi_recipe(jhu) %>%
5453
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
@@ -257,10 +256,9 @@ is_epi_recipe <- function(x) {
257256
#' library(dplyr)
258257
#' library(recipes)
259258
#'
260-
#' jhu <- jhu_csse_daily_subset %>%
259+
#' jhu <- case_death_rate_subset %>%
261260
#' filter(time_value > "2021-08-01") %>%
262-
#' select(geo_value:death_rate_7d_av) %>%
263-
#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av)
261+
#' dplyr::arrange(geo_value, time_value)
264262
#'
265263
#' r <- epi_recipe(jhu) %>%
266264
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%

R/epi_workflow.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL)
4242

4343
if (is_epi_recipe(preprocessor)) {
4444
out <- add_epi_recipe(out, preprocessor)
45-
}else if (!is_null(preprocessor)) {
45+
} else if (!is_null(preprocessor)) {
4646
out <- workflows:::add_preprocessor(out, preprocessor)
4747
}
4848
if (!is_null(postprocessor)) {

R/frosting.R

+15-2
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,23 @@
3535
add_frosting <- function(x, frosting, ...) {
3636
rlang::check_dots_empty()
3737
action <- workflows:::new_action_post(frosting = frosting)
38-
workflows:::add_action(x, action, "frosting")
38+
epi_add_action(x, action, "frosting")
3939
}
4040

41-
order_stage_post <- function() "frosting"
41+
42+
# Hacks around workflows `order_stage_post <- charcter(0)` ----------------
43+
epi_add_action <- function(x, action, name, ..., call = caller_env()) {
44+
workflows:::validate_is_workflow(x, call = call)
45+
add_action_frosting(x, action, name, ..., call = call)
46+
}
47+
add_action_frosting <- function(x, action, name, ..., call = caller_env()) {
48+
workflows:::check_singleton(x$post$actions, name, call = call)
49+
x$post <- workflows:::add_action_to_stage(x$post, action, name, order_stage_frosting())
50+
x
51+
}
52+
order_stage_frosting <- function() "frosting"
53+
# End hacks. See cmu-delphi/epipredict#75
54+
4255

4356
#' @rdname add_frosting
4457
#' @export

man/add_epi_recipe.Rd

+2-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/epi_recipe.Rd

+2-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

musings/example-recipe.R

+2-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ library(tidyverse)
22
library(covidcast)
33
library(delphi.epidata)
44
library(epiprocess)
5-
# library(epipredict)
65
library(tidymodels)
6+
77
x <- covidcast(
88
data_source = "jhu-csse",
99
signals = "confirmed_7dav_incidence_prop",
@@ -40,10 +40,7 @@ r <- epi_recipe(x) %>% # if we add this as a class, maybe we get better
4040
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
4141
step_epi_ahead(death_rate, ahead = 7) %>%
4242
step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
43-
step_naomit(all_predictors()) %>%
44-
# below, `skip` means we don't do this at predict time
45-
# we should probably do something useful here to avoid user error
46-
step_naomit(all_outcomes(), skip = TRUE)
43+
step_epi_naomit()
4744

4845
# specify trainer, this uses stats::lm() by default, but doing
4946
# slm <- linear_reg() %>% use_engine("glmnet", penalty = 0.1)

musings/updated-example.Rmd

+11-6
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@ knitr::opts_chunk$set(echo = TRUE)
1010
library(tidyverse)
1111
library(tidymodels)
1212
library(epiprocess)
13+
# devtools::install_github("cmu-delphi/epipredict")
1314
library(epipredict)
14-
1515
```
1616

1717
```{r small-data}
18-
jhu <- jhu_csse_daily_subset %>%
18+
jhu <- case_death_rate_subset %>%
1919
filter(time_value > "2021-08-01") %>%
20-
select(geo_value:death_rate_7d_av) %>%
21-
rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av)
20+
dplyr::arrange(geo_value, time_value)
2221
2322
jhu_latest <- jhu %>%
2423
filter(!is.na(case_rate), !is.na(death_rate)) %>%
@@ -58,13 +57,19 @@ Predict gives a new `epi_df`
5857

5958
```{r predict}
6059
pp <- predict(wf, new_data = jhu_latest)
61-
pp
60+
pp
6261
```
6362

6463
Can add a `forecast_date` (should be a post processing step)
6564

6665
```{r predict2}
67-
predict(wf, new_data = jhu_latest, forecast_date = "2021-12-31") %>%
66+
# Want:
67+
# predict(wf, new_data = jhu_latest, forecast_date = "2021-12-31") %>%
68+
# filter(!is.na(.pred))
69+
70+
# Intended output:
71+
predict(wf, new_data = jhu_latest) %>%
72+
mutate(forecast_date = as.Date("2021-12-31")) %>%
6873
filter(!is.na(.pred))
6974
```
7075

musings/updated-example.html

+48-58
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)