Skip to content

Commit 37756ed

Browse files
committed
fix: address n_train and ahead mismatch #290
1 parent 259ab4f commit 37756ed

File tree

4 files changed

+37
-22
lines changed

4 files changed

+37
-22
lines changed

R/flatline_forecaster.R

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,25 @@ flatline_forecaster <- function(
3939
ek <- kill_time_value(keys)
4040
outcome <- rlang::sym(outcome)
4141

42+
args_list$forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
43+
if (is.null(args_list$ahead) && is.null(args_list$target_date)) {
44+
args_list$ahead <- 7L
45+
args_list$target_date <- args_list$forecast_date + args_list$ahead
46+
} else if (is.null(args_list$ahead)) {
47+
args_list$ahead <- as.integer(difftime(args_list$target_date, args_list$forecast_date, units = "days"))
48+
} else if (is.null(args_list$target_date)) {
49+
args_list$target_date <- args_list$forecast_date + args_list$ahead
50+
}
4251

4352
r <- epi_recipe(epi_data) %>%
4453
step_epi_ahead(!!outcome, ahead = args_list$ahead, skip = TRUE) %>%
4554
recipes::update_role(!!outcome, new_role = "predictor") %>%
4655
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
4756
step_training_window(n_recent = args_list$n_training)
4857

49-
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
50-
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
51-
5258
latest <- get_test_data(
5359
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
54-
forecast_date
60+
args_list$forecast_date
5561
)
5662

5763
f <- frosting() %>%
@@ -61,8 +67,8 @@ flatline_forecaster <- function(
6167
symmetrize = args_list$symmetrize,
6268
by_key = args_list$quantile_by_key
6369
) %>%
64-
layer_add_forecast_date(forecast_date = forecast_date) %>%
65-
layer_add_target_date(target_date = target_date)
70+
layer_add_forecast_date(forecast_date = args_list$forecast_date) %>%
71+
layer_add_target_date(target_date = args_list$target_date)
6672
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
6773

6874
eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
@@ -87,7 +93,6 @@ flatline_forecaster <- function(
8793
}
8894

8995

90-
9196
#' Flatline forecaster argument constructor
9297
#'
9398
#' Constructs a list of arguments for [flatline_forecaster()].
@@ -108,7 +113,7 @@ flatline_forecaster <- function(
108113
#' flatline_args_list(symmetrize = FALSE)
109114
#' flatline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
110115
flatline_args_list <- function(
111-
ahead = 7L,
116+
ahead = NULL,
112117
n_training = Inf,
113118
forecast_date = NULL,
114119
target_date = NULL,
@@ -119,11 +124,11 @@ flatline_args_list <- function(
119124
nafill_buffer = Inf,
120125
...) {
121126
rlang::check_dots_empty()
122-
arg_is_scalar(ahead, n_training)
127+
arg_is_scalar(n_training)
123128
arg_is_chr(quantile_by_key, allow_empty = TRUE)
124129
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
125130
arg_is_date(forecast_date, target_date, allow_null = TRUE)
126-
arg_is_nonneg_int(ahead)
131+
arg_is_nonneg_int(ahead, allow_null = TRUE)
127132
arg_is_lgl(symmetrize, nonneg)
128133
arg_is_probabilities(quantile_levels, allow_null = TRUE)
129134
arg_is_pos(n_training)

R/layer_residual_quantiles.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ slather.layer_residual_quantiles <-
123123
probs = object$quantile_levels, na.rm = TRUE
124124
))
125125
)
126+
# Check for NA
127+
if (any(sapply(r$dstn, is.na))) {
128+
cli::cli_abort("Quantiles could not be calculated due to missing residuals. Check your n_train and ahead values.")
129+
}
126130

127131
estimate <- components$predictions$.pred
128132
res <- tibble::tibble(

man/flatline_args_list.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-target_date_bug.R

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# These tests address #290:
2+
# https://github.com/cmu-delphi/epipredict/issues/290
3+
14
library(dplyr)
25
train <- jhu_csse_daily_subset |>
36
filter(time_value >= as.Date("2021-10-01")) |>
@@ -13,7 +16,6 @@ test_that("flatline determines target_date where forecast_date exists", {
1316
ahead = 1L
1417
)
1518
)
16-
1719
# previously, if target_date existed, it could be
1820
# erroneously incremented by the ahead
1921
expect_identical(
@@ -24,9 +26,6 @@ test_that("flatline determines target_date where forecast_date exists", {
2426
flat$predictions$forecast_date,
2527
rep(as.Date("2021-12-31"), ngeos)
2628
)
27-
28-
# potentially resulted in NA predictions
29-
# see #290 https://github.com/cmu-delphi/epipredict/issues/290
3029
expect_true(all(!is.na(flat$predictions$.pred_distn)))
3130
expect_true(all(!is.na(flat$predictions$.pred)))
3231
})
@@ -50,9 +49,6 @@ test_that("arx_forecaster determines target_date where forecast_date exists", {
5049
arx$predictions$forecast_date,
5150
rep(as.Date("2021-12-31"), ngeos)
5251
)
53-
54-
# potentially resulted in NA predictions
55-
# see #290 https://github.com/cmu-delphi/epipredict/issues/290
5652
expect_true(all(!is.na(arx$predictions$.pred_distn)))
5753
expect_true(all(!is.na(arx$predictions$.pred)))
5854
})
@@ -67,7 +63,6 @@ test_that("arx_classifier determines target_date where forecast_date exists", {
6763
ahead = 1L
6864
)
6965
)
70-
7166
# previously, if target_date existed, it could be
7267
# erroneously incremented by the ahead
7368
expect_identical(
@@ -78,8 +73,19 @@ test_that("arx_classifier determines target_date where forecast_date exists", {
7873
arx$predictions$forecast_date,
7974
rep(as.Date("2021-12-31"), ngeos)
8075
)
81-
82-
# potentially resulted in NA predictions
83-
# see #290 https://github.com/cmu-delphi/epipredict/issues/290
8476
expect_true(all(!is.na(arx$predictions$.pred_class)))
8577
})
78+
79+
test_that("n_training and ahead bugs", {
80+
expect_error(
81+
flatline_forecaster(
82+
train, "dr",
83+
args_list = flatline_args_list(
84+
forecast_date = as.Date("2021-12-31"),
85+
# n_training is less than default ahead which is 7
86+
n_training = 5L
87+
)
88+
),
89+
"Check your n_train and ahead values"
90+
)
91+
})

0 commit comments

Comments
 (0)