Skip to content

Commit 9d27064

Browse files
committed
feat+test: make a functional flatline forecaster
1 parent 9767d90 commit 9d27064

File tree

9 files changed

+184
-51
lines changed

9 files changed

+184
-51
lines changed

.Rbuildignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
^renv$
22
^renv\.lock$
33
^LICENSE\.md$
4+
^.lintr$
5+
^.renvignore$
6+
^.github$

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export(confirm_insufficient_data)
99
export(covidhub_probs)
1010
export(evaluate_predictions)
1111
export(extend_ahead)
12+
export(flatline_fc)
1213
export(forecaster_pred)
1314
export(format_storage)
1415
export(interval_coverage)
@@ -39,6 +40,7 @@ importFrom(purrr,map)
3940
importFrom(purrr,transpose)
4041
importFrom(rlang,.data)
4142
importFrom(rlang,quo)
43+
importFrom(rlang,sym)
4244
importFrom(rlang,syms)
4345
importFrom(tibble,tibble)
4446
importFrom(tidyr,pivot_wider)

R/forecaster.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ arx_postprocess <- function(postproc,
123123
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
124124
}
125125

126-
postproc %<>% layer_naomit(dplyr::starts_with(".pred"))
127-
postproc %<>% layer_add_forecast_date(forecast_date = forecast_date) %>%
126+
postproc %<>% layer_naomit(dplyr::starts_with(".pred")) %>%
128127
layer_add_target_date(target_date = target_date)
129128
return(postproc)
130129
}

R/forecaster_flatline.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#' flatline forecaster (aka baseline)
2+
#' @description
3+
#' a minimal forecaster whose median is just the last value
4+
#' does not support `lags` as a parameter, but otherwise has the same parameters as `arx_forecaster`
5+
#' @inheritParams scaled_pop
6+
#' @importFrom rlang sym
7+
#' @export
8+
flatline_fc <- function(epi_data,
9+
outcome,
10+
extra_sources = "",
11+
ahead = 1,
12+
trainer = parsnip::linear_reg(),
13+
levels = covidhub_probs(),
14+
...) {
15+
# perform any preprocessing not supported by epipredict
16+
# one that every forecaster will need to handle: how to manage max(time_value)
17+
# that's older than the `as_of` date
18+
epidataAhead <- extend_ahead(epi_data, ahead)
19+
# see latency_adjusting for other examples
20+
# this next part is basically unavoidable boilerplate you'll want to copy
21+
epi_data <- epidataAhead[[1]]
22+
effective_ahead <- epidataAhead[[2]]
23+
args_input <- list(...)
24+
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
25+
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
26+
null_result <- tibble(
27+
geo_value = character(),
28+
forecast_date = Date(),
29+
target_end_date = Date(),
30+
quantile = numeric(),
31+
value = numeric()
32+
)
33+
return(null_result)
34+
}
35+
args_input[["ahead"]] <- effective_ahead
36+
args_input[["levels"]] <- levels
37+
args_list <- do.call(flatline_args_list, args_input)
38+
# if you want to ignore extra_sources, setting predictors is the way to do it
39+
predictors <- c(outcome, extra_sources)
40+
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, NULL, args_list)
41+
args_list <- argsPredictorsTrainer[[1]]
42+
predictors <- argsPredictorsTrainer[[2]]
43+
# end of the copypasta
44+
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
45+
# epipredict
46+
47+
# since this is just the flatline, we don't need much of anything
48+
res <- flatline_forecaster(epi_data, outcome = outcome, args_list = args_list)
49+
true_forecast_date <- attributes(epi_data)$metadata$as_of
50+
pred <- format_storage(res$predictions, true_forecast_date)
51+
# (geo_value, forecast_date, target_end_date, quantile, value)
52+
# finally, any postprocessing not supported by epipredict e.g. calibration
53+
return(pred)
54+
}

R/forecaster_scaled_pop.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
scaled_pop <- function(epi_data,
4646
outcome,
4747
extra_sources = "",
48-
ahead=1,
48+
ahead = 1,
4949
pop_scaling = TRUE,
5050
trainer = parsnip::linear_reg(),
5151
levels = covidhub_probs(),

R/formatters.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
1818
.dstn = nested_quantiles(.pred_distn)
1919
) %>%
2020
unnest(.dstn) %>%
21-
select(-.pred_distn, -.pred, -time_value) %>%
21+
select(-any_of(c(".pred_distn", ".pred", "time_value"))) %>%
2222
rename(quantile = tau, value = q, target_end_date = target_date) %>%
2323
relocate(geo_value, forecast_date, target_end_date, quantile, value)
2424
}

man/flatline_fc.Rd

Lines changed: 48 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-example_spec.R

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,55 @@
1-
test_that("scaled_pop", {
2-
library(epipredict)
3-
jhu <- case_death_rate_subset %>%
4-
dplyr::filter(time_value >= as.Date("2021-12-01"))
5-
# the as_of for this is wildly far in the future
6-
attributes(jhu)$metadata$as_of <- max(jhu$time_value) + 3
7-
expect_warning(res <- scaled_pop(jhu, "case_rate", c("death_rate"), -2L))
8-
expect_equal(
9-
names(res),
10-
c("geo_value", "forecast_date", "target_end_date", "quantile", "value")
11-
)
12-
expect_true(all(
13-
res$target_end_date ==
14-
as.Date("2022-01-01")
15-
))
16-
# confirm scaling produces different results
17-
expect_warning(res_unscaled <- scaled_pop(jhu,
18-
"case_rate",
19-
c("death_rate"),
20-
-2L,
21-
pop_scaling = FALSE
22-
))
23-
expect_false(res_unscaled %>%
24-
full_join(res,
25-
by = join_by(geo_value, forecast_date, target_end_date, quantile),
26-
suffix = c(".unscaled", ".scaled")
27-
) %>%
28-
mutate(equal = value.unscaled == value.scaled) %>%
29-
summarize(all(equal)) %>% pull(`all(equal)`))
30-
# confirming that it produces exactly the same result as arx_forecaster
31-
# test case where extra_sources is "empty"
32-
expect_warning(scaled_pop(
33-
jhu,
34-
"case_rate",
35-
c(""),
36-
1L
37-
))
38-
# test case where the epi_df is empty
39-
null_jhu <- jhu %>% filter(time_value < as.Date("0009-01-01"))
40-
expect_no_error(null_res <- scaled_pop(null_jhu, "case_rate", c("death_rate")))
41-
null_res <- scaled_pop(null_jhu, "case_rate", c("death_rate"))
42-
expect_identical(names(null_res), names(res))
43-
expect_equal(nrow(null_res), 0)
44-
expect_identical(null_res, tibble(geo_value = character(), forecast_date = Date(), target_end_date = Date(), quantile = numeric(), value = numeric()))
45-
})
1+
# TODO better way to do this than copypasta
2+
forecasters <- list(
3+
c("scaled_pop", scaled_pop),
4+
c("flatline_fc", flatline_fc)
5+
)
6+
forecaster <- c("flatline", flatline_fc)
7+
for (forecaster in forecasters) {
8+
test_that(forecaster[[1]], {
9+
jhu <- case_death_rate_subset %>%
10+
dplyr::filter(time_value >= as.Date("2021-12-01"))
11+
# the as_of for this is wildly far in the future
12+
attributes(jhu)$metadata$as_of <- max(jhu$time_value) + 3
13+
res <- forecaster[[2]](jhu, "case_rate", c("death_rate"), -2L)
14+
expect_equal(
15+
names(res),
16+
c("geo_value", "forecast_date", "target_end_date", "quantile", "value")
17+
)
18+
expect_true(all(
19+
res$target_end_date ==
20+
as.Date("2022-01-01")
21+
))
22+
# any forecaster specific tests
23+
if (forecaster[[1]] == "scaled_pop") {
24+
# confirm scaling produces different results
25+
res_unscaled <- forecaster[[2]](jhu,
26+
"case_rate",
27+
c("death_rate"),
28+
-2L,
29+
pop_scaling = FALSE
30+
)
31+
expect_false(res_unscaled %>%
32+
full_join(res,
33+
by = join_by(geo_value, forecast_date, target_end_date, quantile),
34+
suffix = c(".unscaled", ".scaled")
35+
) %>%
36+
mutate(equal = value.unscaled == value.scaled) %>%
37+
summarize(all(equal)) %>% pull(`all(equal)`))
38+
}
39+
# TODO confirming that it produces exactly the same result as arx_forecaster
40+
# test case where extra_sources is "empty"
41+
forecaster[[2]](
42+
jhu,
43+
"case_rate",
44+
c(""),
45+
1L
46+
)
47+
# test case where the epi_df is empty
48+
null_jhu <- jhu %>% filter(time_value < as.Date("0009-01-01"))
49+
expect_no_error(null_res <- forecaster[[2]](null_jhu, "case_rate", c("death_rate")))
50+
null_res <- forecaster[[2]](null_jhu, "case_rate", c("death_rate"))
51+
expect_identical(names(null_res), names(res))
52+
expect_equal(nrow(null_res), 0)
53+
expect_identical(null_res, tibble(geo_value = character(), forecast_date = Date(), target_end_date = Date(), quantile = numeric(), value = numeric()))
54+
})
55+
}

tests/testthat/test-forecasters.R

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ forecasters <- tribble(
33
~forecaster, ~extra_params, ~extra_params_names,
44
scaled_pop, list(1, TRUE), list("ahead", "pop_scaling"),
55
scaled_pop, list(1, FALSE), list("ahead", "pop_scaling"),
6+
flatline_fc, list(1), list("ahead")
67
)
78
synth_mean <- 25
89
synth_sd <- 2
@@ -16,6 +17,7 @@ constant <- as_epi_archive(tibble(
1617
version = simple_dates,
1718
a = synth_mean + approx_zero
1819
))
20+
ii <- 3
1921
# wrap a call that is made quite frequently
2022
# n_training_pad is set to avoid warnings from the trainer
2123
get_pred <- function(dataset,
@@ -42,6 +44,7 @@ test_that("constant", {
4244
a = 4 * synth_mean + approx_zero
4345
)
4446
))
47+
different_constants
4548
for (ii in 1:nrow(forecasters)) {
4649
res <- get_pred(different_constants, ii)
4750

@@ -77,7 +80,9 @@ test_that("white noise", {
7780
values <- res %>%
7881
filter(quantile == .5) %>%
7982
pull(value)
80-
expect_true(sd(values) < synth_sd)
83+
84+
# shouldn't expect the sample sd to actually match the true sd exactly, so giving it some leeway
85+
expect_true(sd(values) < 2*synth_sd)
8186
# how much is each quantile off from the expected value?
8287
# should be fairly generous here, we just want the right order of magnitude
8388
quantile_deviation <- res %>%
@@ -106,6 +111,7 @@ test_that("delayed state", {
106111
a = synth_mean + approx_zero
107112
)
108113
))
114+
missing_state$DT %>% filter(geo_value == "ca")
109115
for (ii in seq_len(nrow(forecasters))) {
110116
expect_no_error(res <- get_pred(missing_state, ii))
111117
expect_equal(length(unique(res$geo_value)), 2)
@@ -119,7 +125,15 @@ test_that("delayed state", {
119125
counts_al <- counts %>%
120126
filter(geo_value == "al") %>%
121127
pull(n)
122-
expect_true(counts_al > counts_ca)
128+
counts_al
129+
counts_ca
130+
res %>% filter(geo_value == "ca" & quantile == .5)
131+
# flatline is more aggressive about forecasting
132+
if (identical(forecasters$forecaster[[ii]], flatline_fc)) {
133+
expect_true(counts_al == counts_ca)
134+
} else {
135+
expect_true(counts_al > counts_ca)
136+
}
123137
expect_true(sum(state_delay == 0) > counts_ca)
124138
expect_true(counts_ca > 0)
125139
}
@@ -139,12 +153,15 @@ test_that("linear", {
139153
)
140154
)
141155
for (ii in seq_len(nrow(forecasters))) {
156+
#flatline will definitely fail this, so it's exempt
157+
if (!identical(forecasters$forecaster[[ii]], flatline_fc)) {
142158
res <- get_pred(linear, ii)
143159
# make sure that the median is on the sloped line
144160
median_err <- res %>%
145161
filter(quantile == .5) %>%
146162
mutate(err = value - as.integer(target_end_date - start_date + 1), .keep = "none") %>%
147163
mutate(is_right = near(err,0, tol=tiny_sd ^ 0.5), .keep = "none")
148164
expect_true(all(median_err))
165+
}
149166
}
150167
})

0 commit comments

Comments
 (0)