Skip to content

Commit c9b4667

Browse files
committed
working cdc baseline
1 parent 16f6c2c commit c9b4667

8 files changed

+419
-16
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ S3method(print,alist)
5252
S3method(print,arx_class)
5353
S3method(print,arx_fcast)
5454
S3method(print,canned_epipred)
55+
S3method(print,cdc_baseline_fcast)
5556
S3method(print,epi_workflow)
5657
S3method(print,flat_fcast)
5758
S3method(print,flatline)
@@ -107,6 +108,8 @@ export(arx_classifier)
107108
export(arx_fcast_epi_workflow)
108109
export(arx_forecaster)
109110
export(bake)
111+
export(cdc_baseline_args_list)
112+
export(cdc_baseline_forecaster)
110113
export(create_layer)
111114
export(default_epi_recipe_blueprint)
112115
export(detect_layer)

R/cdc_baseline_forecaster.R

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#' Predict the future with the most recent value
2+
#'
3+
#' This is a simple forecasting model for
4+
#' [epiprocess::epi_df] data. It uses the most recent observation as the
5+
#' forecast for any future date, and produces intervals by shuffling the quantiles
6+
#' of the residuals of such a "flatline" forecast and incrementing these
7+
#' forward over all available training data.
8+
#'
9+
#' By default, the predictive intervals are computed separately for each
10+
#' combination of `geo_value` in the `epi_data` argument.
11+
#'
12+
#' This forecaster is meant to produce exactly the CDC Baseline used for
13+
#' [COVID19ForecastHub](https://covid19forecasthub.org)
14+
#'
15+
#' @param epi_data An [epiprocess::epi_df]
16+
#' @param outcome A scalar character for the column name we wish to predict.
17+
#' @param args_list A list of additional arguments as created by the
18+
#' [cdc_baseline_args_list()] constructor function.
19+
#'
20+
#' @return A data frame of point and interval forecasts at for all
21+
#' aheads (unique horizons) for each unique combination of `key_vars`.
22+
#' @export
23+
#'
24+
#' @examples
25+
#' library(dplyr)
26+
#' weekly_deaths <- case_death_rate_subset %>%
27+
#' select(geo_value, time_value, death_rate) %>%
28+
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
29+
#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>%
30+
#' select(-pop, -death_rate) %>%
31+
#' group_by(geo_value) %>%
32+
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
33+
#' ungroup() %>%
34+
#' filter(weekdays(time_value) == "Saturday")
35+
#'
36+
#' cdc <- cdc_baseline_forecaster(deaths, "deaths")
37+
#' preds <- pivot_quantiles(cdc$predictions, .pred_distn)
38+
#'
39+
#' if (require(ggplot2)) {
40+
#' forecast_date <- unique(preds$forecast_date)
41+
#' four_states <- c("ca", "pa", "wa", "ny")
42+
#' preds %>%
43+
#' filter(geo_value %in% four_states) %>%
44+
#' ggplot(aes(target_date)) +
45+
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
46+
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
47+
#' geom_line(aes(y = .pred), color = "orange") +
48+
#' geom_line(
49+
#' data = deaths %>% filter(geo_value %in% four_states),
50+
#' aes(x = time_value, y = deaths)
51+
#' ) +
52+
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
53+
#' labs(x = "Date", y = "Weekly deaths") +
54+
#' facet_wrap(~geo_value, scales = "free_y") +
55+
#' theme_bw() +
56+
#' geom_vline(xintercept = forecast_date)
57+
#' }
58+
cdc_baseline_forecaster <- function(
59+
epi_data,
60+
outcome,
61+
args_list = cdc_baseline_args_list()) {
62+
validate_forecaster_inputs(epi_data, outcome, "time_value")
63+
if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) {
64+
cli_stop("args_list was not created using `cdc_baseline_args_list().")
65+
}
66+
keys <- epi_keys(epi_data)
67+
ek <- kill_time_value(keys)
68+
outcome <- rlang::sym(outcome)
69+
70+
71+
r <- epi_recipe(epi_data) %>%
72+
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
73+
recipes::update_role(!!outcome, new_role = "predictor") %>%
74+
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
75+
step_training_window(n_recent = args_list$n_training)
76+
77+
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
78+
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead
79+
80+
81+
latest <- get_test_data(
82+
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
83+
forecast_date
84+
)
85+
86+
f <- frosting() %>%
87+
layer_predict() %>%
88+
layer_cdc_flatline_quantiles(
89+
aheads = args_list$aheads,
90+
quantile_levels = args_list$quantile_levels,
91+
nsims = args_list$nsims,
92+
by_key = args_list$quantile_by_key,
93+
symmetrize = args_list$symmetrize,
94+
nonneg = args_list$nonneg
95+
) %>%
96+
layer_add_forecast_date(forecast_date = forecast_date) %>%
97+
layer_unnest(.pred_distn_all)
98+
# layer_add_target_date(target_date = target_date)
99+
if (args_list$nonneg) f <- layer_threshold(f, ".pred")
100+
101+
eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
102+
103+
wf <- epi_workflow(r, eng, f)
104+
wf <- generics::fit(wf, epi_data)
105+
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
106+
tibble::as_tibble() %>%
107+
dplyr::select(-time_value) %>%
108+
dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency)
109+
110+
structure(
111+
list(
112+
predictions = preds,
113+
epi_workflow = wf,
114+
metadata = list(
115+
training = attr(epi_data, "metadata"),
116+
forecast_created = Sys.time()
117+
)
118+
),
119+
class = c("cdc_baseline_fcast", "canned_epipred")
120+
)
121+
}
122+
123+
124+
125+
#' CDC baseline forecaster argument constructor
126+
#'
127+
#' Constructs a list of arguments for [cdc_baseline_forecaster()].
128+
#'
129+
#' @inheritParams arx_args_list
130+
#' @param data_frequency Integer or string. This describes the frequency of the
131+
#' input `epi_df`. For typical FluSight forecasts, this would be `"1 week"`.
132+
#' Allowable arguments are integers (taken to mean numbers of days) or a
133+
#' string like `"7 days"` or `"2 weeks"`. Currently, all other periods
134+
#' (other than days or weeks) result in an error.
135+
#' @param aheads Integer vector. Unlike [arx_forecaster()], this doesn't have
136+
#' any effect on the predicted values.
137+
#' Predictions are always the most recent observation. This determines the
138+
#' set of prediction horizons for [layer_cdc_flatline_quantiles()]`. It interacts
139+
#' with the `data_frequency` argument. So, for example, if the data is daily
140+
#' and you want forecasts for 1:4 days ahead, then you would use `1:4`. However,
141+
#' if you want one-week predictions, you would set this as `c(7, 14, 21, 28)`.
142+
#' But if `data_frequency` is `"1 week"`, then you would set it as `1:4`.
143+
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
144+
#' prediction intervals. These are created by computing the quantiles of
145+
#' training residuals. A `NULL` value will result in point forecasts only.
146+
#' @param nsims Positive integer. The number of draws from the empirical CDF.
147+
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting
148+
#' in linear interpolation on the X scale. This is achieved with
149+
#' [stats::quantile()] Type 7 (the default for that function).
150+
#' @param nonneg Logical. Force all predictive intervals be non-negative.
151+
#' Because non-negativity is forced _before_ propagating forward, this
152+
#' has slightly different behaviour than would occur if using
153+
#' [layer_threshold_preds()].
154+
#'
155+
#' @return A list containing updated parameter choices with class `cdc_flat_fcast`.
156+
#' @export
157+
#'
158+
#' @examples
159+
#' cdc_baseline_args_list()
160+
#' cdc_baseline_args_list(symmetrize = FALSE)
161+
#' cdc_baseline_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
162+
cdc_baseline_args_list <- function(
163+
data_frequency = "1 week",
164+
aheads = 1:4,
165+
n_training = Inf,
166+
forecast_date = NULL,
167+
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
168+
nsims = 1e3L,
169+
symmetrize = TRUE,
170+
nonneg = TRUE,
171+
quantile_by_key = "geo_value",
172+
nafill_buffer = Inf) {
173+
arg_is_scalar(n_training, nsims, data_frequency)
174+
data_frequency <- parse_period(data_frequency)
175+
arg_is_pos_int(data_frequency)
176+
arg_is_chr(quantile_by_key, allow_empty = TRUE)
177+
arg_is_scalar(forecast_date, allow_null = TRUE)
178+
arg_is_date(forecast_date, allow_null = TRUE)
179+
arg_is_nonneg_int(aheads, nsims)
180+
arg_is_lgl(symmetrize, nonneg)
181+
arg_is_probabilities(quantile_levels, allow_null = TRUE)
182+
arg_is_pos(n_training)
183+
if (is.finite(n_training)) arg_is_pos_int(n_training)
184+
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
185+
186+
structure(
187+
enlist(
188+
data_frequency,
189+
aheads,
190+
n_training,
191+
forecast_date,
192+
quantile_levels,
193+
nsims,
194+
symmetrize,
195+
nonneg,
196+
quantile_by_key,
197+
nafill_buffer
198+
),
199+
class = c("cdc_baseline_fcast", "alist")
200+
)
201+
}
202+
203+
#' @export
204+
print.cdc_baseline_fcast <- function(x, ...) {
205+
name <- "CDC Baseline"
206+
NextMethod(name = name, ...)
207+
}
208+
209+
parse_period <- function(x) {
210+
arg_is_scalar(x)
211+
if (is.character(x)) {
212+
x <- unlist(strsplit(x, " "))
213+
if (length(x) == 1L) x <- as.numeric(x)
214+
if (length(x) == 2L) {
215+
mult <- substr(x[2], 1, 3)
216+
mult <- switch(
217+
mult,
218+
day = 1L,
219+
wee = 7L,
220+
cli::cli_abort("incompatible timespan in `aheads`.")
221+
)
222+
x <- as.numeric(x[1]) * mult
223+
}
224+
if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.")
225+
}
226+
stopifnot(rlang::is_integerish(x))
227+
as.integer(x)
228+
}

R/layer_cdc_flatline_quantiles.R

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ layer_cdc_flatline_quantiles <- function(
9797
frosting,
9898
...,
9999
aheads = 1:4,
100-
quantiles = c(.01, .025, 1:19 / 20, .975, .99),
100+
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
101101
nsims = 1e3,
102102
by_key = "geo_value",
103103
symmetrize = FALSE,
@@ -106,7 +106,7 @@ layer_cdc_flatline_quantiles <- function(
106106
rlang::check_dots_empty()
107107

108108
arg_is_int(aheads)
109-
arg_is_probabilities(quantiles)
109+
arg_is_probabilities(quantile_levels, allow_null = TRUE)
110110
arg_is_pos_int(nsims)
111111
arg_is_scalar(nsims)
112112
arg_is_chr_scalar(id)
@@ -117,7 +117,7 @@ layer_cdc_flatline_quantiles <- function(
117117
frosting,
118118
layer_cdc_flatline_quantiles_new(
119119
aheads = aheads,
120-
quantiles = quantiles,
120+
quantile_levels = quantile_levels,
121121
nsims = nsims,
122122
by_key = by_key,
123123
symmetrize = symmetrize,
@@ -129,7 +129,7 @@ layer_cdc_flatline_quantiles <- function(
129129

130130
layer_cdc_flatline_quantiles_new <- function(
131131
aheads,
132-
quantiles,
132+
quantile_levels,
133133
nsims,
134134
by_key,
135135
symmetrize,
@@ -138,7 +138,7 @@ layer_cdc_flatline_quantiles_new <- function(
138138
layer(
139139
"cdc_flatline_quantiles",
140140
aheads = aheads,
141-
quantiles = quantiles,
141+
quantile_levels = quantile_levels,
142142
nsims = nsims,
143143
by_key = by_key,
144144
symmetrize = symmetrize,
@@ -150,6 +150,7 @@ layer_cdc_flatline_quantiles_new <- function(
150150
#' @export
151151
slather.layer_cdc_flatline_quantiles <-
152152
function(object, components, workflow, new_data, ...) {
153+
if (is.null(object$quantile_levels)) return(components)
153154
the_fit <- workflows::extract_fit_parsnip(workflow)
154155
if (!inherits(the_fit, "_flatline")) {
155156
cli::cli_warn(
@@ -213,7 +214,7 @@ slather.layer_cdc_flatline_quantiles <-
213214
dplyr::rowwise() %>%
214215
dplyr::mutate(
215216
.pred_distn_all = propogate_samples(
216-
.resid, .pred, object$quantiles,
217+
.resid, .pred, object$quantile_levels,
217218
object$aheads, object$nsim, object$symmetrize, object$nonneg
218219
)
219220
) %>%
@@ -229,7 +230,7 @@ slather.layer_cdc_flatline_quantiles <-
229230
}
230231

231232
propogate_samples <- function(
232-
r, p, quantiles, aheads, nsim, symmetrize, nonneg) {
233+
r, p, quantile_levels, aheads, nsim, symmetrize, nonneg) {
233234
max_ahead <- max(aheads)
234235
samp <- quantile(r, probs = c(0, seq_len(nsim - 1)) / (nsim - 1), na.rm = TRUE)
235236
res <- list()
@@ -254,7 +255,7 @@ propogate_samples <- function(
254255
list(tibble::tibble(
255256
ahead = aheads,
256257
.pred_distn = map_vec(
257-
res, ~ dist_quantiles(quantile(.x, quantiles), tau = quantiles)
258+
res, ~ dist_quantiles(quantile(.x, quantile_levels), quantile_levels)
258259
)
259260
))
260261
}

0 commit comments

Comments
 (0)