Skip to content

Commit 9c86846

Browse files
committed
refactor: improve default ahead and target_date handling
1 parent 37756ed commit 9c86846

6 files changed

+52
-28
lines changed

R/arx_classifier.R

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ arx_class_epi_workflow <- function(
128128
}
129129
lags <- arx_lags_validator(predictors, args_list$lags)
130130

131+
args_list$forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
132+
if (is.null(args_list$ahead) && is.null(args_list$target_date)) {
133+
args_list$ahead <- 7L
134+
args_list$target_date <- args_list$forecast_date + args_list$ahead
135+
} else if (is.null(args_list$ahead)) {
136+
args_list$ahead <- as.integer(difftime(args_list$target_date, args_list$forecast_date, units = "days"))
137+
} else if (is.null(args_list$target_date)) {
138+
args_list$target_date <- args_list$forecast_date + args_list$ahead
139+
}
140+
131141
# --- preprocessor
132142
# ------- predictors
133143
r <- epi_recipe(epi_data) %>%
@@ -196,13 +206,10 @@ arx_class_epi_workflow <- function(
196206
}
197207
}
198208

199-
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
200-
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
201-
202209
# --- postprocessor
203210
f <- frosting() %>% layer_predict() # %>% layer_naomit()
204-
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
205-
layer_add_target_date(target_date = target_date)
211+
f <- layer_add_forecast_date(f, forecast_date = args_list$forecast_date) %>%
212+
layer_add_target_date(target_date = args_list$target_date)
206213

207214
epi_workflow(r, trainer, f)
208215
}
@@ -259,7 +266,7 @@ arx_class_epi_workflow <- function(
259266
#' arx_class_args_list(breaks = c(-.2, .25))
260267
arx_class_args_list <- function(
261268
lags = c(0L, 7L, 14L),
262-
ahead = 7L,
269+
ahead = NULL,
263270
n_training = Inf,
264271
forecast_date = NULL,
265272
target_date = NULL,
@@ -279,10 +286,12 @@ arx_class_args_list <- function(
279286
method <- match.arg(method)
280287
outcome_transform <- match.arg(outcome_transform)
281288

282-
arg_is_scalar(ahead, n_training, horizon, log_scale)
289+
arg_is_scalar(n_training, horizon, log_scale)
283290
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
284291
arg_is_date(forecast_date, target_date, allow_null = TRUE)
285-
arg_is_nonneg_int(ahead, lags, horizon)
292+
arg_is_nonneg_int(lags, horizon)
293+
arg_is_scalar(ahead, allow_null = TRUE)
294+
arg_is_nonneg_int(ahead, allow_null = TRUE)
286295
arg_is_numeric(breaks)
287296
arg_is_lgl(log_scale)
288297
arg_is_pos(n_training)

R/arx_forecaster.R

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ arx_fcast_epi_workflow <- function(
117117
}
118118
lags <- arx_lags_validator(predictors, args_list$lags)
119119

120+
args_list$forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
121+
if (is.null(args_list$ahead) && is.null(args_list$target_date)) {
122+
args_list$ahead <- 7L
123+
args_list$target_date <- args_list$forecast_date + args_list$ahead
124+
} else if (is.null(args_list$ahead)) {
125+
args_list$ahead <- as.integer(difftime(args_list$target_date, args_list$forecast_date, units = "days"))
126+
} else if (is.null(args_list$target_date)) {
127+
args_list$target_date <- args_list$forecast_date + args_list$ahead
128+
}
129+
120130
# --- preprocessor
121131
r <- epi_recipe(epi_data)
122132
for (l in seq_along(lags)) {
@@ -142,9 +152,6 @@ arx_fcast_epi_workflow <- function(
142152
}
143153
}
144154

145-
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
146-
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
147-
148155
# --- postprocessor
149156
f <- frosting() %>% layer_predict() # %>% layer_naomit()
150157
if (inherits(trainer, "quantile_reg")) {
@@ -165,8 +172,8 @@ arx_fcast_epi_workflow <- function(
165172
by_key = args_list$quantile_by_key
166173
)
167174
}
168-
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
169-
layer_add_target_date(target_date = target_date)
175+
f <- layer_add_forecast_date(f, forecast_date = args_list$forecast_date) %>%
176+
layer_add_target_date(target_date = args_list$target_date)
170177
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
171178

172179
epi_workflow(r, trainer, f)
@@ -230,7 +237,7 @@ arx_fcast_epi_workflow <- function(
230237
#' arx_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
231238
arx_args_list <- function(
232239
lags = c(0L, 7L, 14L),
233-
ahead = 7L,
240+
ahead = NULL,
234241
n_training = Inf,
235242
forecast_date = NULL,
236243
target_date = NULL,
@@ -247,11 +254,13 @@ arx_args_list <- function(
247254
.lags <- lags
248255
if (is.list(lags)) lags <- unlist(lags)
249256

250-
arg_is_scalar(ahead, n_training, symmetrize, nonneg)
257+
arg_is_scalar(n_training, symmetrize, nonneg)
251258
arg_is_chr(quantile_by_key, allow_empty = TRUE)
252259
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
253260
arg_is_date(forecast_date, target_date, allow_null = TRUE)
254-
arg_is_nonneg_int(ahead, lags)
261+
arg_is_nonneg_int(lags)
262+
arg_is_scalar(ahead, allow_null = TRUE)
263+
arg_is_nonneg_int(ahead, allow_null = TRUE)
255264
arg_is_lgl(symmetrize, nonneg)
256265
arg_is_probabilities(quantile_levels, allow_null = TRUE)
257266
arg_is_pos(n_training)

R/cdc_baseline_forecaster.R

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,25 @@ cdc_baseline_forecaster <- function(
6767
ek <- kill_time_value(keys)
6868
outcome <- rlang::sym(outcome)
6969

70+
args_list$forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
71+
if (is.null(args_list$aheads) && is.null(args_list$target_date)) {
72+
args_list$aheads <- 1:5
73+
# args_list$target_date <- args_list$forecast_date + args_list$aheads
74+
} else if (is.null(args_list$aheads)) {
75+
args_list$aheads <- as.integer(difftime(args_list$target_date, args_list$forecast_date, units = "days"))
76+
} else if (is.null(args_list$target_date)) {
77+
# args_list$target_date <- args_list$forecast_date + args_list$aheads
78+
}
7079

7180
r <- epi_recipe(epi_data) %>%
7281
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
7382
recipes::update_role(!!outcome, new_role = "predictor") %>%
7483
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
7584
step_training_window(n_recent = args_list$n_training)
7685

77-
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
78-
# target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
79-
80-
8186
latest <- get_test_data(
8287
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
83-
forecast_date
88+
args_list$forecast_date
8489
)
8590

8691
f <- frosting() %>%
@@ -93,7 +98,7 @@ cdc_baseline_forecaster <- function(
9398
symmetrize = args_list$symmetrize,
9499
nonneg = args_list$nonneg
95100
) %>%
96-
layer_add_forecast_date(forecast_date = forecast_date) %>%
101+
layer_add_forecast_date(forecast_date = args_list$forecast_date) %>%
97102
layer_unnest(.pred_distn_all)
98103
# layer_add_target_date(target_date = target_date)
99104
if (args_list$nonneg) f <- layer_threshold(f, ".pred")
@@ -105,7 +110,7 @@ cdc_baseline_forecaster <- function(
105110
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
106111
tibble::as_tibble() %>%
107112
dplyr::select(-time_value) %>%
108-
dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency)
113+
dplyr::mutate(target_date = args_list$forecast_date + ahead * args_list$data_frequency)
109114

110115
structure(
111116
list(
@@ -161,7 +166,7 @@ cdc_baseline_forecaster <- function(
161166
#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
162167
cdc_baseline_args_list <- function(
163168
data_frequency = "1 week",
164-
aheads = 1:5,
169+
aheads = NULL,
165170
n_training = Inf,
166171
forecast_date = NULL,
167172
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
@@ -178,7 +183,8 @@ cdc_baseline_args_list <- function(
178183
arg_is_chr(quantile_by_key, allow_empty = TRUE)
179184
arg_is_scalar(forecast_date, allow_null = TRUE)
180185
arg_is_date(forecast_date, allow_null = TRUE)
181-
arg_is_nonneg_int(aheads, nsims)
186+
arg_is_nonneg_int(nsims)
187+
arg_is_nonneg_int(aheads, allow_null = TRUE)
182188
arg_is_lgl(symmetrize, nonneg)
183189
arg_is_probabilities(quantile_levels, allow_null = TRUE)
184190
arg_is_pos(n_training)

man/arx_args_list.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/arx_class_args_list.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/cdc_baseline_args_list.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)