Skip to content

Commit 76a1111

Browse files
committed
fix: #290
1 parent c94d5f9 commit 76a1111

File tree

5 files changed

+93
-5
lines changed

5 files changed

+93
-5
lines changed

R/arx_classifier.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ arx_class_epi_workflow <- function(
197197
}
198198

199199
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
200-
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
200+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
201201

202202
# --- postprocessor
203203
f <- frosting() %>% layer_predict() # %>% layer_naomit()

R/arx_forecaster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ arx_fcast_epi_workflow <- function(
143143
}
144144

145145
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
146-
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
146+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
147147

148148
# --- postprocessor
149149
f <- frosting() %>% layer_predict() # %>% layer_naomit()

R/cdc_baseline_forecaster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ cdc_baseline_forecaster <- function(
7575
step_training_window(n_recent = args_list$n_training)
7676

7777
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
78-
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead
78+
# target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
7979

8080

8181
latest <- get_test_data(

R/flatline_forecaster.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ flatline_forecaster <- function(
4747
step_training_window(n_recent = args_list$n_training)
4848

4949
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
50-
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
51-
50+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
5251

5352
latest <- get_test_data(
5453
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,

tests/testthat/test-target_date_bug.R

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
library(dplyr)
2+
train <- jhu_csse_daily_subset |>
3+
filter(time_value >= as.Date("2021-10-01")) |>
4+
select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av)
5+
ngeos <- n_distinct(train$geo_value)
6+
7+
test_that("flatline determines target_date where forecast_date exists", {
8+
9+
flat <- flatline_forecaster(
10+
train, "dr",
11+
args_list = flatline_args_list(
12+
forecast_date = as.Date("2021-12-31"),
13+
target_date = as.Date("2022-01-01"),
14+
ahead = 1L
15+
)
16+
)
17+
18+
# previously, if target_date existed, it could be
19+
# erroneously incremented by the ahead
20+
expect_identical(
21+
flat$predictions$target_date,
22+
rep(as.Date("2022-01-01"), ngeos)
23+
)
24+
expect_identical(
25+
flat$predictions$forecast_date,
26+
rep(as.Date("2021-12-31"), ngeos)
27+
)
28+
29+
# potentially resulted in NA predictions
30+
# see #290 https://github.com/cmu-delphi/epipredict/issues/290
31+
expect_true(all(!is.na(flat$predictions$.pred_distn)))
32+
expect_true(all(!is.na(flat$predictions$.pred)))
33+
})
34+
35+
test_that("arx_forecaster determines target_date where forecast_date exists", {
36+
37+
arx <- arx_forecaster(
38+
train, "dr", c("dr", "cr"),
39+
args_list = arx_args_list(
40+
forecast_date = as.Date("2021-12-31"),
41+
target_date = as.Date("2022-01-01"),
42+
ahead = 1L
43+
)
44+
)
45+
# previously, if target_date existed, it could be
46+
# erroneously incremented by the ahead
47+
expect_identical(
48+
arx$predictions$target_date,
49+
rep(as.Date("2022-01-01"), ngeos)
50+
)
51+
expect_identical(
52+
arx$predictions$forecast_date,
53+
rep(as.Date("2021-12-31"), ngeos)
54+
)
55+
56+
# potentially resulted in NA predictions
57+
# see #290 https://github.com/cmu-delphi/epipredict/issues/290
58+
expect_true(all(!is.na(arx$predictions$.pred_distn)))
59+
expect_true(all(!is.na(arx$predictions$.pred)))
60+
})
61+
62+
test_that("arx_classifier determines target_date where forecast_date exists", {
63+
64+
65+
arx <- arx_classifier(
66+
train, "dr", c("dr"),
67+
trainer = parsnip::nearest_neighbor(mode = "classification"),
68+
args_list = arx_class_args_list(
69+
forecast_date = as.Date("2021-12-31"),
70+
target_date = as.Date("2022-01-01"),
71+
ahead = 1L
72+
)
73+
)
74+
75+
# previously, if target_date existed, it could be
76+
# erroneously incremented by the ahead
77+
expect_identical(
78+
arx$predictions$target_date,
79+
rep(as.Date("2022-01-01"), ngeos)
80+
)
81+
expect_identical(
82+
arx$predictions$forecast_date,
83+
rep(as.Date("2021-12-31"), ngeos)
84+
)
85+
86+
# potentially resulted in NA predictions
87+
# see #290 https://github.com/cmu-delphi/epipredict/issues/290
88+
expect_true(all(!is.na(arx$predictions$.pred_class)))
89+
})

0 commit comments

Comments
 (0)