Skip to content

Commit c94d5f9

Browse files
authored
Merge pull request #297 from cmu-delphi/291-date-period
291 date period
2 parents 7a4ea55 + c1e3ff9 commit c94d5f9

16 files changed

+258
-102
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.10
3+
Version: 0.0.11
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -46,6 +46,7 @@ Imports:
4646
tibble,
4747
tidyr,
4848
tidyselect,
49+
tsibble,
4950
usethis,
5051
vctrs,
5152
workflows (>= 1.0.0)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
3333
- Working vignette
3434
- use `checkmate` for input validation
3535
- refactor quantile extrapolation (possibly creates different results)
36+
- force `target_date` + `forecast_date` handling to match the time_type of
37+
the epi_df. allows for annual and weekly data

R/dist_quantiles.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ quantile_extrapolate <- function(x, tau_out, middle) {
233233
dplyr::arrange(q)
234234
}
235235
if (any(indl)) {
236-
qvals_out[indl] <- tail_extrapolate(tau_out[indl], head(qv, 2))
236+
qvals_out[indl] <- tail_extrapolate(tau_out[indl], utils::head(qv, 2))
237237
}
238238
if (any(indr)) {
239-
qvals_out[indr] <- tail_extrapolate(tau_out[indr], tail(qv, 2))
239+
qvals_out[indr] <- tail_extrapolate(tau_out[indr], utils::tail(qv, 2))
240240
}
241241
qvals_out
242242
}

R/epi_workflow.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,10 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
251251
#' preds
252252
predict.epi_workflow <- function(object, new_data, ...) {
253253
if (!workflows::is_trained_workflow(object)) {
254-
rlang::abort(
255-
c("Can't predict on an untrained epi_workflow.",
256-
i = "Do you need to call `fit()`?"
257-
)
258-
)
254+
cli::cli_abort(c(
255+
"Can't predict on an untrained epi_workflow.",
256+
i = "Do you need to call `fit()`?"
257+
))
259258
}
260259
components <- list()
261260
components$mold <- workflows::extract_mold(object)

R/layer_add_forecast_date.R

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
#' p3
6969
layer_add_forecast_date <-
7070
function(frosting, forecast_date = NULL, id = rand_id("add_forecast_date")) {
71+
arg_is_chr_scalar(id)
72+
arg_is_scalar(forecast_date, allow_null = TRUE)
73+
# can't validate the type of forecast_date until we know the time_type
7174
add_layer(
7275
frosting,
7376
layer_add_forecast_date_new(
@@ -78,8 +81,6 @@ layer_add_forecast_date <-
7881
}
7982

8083
layer_add_forecast_date_new <- function(forecast_date, id) {
81-
forecast_date <- arg_to_date(forecast_date, allow_null = TRUE)
82-
arg_is_chr_scalar(id)
8384
layer("add_forecast_date", forecast_date = forecast_date, id = id)
8485
}
8586

@@ -91,26 +92,25 @@ slather.layer_add_forecast_date <- function(object, components, workflow, new_da
9192
workflow$fit$meta$max_time_value,
9293
max(new_data$time_value)
9394
)
94-
object$forecast_date <- max_time_value
95+
forecast_date <- max_time_value
96+
} else {
97+
forecast_date <- object$forecast_date
9598
}
96-
as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of
97-
as_of_fit <- workflow$fit$meta$as_of
98-
as_of_post <- attributes(new_data)$metadata$as_of
9999

100-
as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post))
101-
102-
if (object$forecast_date < as_of_date) {
103-
cli_warn(
104-
c("The forecast_date is less than the most ",
105-
"recent update date of the data: ",
106-
i = "forecast_date = {object$forecast_date} while data is from {as_of_date}."
107-
)
108-
)
109-
}
100+
expected_time_type <- attr(
101+
workflows::extract_preprocessor(workflow)$template, "metadata"
102+
)$time_type
103+
if (expected_time_type == "week") expected_time_type <- "day"
104+
validate_date(forecast_date, expected_time_type,
105+
call = expr(layer_add_forecast_date())
106+
)
107+
forecast_date <- coerce_time_type(forecast_date, expected_time_type)
108+
object$forecast_date <- forecast_date
110109
components$predictions <- dplyr::bind_cols(
111110
components$predictions,
112-
forecast_date = as.Date(object$forecast_date)
111+
forecast_date = forecast_date
113112
)
113+
114114
components
115115
}
116116

R/layer_add_target_date.R

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@
6363
#' p3
6464
layer_add_target_date <-
6565
function(frosting, target_date = NULL, id = rand_id("add_target_date")) {
66-
target_date <- arg_to_date(target_date, allow_null = TRUE)
6766
arg_is_chr_scalar(id)
67+
arg_is_scalar(target_date, allow_null = TRUE)
68+
# can't validate the type of target_date until we know the time_type
6869
add_layer(
6970
frosting,
7071
layer_add_target_date_new(
@@ -84,35 +85,39 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data
8485
the_recipe <- workflows::extract_recipe(workflow)
8586
the_frosting <- extract_frosting(workflow)
8687

88+
expected_time_type <- attr(
89+
workflows::extract_preprocessor(workflow)$template, "metadata"
90+
)$time_type
91+
if (expected_time_type == "week") expected_time_type <- "day"
92+
8793
if (!is.null(object$target_date)) {
88-
target_date <- as.Date(object$target_date)
89-
} else { # null target date case
90-
if (detect_layer(the_frosting, "layer_add_forecast_date") &&
91-
!is.null(extract_argument(
92-
the_frosting,
93-
"layer_add_forecast_date", "forecast_date"
94+
target_date <- object$target_date
95+
validate_date(target_date, expected_time_type,
96+
call = expr(layer_add_target_date())
97+
)
98+
target_date <- coerce_time_type(target_date, expected_time_type)
99+
} else if (
100+
detect_layer(the_frosting, "layer_add_forecast_date") &&
101+
!is.null(forecast_date <- extract_argument(
102+
the_frosting, "layer_add_forecast_date", "forecast_date"
94103
))) {
95-
forecast_date <- extract_argument(
96-
the_frosting,
97-
"layer_add_forecast_date", "forecast_date"
98-
)
99-
100-
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
101-
102-
target_date <- forecast_date + ahead
103-
} else {
104-
max_time_value <- max(
105-
workflows::extract_preprocessor(workflow)$max_time_value,
106-
workflow$fit$meta$max_time_value,
107-
max(new_data$time_value)
108-
)
109-
110-
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
111-
112-
target_date <- max_time_value + ahead
113-
}
104+
validate_date(forecast_date, expected_time_type,
105+
call = expr(layer_add_forecast_date())
106+
)
107+
forecast_date <- coerce_time_type(forecast_date, expected_time_type)
108+
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
109+
target_date <- forecast_date + ahead
110+
} else {
111+
max_time_value <- max(
112+
workflows::extract_preprocessor(workflow)$max_time_value,
113+
workflow$fit$meta$max_time_value,
114+
max(new_data$time_value)
115+
)
116+
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
117+
target_date <- max_time_value + ahead
114118
}
115119

120+
object$target_date <- target_date
116121
components$predictions <- dplyr::bind_cols(components$predictions,
117122
target_date = target_date
118123
)

R/layer_population_scaling.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ slather.layer_population_scaling <-
144144
length(object$df_pop_col) == 1
145145
)
146146

147+
if (is.null(object$by)) {
148+
object$by <- intersect(
149+
kill_time_value(epi_keys(components$predictions)),
150+
colnames(dplyr::select(object$df, !object$df_pop_col))
151+
)
152+
}
147153
try_join <- try(
148154
dplyr::left_join(components$predictions, object$df,
149155
by = object$by
@@ -157,8 +163,8 @@ slather.layer_population_scaling <-
157163
))
158164
}
159165

160-
object$df <- object$df %>%
161-
dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower))
166+
# object$df <- object$df %>%
167+
# dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower))
162168
pop_col <- rlang::sym(object$df_pop_col)
163169
exprs <- rlang::expr(c(!!!object$terms))
164170
pos <- tidyselect::eval_select(exprs, components$predictions)

R/step_epi_shift.R

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,20 @@
5353
step_epi_lag <-
5454
function(recipe,
5555
...,
56+
lag,
5657
role = "predictor",
5758
trained = FALSE,
58-
lag,
5959
prefix = "lag_",
6060
default = NA,
6161
columns = NULL,
6262
skip = FALSE,
6363
id = rand_id("epi_lag")) {
6464
if (!is_epi_recipe(recipe)) {
65-
rlang::abort("This recipe step can only operate on an `epi_recipe`.")
65+
cli::cli_abort("This step can only operate on an `epi_recipe`.")
6666
}
6767

6868
if (missing(lag)) {
69-
rlang::abort(
69+
cli::cli_abort(
7070
c("The `lag` argument must not be empty.",
7171
i = "Did you perhaps pass an integer in `...` accidentally?"
7272
)
@@ -75,7 +75,8 @@ step_epi_lag <-
7575
arg_is_nonneg_int(lag)
7676
arg_is_chr_scalar(prefix, id)
7777
if (!is.null(columns)) {
78-
rlang::abort(c("The `columns` argument must be `NULL.",
78+
cli::cli_abort(c(
79+
"The `columns` argument must be `NULL.",
7980
i = "Use `tidyselect` methods to choose columns to lag."
8081
))
8182
}
@@ -85,7 +86,7 @@ step_epi_lag <-
8586
terms = dplyr::enquos(...),
8687
role = role,
8788
trained = trained,
88-
lag = lag,
89+
lag = as.integer(lag),
8990
prefix = prefix,
9091
default = default,
9192
keys = epi_keys(recipe),
@@ -104,24 +105,23 @@ step_epi_lag <-
104105
step_epi_ahead <-
105106
function(recipe,
106107
...,
108+
ahead,
107109
role = "outcome",
108110
trained = FALSE,
109-
ahead,
110111
prefix = "ahead_",
111112
default = NA,
112113
columns = NULL,
113114
skip = FALSE,
114115
id = rand_id("epi_ahead")) {
115116
if (!is_epi_recipe(recipe)) {
116-
rlang::abort("This recipe step can only operate on an `epi_recipe`.")
117+
cli::cli_abort("This step can only operate on an `epi_recipe`.")
117118
}
118119

119120
if (missing(ahead)) {
120-
rlang::abort(
121-
c("The `ahead` argument must not be empty.",
122-
i = "Did you perhaps pass an integer in `...` accidentally?"
123-
)
124-
)
121+
cli::cli_abort(c(
122+
"The `ahead` argument must not be empty.",
123+
i = "Did you perhaps pass an integer in `...` accidentally?"
124+
))
125125
}
126126
arg_is_nonneg_int(ahead)
127127
arg_is_chr_scalar(prefix, id)
@@ -136,7 +136,7 @@ step_epi_ahead <-
136136
terms = dplyr::enquos(...),
137137
role = role,
138138
trained = trained,
139-
ahead = ahead,
139+
ahead = as.integer(ahead),
140140
prefix = prefix,
141141
default = default,
142142
keys = epi_keys(recipe),

R/time_types.R

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
guess_time_type <- function(time_value) {
2+
# similar to epiprocess:::guess_time_type() but w/o the gap handling
3+
arg_is_scalar(time_value)
4+
if (is.character(time_value)) {
5+
if (nchar(time_value) <= "10") {
6+
new_time_value <- tryCatch(
7+
{
8+
as.Date(time_value)
9+
},
10+
error = function(e) NULL
11+
)
12+
} else {
13+
new_time_value <- tryCatch(
14+
{
15+
as.POSIXct(time_value)
16+
},
17+
error = function(e) NULL
18+
)
19+
}
20+
if (!is.null(new_time_value)) time_value <- new_time_value
21+
}
22+
if (inherits(time_value, "POSIXct")) {
23+
return("day-time")
24+
}
25+
if (inherits(time_value, "Date")) {
26+
return("day")
27+
}
28+
if (inherits(time_value, "yearweek")) {
29+
return("yearweek")
30+
}
31+
if (inherits(time_value, "yearmonth")) {
32+
return("yearmonth")
33+
}
34+
if (inherits(time_value, "yearquarter")) {
35+
return("yearquarter")
36+
}
37+
if (is.numeric(time_value) && all(time_value == as.integer(time_value)) &&
38+
all(time_value >= 1582)) {
39+
return("year")
40+
}
41+
return("custom")
42+
}
43+
44+
coerce_time_type <- function(x, target_type) {
45+
if (target_type == "year") {
46+
if (is.numeric(x)) {
47+
return(as.integer(x))
48+
} else {
49+
return(as.POSIXlt(x)$year + 1900L)
50+
}
51+
}
52+
switch(target_type,
53+
"day-time" = as.POSIXct(x),
54+
"day" = as.Date(x),
55+
"week" = as.Date(x),
56+
"yearweek" = tsibble::yearweek(x),
57+
"yearmonth" = tsibble::yearmonth(x),
58+
"yearquarter" = tsibble::yearquarter(x)
59+
)
60+
}
61+
62+
validate_date <- function(x, expected, arg = rlang::caller_arg(x),
63+
call = rlang::caller_env()) {
64+
time_type_x <- guess_time_type(x)
65+
ok <- time_type_x == expected
66+
if (!ok) {
67+
cli::cli_abort(c(
68+
"The {.arg {arg}} was given as a {.val {time_type_x}} while the",
69+
`!` = "`time_type` of the training data was {.val {expected}}.",
70+
i = "See {.topic epiprocess::epi_df} for descriptions of these are determined."
71+
), call = call)
72+
}
73+
}

0 commit comments

Comments
 (0)