Skip to content

Commit 3ef79b9

Browse files
authored
Merge pull request #245 from cmu-delphi/cdc-baseline
Cdc baseline
2 parents dacf6e7 + bdbd3ee commit 3ef79b9

8 files changed

+248
-31
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ S3method(extrapolate_quantiles,dist_default)
3434
S3method(extrapolate_quantiles,dist_quantiles)
3535
S3method(extrapolate_quantiles,distribution)
3636
S3method(fit,epi_workflow)
37+
S3method(flusight_hub_formatter,canned_epipred)
38+
S3method(flusight_hub_formatter,data.frame)
3739
S3method(format,dist_quantiles)
3840
S3method(is.na,dist_quantiles)
3941
S3method(is.na,distribution)
@@ -126,6 +128,7 @@ export(fit)
126128
export(flatline)
127129
export(flatline_args_list)
128130
export(flatline_forecaster)
131+
export(flusight_hub_formatter)
129132
export(frosting)
130133
export(get_test_data)
131134
export(grab_names)

R/cdc_baseline_forecaster.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
#' This forecaster is meant to produce exactly the CDC Baseline used for
1313
#' [COVID19ForecastHub](https://covid19forecasthub.org)
1414
#'
15-
#' @param epi_data An [epiprocess::epi_df]
15+
#' @param epi_data An [`epiprocess::epi_df`]
1616
#' @param outcome A scalar character for the column name we wish to predict.
1717
#' @param args_list A list of additional arguments as created by the
1818
#' [cdc_baseline_args_list()] constructor function.
1919
#'
20-
#' @return A data frame of point and interval forecasts at for all
21-
#' aheads (unique horizons) for each unique combination of `key_vars`.
20+
#' @return A data frame of point and interval forecasts for all aheads (unique
21+
#' horizons) for each unique combination of `key_vars`.
2222
#' @export
2323
#'
2424
#' @examples
2525
#' library(dplyr)
2626
#' weekly_deaths <- case_death_rate_subset %>%
2727
#' select(geo_value, time_value, death_rate) %>%
2828
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
29-
#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>%
29+
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
3030
#' select(-pop, -death_rate) %>%
3131
#' group_by(geo_value) %>%
3232
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%

R/flusight_hub_formatter.R

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
abbr_to_fips <- function(abbr) {
2+
fi <- dplyr::left_join(
3+
tibble::tibble(abbr = tolower(abbr)),
4+
state_census, by = "abbr") %>%
5+
dplyr::mutate(fips = as.character(fips), fips = case_when(
6+
fips == "0" ~ "US",
7+
nchar(fips) < 2L ~ paste0("0", fips),
8+
TRUE ~ fips
9+
)) %>%
10+
pull(.data$fips)
11+
names(fi) <- NULL
12+
fi
13+
}
14+
15+
#' Format predictions for submission to FluSight forecast Hub
16+
#'
17+
#' This function converts predictions from any of the included forecasters into
18+
#' a format (nearly) ready for submission to the 2023-24
19+
#' [FluSight-forecast-hub](https://github.com/cdcepi/FluSight-forecast-hub).
20+
#' See there for documentation of the required columns. Currently, only
21+
#' "quantile" forcasts are supported, but the intention is to support both
22+
#' "quantile" and "pmf". For this reason, adding the `output_type` column should
23+
#' be done via the `...` argument. See the examples below. The specific required
24+
#' format for this forecast task is [here](https://github.com/cdcepi/FluSight-forecast-hub/blob/main/model-output/README.md).
25+
#'
26+
#' @param object a data.frame of predictions or an object of class
27+
#' `canned_epipred` as created by, e.g., [arx_forecaster()]
28+
#' @param ... <[`dynamic-dots`][rlang::dyn-dots]> Name = value pairs of constant
29+
#' columns (or mutations) to perform to the results. See examples.
30+
#' @param .fcast_period Control whether the `horizon` should represent days or
31+
#' weeks. Depending on whether the forecaster output has target dates
32+
#' from [layer_add_target_date()] or not, we may need to compute the horizon
33+
#' and/or the `target_end_date` from the other available columns in the predictions.
34+
#' When both `ahead` and `target_date` are available, this is ignored. If only
35+
#' `ahead` or `aheads` exists, then the target date may need to be multiplied
36+
#' if the `ahead` represents weekly forecasts. Alternatively, if only, the
37+
#' `target_date` is available, then the `horizon` will be in days, unless
38+
#' this argument is `"weekly"`. Note that these can be adjusted later by the
39+
#' `...` argument.
40+
#'
41+
#' @return A [tibble::tibble]. If `...` is empty, the result will contain the
42+
#' columns `reference_date`, `horizon`, `target_end_date`, `location`,
43+
#' `output_type_id`, and `value`. The `...` can perform mutations on any of
44+
#' these.
45+
#' @export
46+
#'
47+
#' @examples
48+
#' library(dplyr)
49+
#' weekly_deaths <- case_death_rate_subset %>%
50+
#' select(geo_value, time_value, death_rate) %>%
51+
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
52+
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
53+
#' select(-pop, -death_rate) %>%
54+
#' group_by(geo_value) %>%
55+
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
56+
#' ungroup() %>%
57+
#' filter(weekdays(time_value) == "Saturday")
58+
#'
59+
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
60+
#' flusight_hub_formatter(cdc)
61+
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths")
62+
#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths"))
63+
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile")
64+
flusight_hub_formatter <- function(
65+
object, ...,
66+
.fcast_period = c("daily", "weekly")) {
67+
UseMethod("flusight_hub_formatter")
68+
}
69+
70+
#' @export
71+
flusight_hub_formatter.canned_epipred <- function(
72+
object, ...,
73+
.fcast_period = c("daily", "weekly")) {
74+
flusight_hub_formatter(object$predictions, ..., .fcast_period = .fcast_period)
75+
}
76+
77+
#' @export
78+
flusight_hub_formatter.data.frame <- function(
79+
object, ...,
80+
.fcast_period = c("daily", "weekly")) {
81+
required_names <- c(".pred", ".pred_distn", "forecast_date", "geo_value")
82+
optional_names <- c("ahead", "target_date")
83+
hardhat::validate_column_names(object, required_names)
84+
if (!any(optional_names %in% names(object))) {
85+
cli::cli_abort("At least one of {.val {optional_names}} must be present.")
86+
}
87+
88+
dots <- enquos(..., .named = TRUE)
89+
names <- names(dots)
90+
91+
object <- object %>%
92+
# combine the predictions and the distribution
93+
dplyr::mutate(.pred_distn = nested_quantiles(.pred_distn)) %>%
94+
dplyr::rowwise() %>%
95+
dplyr::mutate(
96+
.pred_distn = list(add_row(.pred_distn, q = .pred, tau = NA)),
97+
.pred = NULL
98+
) %>%
99+
tidyr::unnest(.pred_distn) %>%
100+
# now we create the correct column names
101+
dplyr::rename(
102+
value = q,
103+
output_type_id = tau,
104+
reference_date = forecast_date
105+
) %>%
106+
# convert to fips codes, and add any constant cols passed in ...
107+
dplyr::mutate(location = abbr_to_fips(tolower(geo_value)), geo_value = NULL)
108+
109+
# create target_end_date / horizon, depending on what is available
110+
pp <- ifelse(match.arg(.fcast_period) == "daily", 1L, 7L)
111+
has_ahead <- charmatch("ahead", names(object))
112+
if ("target_date" %in% names(object) && !is.na(has_ahead)) {
113+
object <- object %>%
114+
dplyr::rename(
115+
target_end_date = target_date,
116+
horizon = !!names(object)[has_ahead]
117+
)
118+
} else if (!is.na(has_ahead)) { # ahead present, not target date
119+
object <- object %>%
120+
dplyr::rename(horizon = !!names(object)[has_ahead]) %>%
121+
dplyr::mutate(target_end_date = horizon * pp + reference_date)
122+
} else { # target_date present, not ahead
123+
object <- object %>%
124+
dplyr::rename(target_end_date = target_date) %>%
125+
dplyr::mutate(horizon = as.integer((target_end_date - reference_date)) / pp)
126+
}
127+
object %>% dplyr::relocate(
128+
reference_date, horizon, target_end_date, location, output_type_id, value
129+
) %>%
130+
dplyr::mutate(!!!dots)
131+
}

R/layer_cdc_flatline_quantiles.R

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,30 @@
2222
#' @inheritParams layer_residual_quantiles
2323
#' @param aheads Numeric vector of desired forecast horizons. These should be
2424
#' given in the "units of the training data". So, for example, for data
25-
#' typically observed daily (possibly with missing values), but
26-
#' with weekly forecast targets, you would use `c(7, 14, 21, 28)`. But with
27-
#' weekly data, you would use `1:4`.
25+
#' typically observed daily (possibly with missing values), but with weekly
26+
#' forecast targets, you would use `c(7, 14, 21, 28)`. But with weekly data,
27+
#' you would use `1:4`.
2828
#' @param quantile_levels Numeric vector of probabilities with values in (0,1)
2929
#' referring to the desired predictive intervals. The default is the standard
3030
#' set for the COVID Forecast Hub.
3131
#' @param nsims Positive integer. The number of draws from the empirical CDF.
32-
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting
33-
#' in linear interpolation on the X scale. This is achieved with
32+
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting in
33+
#' linear interpolation on the X scale. This is achieved with
3434
#' [stats::quantile()] Type 7 (the default for that function).
35-
#' @param nonneg Logical. Force all predictive intervals be non-negative.
36-
#' Because non-negativity is forced _before_ propagating forward, this
37-
#' has slightly different behaviour than would occur if using
38-
#' [layer_threshold()].
35+
#' @param symmetrize Scalar logical. If `TRUE`, does two things: (i) forces the
36+
#' "empirical" CDF of residuals to be symmetric by pretending that for every
37+
#' actually-observed residual X we also observed another residual -X, and (ii)
38+
#' at each ahead, forces the median simulated value to be equal to the point
39+
#' prediction by adding or subtracting the same amount to every simulated
40+
#' value. Adjustments in (ii) take place before propagating forward and
41+
#' simulating the next ahead. This forces any 1-ahead predictive intervals to
42+
#' be symmetric about the point prediction, and encourages larger aheads to be
43+
#' more symmetric.
44+
#' @param nonneg Scalar logical. Force all predictive intervals be non-negative.
45+
#' Because non-negativity is forced _before_ propagating forward, this has
46+
#' slightly different behaviour than would occur if using [layer_threshold()].
47+
#' Thresholding at each ahead takes place after any shifting from
48+
#' `symmetrize`.
3949
#'
4050
#' @return an updated `frosting` postprocessor. Calling [predict()] will result
4151
#' in an additional `<list-col>` named `.pred_distn_all` containing 2-column
@@ -213,7 +223,7 @@ slather.layer_cdc_flatline_quantiles <-
213223
res <- dplyr::left_join(p, r, by = avail_grps) %>%
214224
dplyr::rowwise() %>%
215225
dplyr::mutate(
216-
.pred_distn_all = propogate_samples(
226+
.pred_distn_all = propagate_samples(
217227
.resid, .pred, object$quantile_levels,
218228
object$aheads, object$nsim, object$symmetrize, object$nonneg
219229
)
@@ -229,10 +239,14 @@ slather.layer_cdc_flatline_quantiles <-
229239
components
230240
}
231241

232-
propogate_samples <- function(
242+
propagate_samples <- function(
233243
r, p, quantile_levels, aheads, nsim, symmetrize, nonneg) {
234244
max_ahead <- max(aheads)
235-
samp <- quantile(r, probs = c(0, seq_len(nsim - 1)) / (nsim - 1), na.rm = TRUE)
245+
if (symmetrize) {
246+
r <- c(r, -r)
247+
}
248+
samp <- quantile(r, probs = c(0, seq_len(nsim - 1)) / (nsim - 1),
249+
na.rm = TRUE, names = FALSE)
236250
res <- list()
237251

238252
raw <- samp + p

man/cdc_baseline_forecaster.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/flusight_hub_formatter.Rd

Lines changed: 60 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_cdc_flatline_quantiles.Rd

Lines changed: 19 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-propogate_samples.R renamed to tests/testthat/test-propagate_samples.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
test_that("propogate_samples", {
1+
test_that("propagate_samples", {
22
r <- -30:50
33
p <- 40
44
quantiles <- 1:9 / 10

0 commit comments

Comments
 (0)