Skip to content

Commit f805ec1

Browse files
committed
moving locf to step_adjust_ahead instead of get_test_data
1 parent 02659f2 commit f805ec1

20 files changed

+223
-245
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Imports:
3939
quantreg,
4040
recipes (>= 1.0.4),
4141
rlang (>= 1.0.0),
42+
purrr,
4243
smoothqr,
4344
stats,
4445
tibble,

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ importFrom(dplyr,"%>%")
227227
importFrom(dplyr,across)
228228
importFrom(dplyr,all_of)
229229
importFrom(dplyr,group_by)
230+
importFrom(dplyr,group_by_at)
230231
importFrom(dplyr,join_by)
231232
importFrom(dplyr,left_join)
232233
importFrom(dplyr,mutate)
@@ -279,6 +280,7 @@ importFrom(stats,residuals)
279280
importFrom(tibble,tibble)
280281
importFrom(tidyr,drop_na)
281282
importFrom(tidyr,expand_grid)
283+
importFrom(tidyr,fill)
282284
importFrom(tidyr,unnest)
283285
importFrom(vctrs,as_list_of)
284286
importFrom(vctrs,field)

R/arx_classifier.R

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ arx_classifier <- function(
6666
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
6767
preds <- forecast(
6868
wf,
69-
fill_locf = is.null(args_list$adjust_latency),
70-
n_recent = args_list$nafill_buffer,
71-
forecast_date = forecast_date
7269
) %>%
7370
tibble::as_tibble() %>%
7471
dplyr::select(-time_value)
@@ -292,7 +289,6 @@ arx_class_args_list <- function(
292289
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
293290
log_scale = FALSE,
294291
additional_gr_args = list(),
295-
nafill_buffer = Inf,
296292
check_enough_data_n = NULL,
297293
check_enough_data_epi_keys = NULL,
298294
...) {
@@ -310,7 +306,6 @@ arx_class_args_list <- function(
310306
arg_is_lgl(log_scale)
311307
arg_is_pos(n_training)
312308
if (is.finite(n_training)) arg_is_pos_int(n_training)
313-
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
314309
if (!is.list(additional_gr_args)) {
315310
cli::cli_abort(
316311
c("`additional_gr_args` must be a {.cls list}.",
@@ -352,7 +347,6 @@ arx_class_args_list <- function(
352347
method,
353348
log_scale,
354349
additional_gr_args,
355-
nafill_buffer,
356350
check_enough_data_n,
357351
check_enough_data_epi_keys
358352
),

R/arx_forecaster.R

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,7 @@ arx_forecaster <- function(
5151
wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5252
wf <- generics::fit(wf, epi_data)
5353

54-
preds <- forecast(
55-
wf,
56-
fill_locf = is.null(args_list$adjust_latency),
57-
n_recent = args_list$nafill_buffer,
58-
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
59-
) %>%
54+
preds <- forecast(wf) %>%
6055
tibble::as_tibble() %>%
6156
dplyr::select(-time_value)
6257

@@ -251,15 +246,6 @@ arx_fcast_epi_workflow <- function(
251246
#' `character(0)` performs no grouping. This argument only applies when
252247
#' residual quantiles are used. It is not applicable with
253248
#' `trainer = quantile_reg()`, for example.
254-
#' @param nafill_buffer At predict time, recent values of the training data
255-
#' are used to create a forecast. However, these can be `NA` due to, e.g.,
256-
#' data latency issues. By default, any missing values will get filled with
257-
#' less recent data. Setting this value to `NULL` will result in 1 extra
258-
#' recent row (beyond those required for lag creation) to be used. Note that
259-
#' we require at least `min(lags)` rows of recent data per `geo_value` to
260-
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
261-
#' will be treated as _additional_ allowed recent data rather than the
262-
#' total amount of recent data to examine.
263249
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
264250
#' epi_key that are required for training. If `NULL`, this check is ignored.
265251
#' @param check_enough_data_epi_keys Character vector. A character vector of
@@ -286,7 +272,6 @@ arx_args_list <- function(
286272
symmetrize = TRUE,
287273
nonneg = TRUE,
288274
quantile_by_key = character(0L),
289-
nafill_buffer = Inf,
290275
check_enough_data_n = NULL,
291276
check_enough_data_epi_keys = NULL,
292277
...) {
@@ -304,7 +289,6 @@ arx_args_list <- function(
304289
arg_is_probabilities(quantile_levels, allow_null = TRUE)
305290
arg_is_pos(n_training)
306291
if (is.finite(n_training)) arg_is_pos_int(n_training)
307-
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
308292
arg_is_pos(check_enough_data_n, allow_null = TRUE)
309293
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)
310294

@@ -331,7 +315,6 @@ arx_args_list <- function(
331315
nonneg,
332316
max_lags,
333317
quantile_by_key,
334-
nafill_buffer,
335318
check_enough_data_n,
336319
check_enough_data_epi_keys
337320
),

R/cdc_baseline_forecaster.R

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ cdc_baseline_forecaster <- function(
7979

8080

8181
latest <- get_test_data(
82-
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
83-
forecast_date
84-
)
82+
epi_recipe(epi_data), epi_data)
8583

8684
f <- frosting() %>%
8785
layer_predict() %>%
@@ -169,7 +167,6 @@ cdc_baseline_args_list <- function(
169167
symmetrize = TRUE,
170168
nonneg = TRUE,
171169
quantile_by_key = "geo_value",
172-
nafill_buffer = Inf,
173170
...) {
174171
rlang::check_dots_empty()
175172
arg_is_scalar(n_training, nsims, data_frequency)
@@ -183,7 +180,6 @@ cdc_baseline_args_list <- function(
183180
arg_is_probabilities(quantile_levels, allow_null = TRUE)
184181
arg_is_pos(n_training)
185182
if (is.finite(n_training)) arg_is_pos_int(n_training)
186-
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
187183

188184
structure(
189185
enlist(
@@ -195,8 +191,7 @@ cdc_baseline_args_list <- function(
195191
nsims,
196192
symmetrize,
197193
nonneg,
198-
quantile_by_key,
199-
nafill_buffer
194+
quantile_by_key
200195
),
201196
class = c("cdc_baseline_fcast", "alist")
202197
)

R/epi_workflow.R

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,9 @@ forecast.epi_workflow <- function(object, ..., fill_locf = FALSE, n_recent = NUL
269269
))
270270
}
271271
}
272-
273272
test_data <- get_test_data(
274273
hardhat::extract_preprocessor(object),
275-
object$original_data,
276-
fill_locf = fill_locf,
277-
n_recent = n_recent %||% Inf,
278-
forecast_date = forecast_date %||% frosting_fd %||% max(object$original_data$time_value)
274+
object$original_data
279275
)
280276

281277
predict(object, new_data = test_data)

R/flatline_forecaster.R

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ flatline_forecaster <- function(
6666
wf <- generics::fit(wf, epi_data)
6767
preds <- suppressWarnings(forecast(
6868
wf,
69-
fill_locf = TRUE,
70-
n_recent = args_list$nafill_buffer,
71-
forecast_date = forecast_date
69+
fill_locf = TRUE
7270
)) %>%
7371
tibble::as_tibble() %>%
7472
dplyr::select(-time_value)
@@ -116,7 +114,6 @@ flatline_args_list <- function(
116114
symmetrize = TRUE,
117115
nonneg = TRUE,
118116
quantile_by_key = character(0L),
119-
nafill_buffer = Inf,
120117
...) {
121118
rlang::check_dots_empty()
122119
arg_is_scalar(ahead, n_training)
@@ -128,7 +125,6 @@ flatline_args_list <- function(
128125
arg_is_probabilities(quantile_levels, allow_null = TRUE)
129126
arg_is_pos(n_training)
130127
if (is.finite(n_training)) arg_is_pos_int(n_training)
131-
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
132128

133129
if (!is.null(forecast_date) && !is.null(target_date)) {
134130
if (forecast_date + ahead != target_date) {
@@ -148,8 +144,7 @@ flatline_args_list <- function(
148144
quantile_levels,
149145
symmetrize,
150146
nonneg,
151-
quantile_by_key,
152-
nafill_buffer
147+
quantile_by_key
153148
),
154149
class = c("flat_fcast", "alist")
155150
)

R/get_test_data.R

Lines changed: 7 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
#' @param recipe A recipe object.
2121
#' @param x An epi_df. The typical usage is to
2222
#' pass the same data as that used for fitting the recipe.
23-
#' @param fill_locf Logical. Should we use `locf` to fill in missing data?
24-
#' @param n_recent Integer or NULL. If filling missing data with `locf = TRUE`,
25-
#' how far back are we willing to tolerate missing data? Larger values allow
26-
#' more filling. The default `NULL` will determine this from the
27-
#' the `recipe`. For example, suppose `n_recent = 3`, then if the
28-
#' 3 most recent observations in any `geo_value` are all `NA`’s, we won’t be
29-
#' able to fill anything, and an error message will be thrown. (See details.)
3023
#' @param forecast_date By default, this is set to the maximum
3124
#' `time_value` in `x`. But if there is data latency such that recent `NA`'s
3225
#' should be filled, this may be _after_ the last available `time_value`.
@@ -45,18 +38,8 @@
4538

4639
get_test_data <- function(
4740
recipe,
48-
x,
49-
fill_locf = FALSE,
50-
n_recent = NULL,
51-
forecast_date = max(x$time_value)) {
41+
x) {
5242
if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.")
53-
arg_is_lgl(fill_locf)
54-
arg_is_scalar(fill_locf)
55-
arg_is_scalar(n_recent, allow_null = TRUE)
56-
if (!is.null(n_recent) && is.finite(n_recent)) {
57-
arg_is_pos_int(n_recent, allow_null = TRUE)
58-
}
59-
if (!is.null(n_recent)) n_recent <- abs(n_recent) # in case they passed -Inf
6043

6144
check <- hardhat::check_column_names(x, colnames(recipe$template))
6245
if (!check$ok) {
@@ -66,106 +49,34 @@ get_test_data <- function(
6649
))
6750
}
6851

69-
if (class(forecast_date) != class(x$time_value)) {
70-
cli::cli_abort("`forecast_date` must be the same class as `x$time_value`.")
71-
}
72-
73-
74-
if (forecast_date < max(x$time_value)) {
75-
cli::cli_abort("`forecast_date` must be no earlier than `max(x$time_value)`")
76-
}
77-
7852
min_lags <- min(map_dbl(recipe$steps, ~ min(.x$lag %||% Inf)), Inf)
7953
max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0)
8054
max_horizon <- max(map_dbl(recipe$steps, ~ max(.x$horizon %||% 0)), 0)
81-
min_required <- max_lags + max_horizon
82-
if (is.null(n_recent)) n_recent <- min_required + 1 # one extra for filling
83-
if (n_recent <= min_required) n_recent <- min_required + n_recent
55+
keep <- max_lags + max_horizon
8456

8557
# CHECK: Error out if insufficient training data
8658
# Probably needs a fix based on the time_type of the epi_df
8759
avail_recent <- diff(range(x$time_value))
88-
if (avail_recent < min_required) {
60+
if (avail_recent < keep) {
8961
cli::cli_abort(c(
9062
"You supplied insufficient recent data for this recipe. ",
9163
"!" = "You need at least {min_required} days of data,",
9264
"!" = "but `x` contains only {avail_recent}."
9365
))
9466
}
95-
67+
max_time_value <- x %>% na.omit %>% pull(time_value) %>% max
9668
x <- arrange(x, time_value)
9769
groups <- kill_time_value(epi_keys(recipe))
9870

9971
# If we skip NA completion, we remove undesirably early time values
10072
# Happens globally, over all groups
101-
keep <- max(n_recent, min_required + 1)
102-
x <- dplyr::filter(x, forecast_date - time_value <= keep)
103-
104-
# Pad with explicit missing values up to and including the forecast_date
105-
# x is grouped here
106-
x <- pad_to_end(x, groups, forecast_date) %>%
107-
epiprocess::group_by(dplyr::across(dplyr::all_of(groups)))
73+
x <- dplyr::filter(x, max_time_value - time_value <= keep)
10874

10975
# If all(lags > 0), then we get rid of recent data
11076
if (min_lags > 0 && min_lags < Inf) {
111-
x <- dplyr::filter(x, forecast_date - time_value >= min_lags)
77+
x <- dplyr::filter(x, max_time_value - time_value >= min_lags)
11278
}
11379

114-
# Now, fill forward missing data if requested
115-
if (fill_locf) {
116-
cannot_be_used <- x %>%
117-
dplyr::filter(forecast_date - time_value <= n_recent) %>%
118-
dplyr::mutate(fillers = forecast_date - time_value > min_required) %>%
119-
dplyr::summarise(
120-
dplyr::across(
121-
-tidyselect::any_of(epi_keys(recipe)),
122-
~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1))
123-
),
124-
.groups = "drop"
125-
) %>%
126-
dplyr::select(-fillers) %>%
127-
dplyr::summarise(dplyr::across(
128-
-tidyselect::any_of(epi_keys(recipe)), ~ any(.x)
129-
)) %>%
130-
unlist()
131-
if (any(cannot_be_used)) {
132-
bad_vars <- names(cannot_be_used)[cannot_be_used]
133-
if (recipes::is_trained(recipe)) {
134-
cli::cli_abort(c(
135-
"The variables {.var {bad_vars}} have too many recent missing",
136-
`!` = "values to be filled automatically. ",
137-
i = "You should either choose `n_recent` larger than its current ",
138-
i = "value {n_recent}, or perform NA imputation manually, perhaps with ",
139-
i = "{.code recipes::step_impute_*()} or with {.code tidyr::fill()}."
140-
))
141-
}
142-
}
143-
x <- tidyr::fill(x, !time_value)
144-
}
145-
146-
dplyr::filter(x, forecast_date - time_value <= min_required) %>%
80+
dplyr::filter(x, max_time_value - time_value <= keep) %>%
14781
epiprocess::ungroup()
14882
}
149-
150-
pad_to_end <- function(x, groups, end_date) {
151-
itval <- epiprocess:::guess_period(c(x$time_value, end_date), "time_value")
152-
completed_time_values <- x %>%
153-
dplyr::group_by(dplyr::across(tidyselect::all_of(groups))) %>%
154-
dplyr::summarise(
155-
time_value = rlang::list2(
156-
time_value = Seq(max(time_value) + itval, end_date, itval)
157-
)
158-
) %>%
159-
unnest("time_value") %>%
160-
mutate(time_value = vctrs::vec_cast(time_value, x$time_value))
161-
162-
dplyr::bind_rows(x, completed_time_values) %>%
163-
dplyr::arrange(dplyr::across(tidyselect::all_of(c("time_value", groups))))
164-
}
165-
166-
Seq <- function(from, to, by) {
167-
if (from > to) {
168-
return(NULL)
169-
}
170-
seq(from = from, to = to, by = by)
171-
}

0 commit comments

Comments
 (0)