Skip to content

Commit 9d5cc5d

Browse files
authored
Merge pull request #300 from cmu-delphi/djm/fix-290
fix: #290
2 parents 98bdc85 + 3c12fb7 commit 9d5cc5d

11 files changed

+143
-43
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
^.git-blame-ignore-revs$
1919
^doc$
2020
^Meta$
21+
^.lintr$

.lintr

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
linters: linters_with_defaults(
2+
line_length_linter(120),
3+
cyclocomp_linter = NULL,
4+
object_length_linter(length = 40L)
5+
)
6+
exclusions: list(
7+
"renv",
8+
"venv"
9+
)

NEWS.md

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,35 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
44

55
# epipredict 0.1
66

7-
- add `check_enough_train_data` that will error if training data is too small
8-
- added `check_enough_train_data` to `arx_forecaster`
9-
- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
10-
- rename the `dist_quantiles()` to be more descriptive, breaking change
11-
- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
12-
- add `pivot_quantiles_wider()` for easier plotting
13-
- add complement `pivot_quantiles_longer()`
14-
- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`
15-
- add `smooth_quantile_reg()`
16-
- improved printing of various methods / internals
17-
- canned forecasters get a class
18-
- fixed quantile bug in `flatline_forecaster()`
19-
- add functionality to output the unfit workflow from the canned forecasters
20-
- add quantile_reg()
21-
- clean up documentation bugs
22-
- add smooth_quantile_reg()
23-
- add classifier
24-
- training window step debugged
25-
- `min_train_window` argument removed from canned forecasters
26-
- add forecasters
27-
- implement postprocessing
28-
- vignettes avaliable
29-
- arx_forecaster
30-
- pkgdown
31-
- Publish public for easy navigation
32-
- Two simple forecasters as test beds
33-
- Working vignette
34-
- use `checkmate` for input validation
35-
- refactor quantile extrapolation (possibly creates different results)
36-
- force `target_date` + `forecast_date` handling to match the time_type of
37-
the epi_df. allows for annual and weekly data
7+
- `layer_residual_quantiles()` will now error if any of the residual quantiles are NA
8+
- add `check_enough_train_data` that will error if training data is too small
9+
- added `check_enough_train_data` to `arx_forecaster`
10+
- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
11+
- rename the `dist_quantiles()` to be more descriptive, breaking change
12+
- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
13+
- add `pivot_quantiles_wider()` for easier plotting
14+
- add complement `pivot_quantiles_longer()`
15+
- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`
16+
- add `smooth_quantile_reg()`
17+
- improved printing of various methods / internals
18+
- canned forecasters get a class
19+
- fixed quantile bug in `flatline_forecaster()`
20+
- add functionality to output the unfit workflow from the canned forecasters
21+
- add quantile_reg()
22+
- clean up documentation bugs
23+
- add smooth_quantile_reg()
24+
- add classifier
25+
- training window step debugged
26+
- `min_train_window` argument removed from canned forecasters
27+
- add forecasters
28+
- implement postprocessing
29+
- vignettes avaliable
30+
- arx_forecaster
31+
- pkgdown
32+
- Publish public for easy navigation
33+
- Two simple forecasters as test beds
34+
- Working vignette
35+
- use `checkmate` for input validation
36+
- refactor quantile extrapolation (possibly creates different results)
37+
- force `target_date` + `forecast_date` handling to match the time_type of
38+
the epi_df. allows for annual and weekly data

R/arx_classifier.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ arx_class_epi_workflow <- function(
197197
}
198198

199199
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
200-
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
200+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
201201

202202
# --- postprocessor
203203
f <- frosting() %>% layer_predict() # %>% layer_naomit()

R/arx_forecaster.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ arx_fcast_epi_workflow <- function(
143143
}
144144

145145
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
146-
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
146+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
147147

148148
# --- postprocessor
149149
f <- frosting() %>% layer_predict() # %>% layer_naomit()
@@ -289,8 +289,8 @@ print.arx_fcast <- function(x, ...) {
289289
}
290290

291291
compare_quantile_args <- function(alist, tlist) {
292-
default_alist <- eval(formals(arx_args_list)$quantile_level)
293-
default_tlist <- eval(formals(quantile_reg)$quantile_level)
292+
default_alist <- eval(formals(arx_args_list)$quantile_levels)
293+
default_tlist <- eval(formals(quantile_reg)$quantile_levels)
294294
if (setequal(alist, default_alist)) {
295295
if (setequal(tlist, default_tlist)) {
296296
return(sort(unique(union(alist, tlist))))

R/cdc_baseline_forecaster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ cdc_baseline_forecaster <- function(
7575
step_training_window(n_recent = args_list$n_training)
7676

7777
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
78-
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead
78+
# target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
7979

8080

8181
latest <- get_test_data(

R/flatline_forecaster.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ flatline_forecaster <- function(
4747
step_training_window(n_recent = args_list$n_training)
4848

4949
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
50-
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
51-
50+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
5251

5352
latest <- get_test_data(
5453
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,

R/layer_residual_quantiles.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ slather.layer_residual_quantiles <-
123123
probs = object$quantile_levels, na.rm = TRUE
124124
))
125125
)
126+
# Check for NA
127+
if (any(sapply(r$dstn, is.na))) {
128+
cli::cli_abort(c(
129+
"Residual quantiles could not be calculated due to missing residuals.",
130+
i = "This may be due to `n_train` < `ahead` in your {.cls epi_recipe}."
131+
))
132+
}
126133

127134
estimate <- components$predictions$.pred
128135
res <- tibble::tibble(

tests/testthat/test-dist_quantiles.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,16 @@ test_that("single dist_quantiles works, quantiles are accessible", {
2828

2929
test_that("quantile extrapolator works", {
3030
dstn <- dist_normal(c(10, 2), c(5, 10))
31-
qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75))
31+
qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75))
3232
expect_s3_class(qq, "distribution")
3333
expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles")
34-
expect_length(parameters(qq[1])$q[[1]], 3L)
35-
34+
expect_length(parameters(qq[1])$quantile_levels[[1]], 3L)
3635

3736
dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
38-
qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75))
37+
qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75))
3938
expect_s3_class(qq, "distribution")
4039
expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles")
41-
expect_length(parameters(qq[1])$q[[1]], 7L)
40+
expect_length(parameters(qq[1])$quantile_levels[[1]], 7L)
4241
})
4342

4443
test_that("small deviations of quantile requests work", {

tests/testthat/test-layer_residual_quantiles.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,10 @@ test_that("Canned forecasters work with / without", {
9898
)
9999
)
100100
})
101+
102+
test_that("flatline_forecaster correctly errors when n_training < ahead", {
103+
expect_error(
104+
flatline_forecaster(jhu, "death_rate", args_list = flatline_args_list(ahead = 10, n_training = 9)),
105+
"This may be due to `n_train` < `ahead`"
106+
)
107+
})

tests/testthat/test-target_date_bug.R

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# These tests address #290:
2+
# https://github.com/cmu-delphi/epipredict/issues/290
3+
4+
library(dplyr)
5+
train <- jhu_csse_daily_subset |>
6+
filter(time_value >= as.Date("2021-10-01")) |>
7+
select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av)
8+
ngeos <- n_distinct(train$geo_value)
9+
10+
test_that("flatline determines target_date where forecast_date exists", {
11+
flat <- flatline_forecaster(
12+
train, "dr",
13+
args_list = flatline_args_list(
14+
forecast_date = as.Date("2021-12-31"),
15+
target_date = as.Date("2022-01-01"),
16+
ahead = 1L
17+
)
18+
)
19+
# previously, if target_date existed, it could be
20+
# erroneously incremented by the ahead
21+
expect_identical(
22+
flat$predictions$target_date,
23+
rep(as.Date("2022-01-01"), ngeos)
24+
)
25+
expect_identical(
26+
flat$predictions$forecast_date,
27+
rep(as.Date("2021-12-31"), ngeos)
28+
)
29+
expect_true(all(!is.na(flat$predictions$.pred_distn)))
30+
expect_true(all(!is.na(flat$predictions$.pred)))
31+
})
32+
33+
test_that("arx_forecaster determines target_date where forecast_date exists", {
34+
arx <- arx_forecaster(
35+
train, "dr", c("dr", "cr"),
36+
args_list = arx_args_list(
37+
forecast_date = as.Date("2021-12-31"),
38+
target_date = as.Date("2022-01-01"),
39+
ahead = 1L
40+
)
41+
)
42+
# previously, if target_date existed, it could be
43+
# erroneously incremented by the ahead
44+
expect_identical(
45+
arx$predictions$target_date,
46+
rep(as.Date("2022-01-01"), ngeos)
47+
)
48+
expect_identical(
49+
arx$predictions$forecast_date,
50+
rep(as.Date("2021-12-31"), ngeos)
51+
)
52+
expect_true(all(!is.na(arx$predictions$.pred_distn)))
53+
expect_true(all(!is.na(arx$predictions$.pred)))
54+
})
55+
56+
test_that("arx_classifier determines target_date where forecast_date exists", {
57+
arx <- arx_classifier(
58+
train, "dr", c("dr"),
59+
trainer = parsnip::boost_tree(mode = "classification", trees = 5),
60+
args_list = arx_class_args_list(
61+
forecast_date = as.Date("2021-12-31"),
62+
target_date = as.Date("2022-01-01"),
63+
ahead = 1L
64+
)
65+
)
66+
# previously, if target_date existed, it could be
67+
# erroneously incremented by the ahead
68+
expect_identical(
69+
arx$predictions$target_date,
70+
rep(as.Date("2022-01-01"), ngeos)
71+
)
72+
expect_identical(
73+
arx$predictions$forecast_date,
74+
rep(as.Date("2021-12-31"), ngeos)
75+
)
76+
expect_true(all(!is.na(arx$predictions$.pred_class)))
77+
})

0 commit comments

Comments
 (0)