Skip to content

Commit 30ddb54

Browse files
committed
fix string/fct coercion
1 parent 62868ff commit 30ddb54

File tree

8 files changed

+42
-14
lines changed

8 files changed

+42
-14
lines changed

DESCRIPTION

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Depends:
1414
R (>= 3.5.0)
1515
Imports:
1616
assertthat,
17+
broom,
1718
cli,
1819
dplyr,
1920
epiprocess,

NAMESPACE

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(apply_frosting,default)
4+
S3method(apply_frosting,epi_workflow)
5+
S3method(as_tibble,epi_df)
36
S3method(augment,epi_workflow)
47
S3method(bake,step_epi_ahead)
58
S3method(bake,step_epi_lag)
@@ -17,6 +20,7 @@ S3method(print,step_epi_ahead)
1720
S3method(print,step_epi_lag)
1821
export("%>%")
1922
export(add_epi_recipe)
23+
export(apply_frosting)
2024
export(arx_args_list)
2125
export(arx_forecaster)
2226
export(create_lags_and_leads)
@@ -39,6 +43,7 @@ export(smooth_arx_forecaster)
3943
export(step_epi_ahead)
4044
export(step_epi_lag)
4145
import(recipes)
46+
importFrom(broom,augment)
4247
importFrom(magrittr,"%>%")
4348
importFrom(rlang,"!!")
4449
importFrom(rlang,":=")

R/epi_recipe.R

+6
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,9 @@ kill_levels <- function(x, keys) {
416416
for (i in which(names(x) %in% keys)) x[[i]] <- list(values = NA, ordered = NA)
417417
x
418418
}
419+
420+
#' @export
421+
as_tibble.epi_df <- function(x, ...) {
422+
# so that downstream calls to as_tibble don't clobber our metadata
423+
return(x)
424+
}

R/epi_workflow.R

+6-7
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@
1515
#' @importFrom rlang is_null
1616
#' @export
1717
#' @examples
18-
#' library(epiprocess)
1918
#' library(dplyr)
2019
#' library(parsnip)
2120
#' library(recipes)
2221
#'
23-
#' jhu <- jhu_csse_daily_subset %>%
24-
#' filter(time_value > "2021-08-01") %>%
25-
#' select(geo_value:death_rate_7d_av) %>%
26-
#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av)
22+
#' jhu <- case_death_rate_subset
2723
#'
2824
#' r <- epi_recipe(jhu) %>%
2925
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
@@ -103,7 +99,9 @@ is_epi_workflow <- function(x) {
10399
#' library(parsnip)
104100
#' library(recipes)
105101
#'
106-
#' r <- epi_recipe(case_death_rate_subset) %>%
102+
#' jhu <- case_death_rate_subset
103+
#'
104+
#' r <- epi_recipe(jhu) %>%
107105
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
108106
#' step_epi_ahead(death_rate, ahead = 7) %>%
109107
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
@@ -112,7 +110,7 @@ is_epi_workflow <- function(x) {
112110
#'
113111
#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu)
114112
#'
115-
#' latest <- get_test_data(r, case_death_rate_subset)
113+
#' latest <- get_test_data(r, jhu)
116114
#'
117115
#' preds <- predict(wf, latest) %>%
118116
#' filter(!is.na(.pred))
@@ -171,6 +169,7 @@ grab_forged_keys <- function(forged, mold, new_data) {
171169
#' @param ... Arguments passed on to the predict method.
172170
#'
173171
#' @return new_data with additional columns containing the predicted values
172+
#' @importFrom broom augment
174173
#' @export
175174
augment.epi_workflow <- function (x, new_data, ...) {
176175
predictions <- predict(x, new_data, ...)

R/frosting.R

+3
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ frosting <- function(layers = NULL, requirements = NULL) {
5858
out <- new_frosting()
5959
}
6060

61+
#' @export
6162
apply_frosting <- function(workflow, ...) {
6263
UseMethod("apply_frosting")
6364
}
6465

66+
#' @export
6567
apply_frosting.default <- function(workflow, components, ...) {
6668
if (has_postprocessor(workflow)) {
6769
abort(c("Postprocessing is only available for epi_workflows currently.",
@@ -74,6 +76,7 @@ apply_frosting.default <- function(workflow, components, ...) {
7476

7577
#' @importFrom rlang is_null
7678
#' @importFrom rlang abort
79+
#' @export
7780
apply_frosting.epi_workflow <- function(workflow, components, the_fit, ...) {
7881
if (!has_postprocessor(workflow)) {
7982
components$preds <- predict(the_fit, components$forged$predictors, ...)

man/epi_workflow.Rd

+1-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/predict-epi_workflow.Rd

+4-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-frosting.R

+16
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,20 @@ test_that("frosting validators / constructors work", {
1717

1818
test_that("prediction works without any postprocessor", {
1919

20+
jhu <- case_death_rate_subset %>%
21+
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
22+
r <- epi_recipe(jhu) %>%
23+
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
24+
step_epi_ahead(death_rate, ahead = 7) %>%
25+
step_naomit(all_predictors()) %>%
26+
step_naomit(all_outcomes(), skip = TRUE)
27+
wf <- epi_workflow(r, linear_reg()) %>% fit(jhu)
28+
latest <- get_test_data(r, jhu)
29+
30+
expect_silent(predict(wf, latest))
31+
p <- predict(wf, latest) %>% dplyr::filter(!is.na(.pred))
32+
expect_equal(nrow(p), 3)
33+
expect_s3_class(p, "epi_df")
34+
expect_equal(tail(p$time_value, 1), as.Date("2021-12-31"))
35+
expect_equal(unique(p$geo_value), c("ak", "ca", "ny"))
2036
})

0 commit comments

Comments
 (0)