diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 57272ee88..abcb06167 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -45,10 +45,9 @@ epi_recipe.default <- function(x, ...) { #' library(dplyr) #' library(recipes) #' -#' jhu <- jhu_csse_daily_subset %>% +#' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-08-01") %>% -#' select(geo_value:death_rate_7d_av) %>% -#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av) +#' dplyr::arrange(geo_value, time_value) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -257,10 +256,9 @@ is_epi_recipe <- function(x) { #' library(dplyr) #' library(recipes) #' -#' jhu <- jhu_csse_daily_subset %>% +#' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-08-01") %>% -#' select(geo_value:death_rate_7d_av) %>% -#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av) +#' dplyr::arrange(geo_value, time_value) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 72fe2b1f3..4be82dcc6 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -42,7 +42,7 @@ epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL) if (is_epi_recipe(preprocessor)) { out <- add_epi_recipe(out, preprocessor) - }else if (!is_null(preprocessor)) { + } else if (!is_null(preprocessor)) { out <- workflows:::add_preprocessor(out, preprocessor) } if (!is_null(postprocessor)) { diff --git a/R/frosting.R b/R/frosting.R index b8bb75e8a..436ea2319 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -35,10 +35,23 @@ add_frosting <- function(x, frosting, ...) { rlang::check_dots_empty() action <- workflows:::new_action_post(frosting = frosting) - workflows:::add_action(x, action, "frosting") + epi_add_action(x, action, "frosting") } -order_stage_post <- function() "frosting" + +# Hacks around workflows `order_stage_post <- charcter(0)` ---------------- +epi_add_action <- function(x, action, name, ..., call = caller_env()) { + workflows:::validate_is_workflow(x, call = call) + add_action_frosting(x, action, name, ..., call = call) +} +add_action_frosting <- function(x, action, name, ..., call = caller_env()) { + workflows:::check_singleton(x$post$actions, name, call = call) + x$post <- workflows:::add_action_to_stage(x$post, action, name, order_stage_frosting()) + x +} +order_stage_frosting <- function() "frosting" +# End hacks. See cmu-delphi/epipredict#75 + #' @rdname add_frosting #' @export diff --git a/man/add_epi_recipe.Rd b/man/add_epi_recipe.Rd index 6202e4670..55508eefb 100644 --- a/man/add_epi_recipe.Rd +++ b/man/add_epi_recipe.Rd @@ -35,10 +35,9 @@ library(epiprocess) library(dplyr) library(recipes) -jhu <- jhu_csse_daily_subset \%>\% +jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-08-01") \%>\% - select(geo_value:death_rate_7d_av) \%>\% - rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av) + dplyr::arrange(geo_value, time_value) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% diff --git a/man/create_layer.Rd b/man/create_layer.Rd index 81f5e33b0..7917e8854 100644 --- a/man/create_layer.Rd +++ b/man/create_layer.Rd @@ -9,9 +9,9 @@ create_layer(name = NULL, open = rlang::is_interactive()) \arguments{ \item{name}{Either a name without extension, or \code{NULL} to create the paired file based on currently open file in the script editor. If -the \verb{R/} file is open, \code{use_test()} will create/open the corresponding +the R file is open, \code{use_test()} will create/open the corresponding test file; if the test file is open, \code{use_r()} will create/open the -corresponding \verb{R/} file.} +corresponding R file.} \item{open}{Whether to open the file for interactive editing.} } diff --git a/man/epi_recipe.Rd b/man/epi_recipe.Rd index 8f5eb4e11..1c74426bf 100644 --- a/man/epi_recipe.Rd +++ b/man/epi_recipe.Rd @@ -61,10 +61,9 @@ library(epiprocess) library(dplyr) library(recipes) -jhu <- jhu_csse_daily_subset \%>\% +jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-08-01") \%>\% - select(geo_value:death_rate_7d_av) \%>\% - rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av) + dplyr::arrange(geo_value, time_value) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% diff --git a/musings/example-recipe.R b/musings/example-recipe.R index d2bf9f46c..00458e3ca 100644 --- a/musings/example-recipe.R +++ b/musings/example-recipe.R @@ -2,8 +2,8 @@ library(tidyverse) library(covidcast) library(delphi.epidata) library(epiprocess) -# library(epipredict) library(tidymodels) + x <- covidcast( data_source = "jhu-csse", signals = "confirmed_7dav_incidence_prop", @@ -40,10 +40,7 @@ r <- epi_recipe(x) %>% # if we add this as a class, maybe we get better step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% - step_naomit(all_predictors()) %>% - # below, `skip` means we don't do this at predict time - # we should probably do something useful here to avoid user error - step_naomit(all_outcomes(), skip = TRUE) + step_epi_naomit() # specify trainer, this uses stats::lm() by default, but doing # slm <- linear_reg() %>% use_engine("glmnet", penalty = 0.1) diff --git a/musings/updated-example.Rmd b/musings/updated-example.Rmd index 84b8ea803..6191cce04 100644 --- a/musings/updated-example.Rmd +++ b/musings/updated-example.Rmd @@ -10,15 +10,14 @@ knitr::opts_chunk$set(echo = TRUE) library(tidyverse) library(tidymodels) library(epiprocess) +# devtools::install_github("cmu-delphi/epipredict") library(epipredict) - ``` ```{r small-data} -jhu <- jhu_csse_daily_subset %>% +jhu <- case_death_rate_subset %>% filter(time_value > "2021-08-01") %>% - select(geo_value:death_rate_7d_av) %>% - rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av) + dplyr::arrange(geo_value, time_value) jhu_latest <- jhu %>% filter(!is.na(case_rate), !is.na(death_rate)) %>% @@ -58,13 +57,19 @@ Predict gives a new `epi_df` ```{r predict} pp <- predict(wf, new_data = jhu_latest) -pp +pp ``` Can add a `forecast_date` (should be a post processing step) ```{r predict2} -predict(wf, new_data = jhu_latest, forecast_date = "2021-12-31") %>% +# Want: +# predict(wf, new_data = jhu_latest, forecast_date = "2021-12-31") %>% +# filter(!is.na(.pred)) + +# Intended output: +predict(wf, new_data = jhu_latest) %>% + mutate(forecast_date = as.Date("2021-12-31")) %>% filter(!is.na(.pred)) ``` diff --git a/musings/updated-example.html b/musings/updated-example.html index c88fb95ba..eb9a039b4 100644 --- a/musings/updated-example.html +++ b/musings/updated-example.html @@ -15,24 +15,11 @@