Skip to content

Commit 02659f2

Browse files
committed
refactor step_adjust_ahead to be early step
1 parent 2f8090d commit 02659f2

21 files changed

+265
-434
lines changed

NAMESPACE

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ importFrom(dplyr,all_of)
229229
importFrom(dplyr,group_by)
230230
importFrom(dplyr,join_by)
231231
importFrom(dplyr,left_join)
232+
importFrom(dplyr,mutate)
232233
importFrom(dplyr,n)
233234
importFrom(dplyr,pull)
234-
importFrom(dplyr,rowwise)
235235
importFrom(dplyr,select)
236236
importFrom(dplyr,summarise)
237237
importFrom(dplyr,tibble)
@@ -242,6 +242,7 @@ importFrom(generics,fit)
242242
importFrom(generics,forecast)
243243
importFrom(ggplot2,autoplot)
244244
importFrom(glue,glue)
245+
importFrom(hardhat,extract_recipe)
245246
importFrom(hardhat,refresh_blueprint)
246247
importFrom(hardhat,run_mold)
247248
importFrom(magrittr,"%>%")
@@ -250,6 +251,7 @@ importFrom(recipes,bake)
250251
importFrom(recipes,detect_step)
251252
importFrom(recipes,prep)
252253
importFrom(rlang,"!!!")
254+
importFrom(recipes,recipes_eval_select)
253255
importFrom(rlang,"!!")
254256
importFrom(rlang,"%@%")
255257
importFrom(rlang,"%||%")
@@ -276,6 +278,7 @@ importFrom(stats,quantile)
276278
importFrom(stats,residuals)
277279
importFrom(tibble,tibble)
278280
importFrom(tidyr,drop_na)
281+
importFrom(tidyr,expand_grid)
279282
importFrom(tidyr,unnest)
280283
importFrom(vctrs,as_list_of)
281284
importFrom(vctrs,field)

R/arx_classifier.R

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,6 @@ arx_class_epi_workflow <- function(
193193
}
194194
}
195195
o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o))
196-
r <- r %>%
197-
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome")
198196
method_adjust_latency <- args_list$adjust_latency
199197
if (!is.null(method_adjust_latency)) {
200198
# only extend_ahead is supported atm
@@ -203,6 +201,8 @@ arx_class_epi_workflow <- function(
203201
method = method_adjust_latency
204202
)
205203
}
204+
r <- r %>%
205+
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome")
206206
r <- r %>%
207207
step_mutate(
208208
outcome_class = cut(!!o2, breaks = args_list$breaks),
@@ -220,10 +220,6 @@ arx_class_epi_workflow <- function(
220220
drop_na = FALSE
221221
)
222222
}
223-
if (!is.null) {
224-
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
225-
}
226-
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
227223

228224
# --- postprocessor
229225
f <- frosting() %>% layer_predict() # %>% layer_naomit()
@@ -308,9 +304,6 @@ arx_class_args_list <- function(
308304

309305
arg_is_scalar(ahead, n_training, horizon, log_scale)
310306
arg_is_scalar(forecast_date, target_date, adjust_latency, allow_null = TRUE)
311-
if (adjust_latency == "adjust_lags") {
312-
cli::cli_abort("step_adjust_latency is not yet implemented for lagged differences and growth rates")
313-
}
314307
arg_is_date(forecast_date, target_date, allow_null = TRUE)
315308
arg_is_nonneg_int(ahead, lags, horizon)
316309
arg_is_numeric(breaks)

R/arx_forecaster.R

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,7 @@ arx_fcast_epi_workflow <- function(
142142

143143
# --- preprocessor
144144
r <- epi_recipe(epi_data)
145-
for (l in seq_along(lags)) {
146-
p <- predictors[l]
147-
r <- step_epi_lag(r, !!p, lag = lags[[l]])
148-
}
149-
r <- r %>%
150-
step_epi_ahead(!!outcome, ahead = args_list$ahead)
145+
# adjust latency if the user asks
151146
method_adjust_latency <- args_list$adjust_latency
152147
if (!is.null(method_adjust_latency)) {
153148
if (method_adjust_latency == "extend_ahead") {
@@ -162,6 +157,12 @@ arx_fcast_epi_workflow <- function(
162157
)
163158
}
164159
}
160+
for (l in seq_along(lags)) {
161+
p <- predictors[l]
162+
r <- step_epi_lag(r, !!p, lag = lags[[l]])
163+
}
164+
r <- r %>%
165+
step_epi_ahead(!!outcome, ahead = args_list$ahead)
165166
r <- r %>%
166167
step_epi_naomit() %>%
167168
step_training_window(n_recent = args_list$n_training)

R/canned-epipred.R

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ print.alist <- function(x, ...) {
6363
}
6464

6565
#' @export
66+
#' @importFrom hardhat extract_recipe
6667
print.canned_epipred <- function(x, name, ...) {
6768
d <- cli::cli_div(theme = list(rule = list("line-type" = "double")))
6869
cli::cli_rule("A basic forecaster of type {name}")
@@ -110,19 +111,33 @@ print.canned_epipred <- function(x, name, ...) {
110111
"At forecast date{?s}: {.val {fds}},",
111112
"For target date{?s}: {.val {tds}},"
112113
))
113-
if (detect_step(x$epi_workflow$pre$actions$recipe$recipe, "adjust_latency")) {
114-
latency_step <- keep(x$epi_workflow$pre$mold$blueprint$recipe$steps,
115-
\(x) inherits(x, "step_adjust_latency"))[[1]]
116-
latency_per_base_col <- latency_step$shift_cols %>%
117-
group_by(latency) %>%
118-
reframe(variable = parent_name) %>%
119-
distinct() %>%
120-
mutate(latency = abs(latency)) %>%
121-
relocate(variable, latency)
122-
if (nrow(latency_per_base_col)>1) {
123-
intro_text <- "Latency adjusted per column: "
114+
fit_recipe <- extract_recipe(x$epi_workflow)
115+
if (detect_step(fit_recipe, "adjust_latency")) {
116+
is_adj_latency <- map_lgl(fit_recipe$steps, \(x) inherits(x, "step_adjust_latency"))
117+
latency_step <- fit_recipe$steps[is_adj_latency][[1]]
118+
# all steps after adjust_latency
119+
later_steps <- fit_recipe$steps[-(1:which(is_adj_latency))]
120+
if (latency_step$method == "extend_ahead") {
121+
step_names <- "step_epi_ahead"
122+
type_str <- "Aheads"
123+
} else if (latency_step$method == "extend_lags") {
124+
step_names <- "step_epi_lag"
125+
type_str <- "Lags"
124126
} else {
125-
intro_text <- "Latency adjusted for "
127+
step_names <- ""
128+
type_str <- "columns locf"
129+
}
130+
later_steps[[1]]$columns
131+
valid_columns <- later_steps %>%
132+
keep(\(x) inherits(x, step_names)) %>%
133+
purrr::map("columns") %>%
134+
reduce(c)
135+
latency_per_base_col <- latency_step$latency_table %>%
136+
filter(col_name %in% valid_columns) %>% mutate(latency = abs(latency))
137+
if (latency_step$method != "locf" && nrow(latency_per_base_col) > 1) {
138+
intro_text <- glue::glue("{type_str} adjusted per column: ")
139+
} else if (latency_step$method != "locf") {
140+
intro_text <- glue::glue("{type_str} adjusted for ")
126141
}
127142
latency_info <- paste0(intro_text, paste(apply(latency_per_base_col, 1, paste0, collapse = "="), collapse = ", "))
128143
cli::cli_ul(latency_info)

R/epi_shift.R

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ epi_shift_single <- function(x, col, shift_val, newname, key_cols) {
2121
#' the future back to today
2222
#' @keywords internal
2323
get_sign <- function(object) {
24-
if (object$prefix == "lag_") {
24+
if (!is.null(object$prefix)) {
25+
if (object$prefix == "lag_") {
26+
return(1)
27+
} else {
28+
return(-1)
29+
}
30+
} else if (object$method == "extend_lags") {
2531
return(1)
2632
} else {
2733
return(-1)
@@ -31,13 +37,28 @@ get_sign <- function(object) {
3137
#' backend for both `bake.step_epi_ahead` and `bake.step_epi_lag`, performs the
3238
#' checks missing in `epi_shift_single`
3339
#' @keywords internal
40+
#' @importFrom tidyr expand_grid
41+
#' @importFrom dplyr mutate left_join join_by
3442
add_shifted_columns <- function(new_data, object, amount) {
3543
sign_shift <- get_sign(object)
36-
grid <- tidyr::expand_grid(col = object$columns, amount = amount) %>%
44+
latency_table <- attributes(new_data)$metadata$latency_table
45+
shift_sign_lat <- attributes(new_data)$metadata$shift_sign
46+
if (!is.null(latency_table) &&
47+
shift_sign_lat == sign_shift) {
48+
#TODO this doesn't work on lags of transforms
49+
rel_latency <- latency_table %>% filter(col_name %in% object$columns)
50+
} else {
51+
rel_latency <- tibble(col_name = object$columns, latency = 0L)
52+
}
53+
grid <- expand_grid(col = object$columns, amount = sign_shift *amount) %>%
54+
left_join(rel_latency, by = join_by(col == col_name), ) %>%
55+
tidyr::replace_na(list(latency = 0)) %>%
3756
dplyr::mutate(
38-
newname = glue::glue("{object$prefix}{amount}_{col}"),
39-
shift_val = sign_shift * amount,
40-
amount = NULL
57+
shift_val = amount + latency) %>%
58+
mutate(
59+
newname = glue::glue("{object$prefix}{abs(shift_val)}_{col}"), # name is always positive
60+
amount = NULL,
61+
latency = NULL
4162
)
4263

4364
## ensure no name clashes
@@ -59,8 +80,12 @@ add_shifted_columns <- function(new_data, object, amount) {
5980
dplyr::full_join,
6081
by = ok
6182
)
62-
dplyr::full_join(new_data, shifted, by = ok) %>%
83+
processed <- new_data %>%
84+
dplyr::full_join(shifted, by = ok) %>%
6385
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
6486
dplyr::arrange(time_value) %>%
65-
dplyr::ungroup()
87+
dplyr::ungroup() %>%
88+
as_epi_df()
89+
attributes(processed)$metadata <- attributes(new_data)$metadata
90+
return(processed)
6691
}

R/layer_add_forecast_date.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ slather.layer_add_forecast_date <- function(object, components, workflow, new_da
9292
rlang::check_dots_empty()
9393
forecast_date <- object$forecast_date %||%
9494
get_forecast_date_in_layer(
95-
extract_preprocessor(workflow),
95+
extract_recipe(workflow),
9696
workflow$fit$meta$max_time_value,
9797
new_data
9898
)

0 commit comments

Comments
 (0)