Skip to content

Commit 83fc279

Browse files
committed
moving shift detection earlier,dropping string*dep
1 parent def8f37 commit 83fc279

10 files changed

+151
-169
lines changed

DESCRIPTION

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ Imports:
4141
rlang,
4242
smoothqr,
4343
stats,
44-
stringr,
45-
stringi,
4644
tibble,
4745
tidyr,
4846
tidyselect,

NAMESPACE

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ importFrom(dplyr,n)
229229
importFrom(dplyr,pull)
230230
importFrom(dplyr,rowwise)
231231
importFrom(dplyr,summarise)
232+
importFrom(dplyr,tibble)
232233
importFrom(dplyr,ungroup)
233234
importFrom(epiprocess,growth_rate)
234235
importFrom(generics,augment)
@@ -237,8 +238,8 @@ importFrom(generics,forecast)
237238
importFrom(ggplot2,autoplot)
238239
importFrom(hardhat,refresh_blueprint)
239240
importFrom(hardhat,run_mold)
240-
importFrom(magrittr,"%<>%")
241241
importFrom(magrittr,"%>%")
242+
importFrom(purrr,map_lgl)
242243
importFrom(quantreg,rq)
243244
importFrom(recipes,bake)
244245
importFrom(recipes,prep)
@@ -261,12 +262,11 @@ importFrom(stats,predict)
261262
importFrom(stats,qnorm)
262263
importFrom(stats,quantile)
263264
importFrom(stats,residuals)
264-
importFrom(stringi,stri_replace_all_regex)
265-
importFrom(stringr,str_match)
266265
importFrom(tibble,as_tibble)
267266
importFrom(tibble,is_tibble)
268267
importFrom(tibble,tibble)
269268
importFrom(tidyr,drop_na)
269+
importFrom(tidyr,unnest)
270270
importFrom(vctrs,as_list_of)
271271
importFrom(vctrs,field)
272272
importFrom(vctrs,new_rcrd)

R/step_adjust_latency.R

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ step_adjust_latency <-
9090
fixed_asof = NULL,
9191
default = NA,
9292
skip = FALSE,
93-
prefix = NULL,
9493
columns = NULL,
9594
id = recipes::rand_id("epi_lag")) {
95+
arg_is_chr_scalar(id, method)
9696
if (!is_epi_recipe(recipe)) {
9797
cli::cli_abort("This recipe step can only operate on an `epi_recipe`.")
9898
}
@@ -103,24 +103,28 @@ step_adjust_latency <-
103103
}
104104

105105
method <- rlang::arg_match(method)
106+
terms_used <- recipes_eval_select(enquos(...), recipe$template, recipe$term_info)
107+
if (length(terms_used) == 0) {
108+
terms_used <- recipe$term_info %>%
109+
filter(role == "raw") %>%
110+
pull(variable)
111+
}
106112
if (method == "extend_ahead") {
107-
prefix <- "ahead_"
108-
if (!any(map_lgl(
109-
recipe$steps,
110-
function(recipe_step) inherits(recipe_step, "step_epi_ahead")
111-
))) {
112-
cli:cli_abort("There is no `step_epi_ahead` defined before this. For the method `extend_ahead` of `step_adjust_latency`, at least one ahead must be previously defined.")
113-
}
113+
rel_step_type <- "step_epi_ahead"
114+
shift_name <- "ahead"
114115
} else if (method == "extend_lags") {
115-
prefix <- "lag_"
116-
if (!any(map_lgl(
117-
recipe$steps,
118-
function(recipe_step) inherits(recipe_step, "step_epi_lag")
119-
))) {
120-
cli:cli_abort("There is no `step_epi_lag` defined before this. For the method `extend_lags` of `step_adjust_latency`, at least one lag must be previously defined.")
121-
}
116+
rel_step_type <- "step_epi_lag"
117+
shift_name <- "lag"
118+
}
119+
relevant_shifts <- construct_shift_tibble(terms_used, recipe, rel_step_type, shift_name)
120+
121+
if (!any(map_lgl(
122+
recipe$steps,
123+
function(recipe_step) inherits(recipe_step, rel_step_type)
124+
))) {
125+
cli:cli_abort("there is no `{rel_step_type}` defined before this. for the method `extend_{shift_name}` of `step_adjust_latency`, at least one {shift_name} must be previously defined.")
122126
}
123-
arg_is_chr_scalar(prefix, id, method)
127+
124128
recipes::add_step(
125129
recipe,
126130
step_adjust_latency_new(
@@ -130,8 +134,7 @@ step_adjust_latency <-
130134
trained = trained,
131135
as_of = fixed_asof,
132136
latency = fixed_latency,
133-
shift_cols = NULL,
134-
prefix = prefix,
137+
shift_cols = relevant_shifts,
135138
default = default,
136139
keys = epi_keys(recipe),
137140
skip = skip,
@@ -141,7 +144,7 @@ step_adjust_latency <-
141144
}
142145

143146
step_adjust_latency_new <-
144-
function(terms, role, trained, as_of, latency, shift_cols, prefix, time_type, default,
147+
function(terms, role, trained, as_of, latency, shift_cols, time_type, default,
145148
keys, method, skip, id) {
146149
step(
147150
subclass = "adjust_latency",
@@ -152,7 +155,6 @@ step_adjust_latency_new <-
152155
as_of = as_of,
153156
latency = latency,
154157
shift_cols = shift_cols,
155-
prefix = prefix,
156158
default = default,
157159
keys = keys,
158160
skip = skip,
@@ -163,6 +165,7 @@ step_adjust_latency_new <-
163165
# lags introduces max(lags) NA's after the max_time_value.
164166
# TODO all of the shifting happens before NA removal, which saves all the data I might possibly want; I should probably add a bit that makes sure this operation is happening before NA removal so data doesn't get dropped
165167
#' @export
168+
#' @importFrom glue glue
166169
prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
167170
if ((x$method == "extend_ahead") && (!("outcome" %in% info$role))) {
168171
cli::cli_abort('If `method` is `"extend_ahead"`, then a step ",
@@ -172,7 +175,6 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
172175
"must have already added a predictor.')
173176
}
174177

175-
sign_shift <- get_sign(x)
176178
# get the columns used, even if it's all of them
177179
terms_used <- x$columns
178180
if (length(terms_used) == 0) {
@@ -185,16 +187,18 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
185187

186188
# infer the correct columns to be working with from the previous
187189
# transformations
188-
shift_cols <- get_shifted_column_tibble(
189-
x$prefix, training, terms_used,
190-
as_of, x$latency, sign_shift, info
190+
x$prefix <- x$shift_cols$prefix[[1]]
191+
sign_shift <- get_sign(x)
192+
latency_cols <- get_latent_column_tibble(
193+
x$shift_cols, training, as_of,
194+
x$latency, sign_shift, info
191195
)
192196

193197
if ((x$method == "extend_ahead") || (x$method == "extend_lags")) {
194198
# check that the shift amount isn't too extreme
195-
latency <- max(shift_cols$latency)
199+
latency <- max(latency_cols$latency)
196200
time_type <- attributes(training)$metadata$time_type
197-
i_latency <- which.max(shift_cols$latency)
201+
i_latency <- which.max(latency_cols$latency)
198202
if (
199203
(grepl("day", time_type) && (latency >= 10)) ||
200204
(grepl("week", time_type) && (latency >= 4)) ||
@@ -203,26 +207,25 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
203207
((time_type == "year") && (latency >= 1))
204208
) {
205209
cli::cli_warn(c(
206-
"!" = glue::glue(
210+
"!" = glue(
207211
"The shift has been adjusted by {latency}, ",
208212
"which is questionable for it's `time_type` of ",
209213
"{time_type}"
210214
),
211-
"i" = "input shift: {shift_cols$shifts[[i_latency]]}",
212-
"i" = "latency adjusted shift: {shift_cols$effective_shift[[i_latency]]}",
215+
"i" = "input shift: {latency_cols$shift[[i_latency]]}",
216+
"i" = "latency adjusted shift: {latency_cols$effective_shift[[i_latency]]}",
213217
"i" = "max_time = {max_time} -> as_of = {as_of}"
214218
))
215219
}
216220
}
217221

218222
step_adjust_latency_new(
219-
terms = shift_cols$original_name,
220-
role = shift_cols$role[[1]],
223+
terms = latency_cols$original_name,
224+
role = latency_cols$role[[1]],
221225
trained = TRUE,
222-
prefix = x$prefix,
223-
shift_cols = shift_cols,
226+
shift_cols = latency_cols,
224227
as_of = as_of,
225-
latency = unique(shift_cols$latency),
228+
latency = unique(latency_cols$latency),
226229
default = x$default,
227230
keys = x$keys,
228231
method = x$method,

R/utils-latency.R

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
extend_either <- function(new_data, shift_cols, keys) {
1111
shifted <-
1212
shift_cols %>%
13-
select(-any_of(c("shifts", "effective_shift", "type", "role", "source"))) %>%
14-
pmap(\(original_name, latency, new_name) {
13+
select(original_name, latency, new_name) %>%
14+
pmap(function(original_name, latency, new_name) {
1515
epi_shift_single(
1616
x = new_data,
1717
col = original_name,
@@ -20,7 +20,7 @@ extend_either <- function(new_data, shift_cols, keys) {
2020
key_cols = keys
2121
)
2222
}) %>%
23-
map(\(x) zoo::na.trim(x)) %>%
23+
map(function(x) zoo::na.trim(x)) %>%
2424
reduce(
2525
dplyr::full_join,
2626
by = keys
@@ -34,60 +34,66 @@ extend_either <- function(new_data, shift_cols, keys) {
3434
dplyr::ungroup())
3535
}
3636

37+
#' create a table of the columns to modify, their shifts, and their prefixes
38+
#' @keywords internal
39+
#' @importFrom dplyr tibble
40+
#' @importFrom tidyr unnest
41+
construct_shift_tibble <- function(terms_used, recipe, rel_step_type, shift_name) {
42+
# for the right step types (either "step_epi_lag" or "step_epi_shift"), grab
43+
# the useful parameters, including the evaluated column names
44+
extract_named_rates <- function(recipe_step) {
45+
if (inherits(recipe_step, rel_step_type)) {
46+
recipe_columns <- recipes_eval_select(recipe_step$terms, recipe$template, recipe$term_info)
47+
if (any(recipe_columns %in% terms_used)) {
48+
return(list(term = recipe_columns, shift = recipe_step[shift_name], prefix = recipe_step$prefix))
49+
}
50+
}
51+
return(NULL)
52+
}
53+
rel_list <- recipe$steps %>%
54+
purrr::map(extract_named_rates) %>%
55+
unlist(recursive = FALSE) %>%
56+
split(c("term", "shift", "prefix"))
57+
relevant_shifts <- tibble(
58+
terms = lapply(rel_list$term, unname),
59+
shift = lapply(rel_list$shift, unname),
60+
prefix = unname(unlist(rel_list$prefix))
61+
) %>%
62+
unnest(c(terms, shift)) %>%
63+
unnest(shift)
64+
return(relevant_shifts)
65+
}
66+
3767
#' find the columns added with the lags or aheads, and the amounts they have
3868
#' been changed
3969
#' @param prefix the prefix indicating if we are adjusting lags or aheads
4070
#' @param new_data the data transformed so far
4171
#' @return a tibble with columns `column` (relevant shifted names), `shift` (the
4272
#' amount that one is shifted), `latency` (original columns difference between
4373
#' max_time_value and as_of (on a per-initial column basis)),
44-
#' `effective_shift` (shifts+latency), and `new_name` (adjusted names with the
74+
#' `effective_shift` (shift+latency), and `new_name` (adjusted names with the
4575
#' effective_shift)
4676
#' @keywords internal
47-
#' @importFrom stringr str_match
4877
#' @importFrom dplyr rowwise %>%
49-
#' @importFrom magrittr %<>%
50-
get_shifted_column_tibble <- function(
51-
prefix, new_data, terms_used, as_of, latency,
78+
#' @importFrom purrr map_lgl
79+
#' @importFrom glue glue
80+
get_latent_column_tibble <- function(
81+
shift_cols, new_data, as_of, latency,
5282
sign_shift, info, call = caller_env()) {
53-
relevant_columns <- names(new_data)[grepl(prefix, names(new_data))]
54-
to_keep <- rep(FALSE, length(relevant_columns))
55-
for (col_name in terms_used) {
56-
to_keep <- to_keep | grepl(col_name, relevant_columns)
57-
}
58-
relevant_columns <- relevant_columns[to_keep]
59-
if (length(relevant_columns) == 0) {
60-
cli::cli_abort("There is no column(s) {terms_used}.",
61-
current_column_names = names(new_data),
62-
class = "epipredict_adjust_latency_nonexistent_column_used",
63-
call = call
64-
)
65-
}
66-
# this pulls text that is any number of digits between two _, e.g. _3557_, and
67-
# converts them to an integer
68-
shift_amounts <- stringr::str_match(relevant_columns, "_(\\d+)_") %>%
69-
`[`(, 2) %>%
70-
as.integer()
71-
72-
shift_cols <- dplyr::tibble(
73-
original_name = relevant_columns,
74-
shifts = shift_amounts
75-
)
83+
shift_cols <- shift_cols %>% mutate(original_name = glue("{prefix}{shift}_{terms}"))
7684
if (is.null(latency)) {
7785
shift_cols <- shift_cols %>%
7886
rowwise() %>%
7987
# add the latencies to shift_cols
8088
mutate(latency = get_latency(
81-
new_data, as_of, original_name, shifts, sign_shift
89+
new_data, as_of, original_name, shift, sign_shift
8290
)) %>%
8391
ungroup()
8492
} else if (length(latency) > 1) {
93+
# if latency has a length, we assign based on comparing the name in the list with the `terms` column
8594
shift_cols <- shift_cols %>%
8695
rowwise() %>%
87-
mutate(latency = unname(latency[purrr::map_lgl(
88-
names(latency),
89-
\(x) grepl(x, original_name)
90-
)])) %>%
96+
mutate(latency = unname(latency[names(latency) == terms])) %>%
9197
ungroup()
9298
} else {
9399
shift_cols <- shift_cols %>% mutate(latency = latency)
@@ -96,10 +102,10 @@ get_shifted_column_tibble <- function(
96102
# add the updated names to shift_cols
97103
shift_cols <- shift_cols %>%
98104
mutate(
99-
effective_shift = shifts + abs(latency)
105+
effective_shift = shift + abs(latency)
100106
) %>%
101107
mutate(
102-
new_name = adjust_name(prefix, original_name, effective_shift)
108+
new_name = glue("{prefix}{effective_shift}_{terms}")
103109
)
104110
info <- info %>% select(variable, type, role)
105111
shift_cols <- left_join(shift_cols, info, by = join_by(original_name == variable))
@@ -166,19 +172,6 @@ set_asof <- function(new_data, info) {
166172
return(as_of)
167173
}
168174

169-
#' adjust the shifts by latency for the names in column assumes e.g.
170-
#' `"lag_6_case_rate"` and returns something like `"lag_10_case_rate"`
171-
#' @keywords internal
172-
#' @importFrom stringi stri_replace_all_regex
173-
adjust_name <- function(prefix, column, effective_shift) {
174-
pattern <- paste0(prefix, "\\d+", "_")
175-
adjusted_shifts <- paste0(prefix, effective_shift, "_")
176-
stringi::stri_replace_all_regex(
177-
column,
178-
pattern, adjusted_shifts
179-
)
180-
}
181-
182175
#' the latency is also the amount the shift is off by
183176
#' @param sign_shift integer. 1 if lag and -1 if ahead. These represent how you
184177
#' need to shift the data to bring the 3 day lagged value to today.

man/adjust_name.Rd

Lines changed: 0 additions & 14 deletions
This file was deleted.

man/construct_shift_tibble.Rd

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)