Skip to content

Commit 2f61069

Browse files
committed
consistent name, only smooth non-smoothed, init forecaster
1 parent 4c8515f commit 2f61069

File tree

7 files changed

+180
-30
lines changed

7 files changed

+180
-30
lines changed

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ export(scaled_pop)
3838
export(sharpness)
3939
export(single_id)
4040
export(slide_forecaster)
41-
export(smooth_scaled)
41+
export(smoothed_scaled)
4242
export(underprediction)
4343
export(weighted_interval_score)
4444
importFrom(assertthat,assert_that)

R/data_transforms.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# various reusable transforms to apply before handing to epipredict
22

3-
#' extract the non-key columns from epi_data
3+
#' extract the non-key, non-smoothed columns from epi_data
44
#' @keywords internal
55
#' @param epi_data the epi_data tibble
66
#' @param cols vector of column names to use. If `NULL`, fill with all non-key columns
77
get_trainable_names <- function(epi_data, cols) {
88
if (is.null(cols)) {
99
cols <- names(epi_data)
1010
cols <- cols[!(cols %in% c("geo_value", "time_value", attr(epi_data, "metadata")$other_keys))]
11+
# exclude anything with the same naming schema as the rolling average/sd created below
12+
cols <- cols[!grepl("_\\w{1,2}\\d+", cols)]
1113
}
1214
return(cols)
1315
}
@@ -27,7 +29,7 @@ rolling_mean <- function(epi_data, width = 7L, cols_to_mean = NULL) {
2729
cols_to_mean <- get_trainable_names(epi_data, cols_to_mean)
2830
epi_data %<>% group_by(geo_value)
2931
for (col in cols_to_mean) {
30-
mean_name <- paste0(col, width)
32+
mean_name <- paste0(col, "_m", width)
3133
epi_data %<>% mutate({{ mean_name }} := slider::slide_dbl(.data[[col]], mean, .before = width))
3234
}
3335
epi_data %<>% ungroup()

R/forecaster_smoothed_scaled.R

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#' predict on smoothed data and the standard deviation
2+
#' @description
3+
#' This is a variant of `scaled_pop`, which predicts on a smoothed version of
4+
#' the data. Even if the target is smoothed when used as a /predictor/, as a
5+
#' /target/ it still uses the raw value (this captures some of the noise). It
6+
#' also uses a rolling standard deviation as an auxillary signal, window of
7+
#' withd `sd_width`, which by default is 28 days.
8+
#' @param epi_data the actual data used
9+
#' @param outcome the name of the target variable
10+
#' @param extra_sources the name of any extra columns to use. This list could be
11+
#' empty
12+
#' @param ahead (this is relative to the `as_of` field of the `epi_df`, which is
13+
#' likely *not* the same as the `ahead` used by epipredict, which is relative
14+
#' to the max time value of the `epi_df`. how to handle this is a modelling
15+
#' question left up to each forecaster; see latency_adjusting.R for the
16+
#' existing examples)
17+
#' @param pop_scaling an example extra parameter unique to this forecaster
18+
#' @param trainer an example extra parameter that is fairly common
19+
#' @param smooth_width the number of days over which to do smoothing. If `NULL`,
20+
#' then no smoothing is applied.
21+
#' @param smooth_cols the names of the columns to smooth. If `NULL` it smooths
22+
#' everything
23+
#' @param sd_width the number of days over which to take a moving average of the
24+
#' standard deviation. If `NULL`, the sd_width isn't included.
25+
#' @param sd_mean_width to calculate the sd, we need a window size for the mean
26+
#' used.
27+
#' @param sd_cols the names of the columns to smooth. If `NULL` its includes
28+
#' the sd of everything
29+
#' @param quantile_levels The quantile levels to predict. Defaults to those
30+
#' required by covidhub.
31+
#' @seealso some utilities for making forecasters: [format_storage],
32+
#' [perform_sanity_checks]
33+
#' @importFrom epipredict epi_recipe step_population_scaling frosting arx_args_list layer_population_scaling
34+
#' @importFrom tibble tibble
35+
#' @importFrom recipes all_numeric
36+
#' @export
37+
smoothed_scaled <- function(epi_data,
38+
outcome,
39+
extra_sources = "",
40+
ahead = 1,
41+
pop_scaling = TRUE,
42+
trainer = parsnip::linear_reg(),
43+
quantile_levels = covidhub_probs(),
44+
smooth_width = 7,
45+
smooth_cols = NULL,
46+
sd_width = 28,
47+
sd_mean_width = 14,
48+
sd_cols = NULL,
49+
...) {
50+
# perform any preprocessing not supported by epipredict
51+
# this is a temp fix until a real fix gets put into epipredict
52+
epi_data <- clear_lastminute_nas(epi_data)
53+
# one that every forecaster will need to handle: how to manage max(time_value)
54+
# that's older than the `as_of` date
55+
epidataAhead <- extend_ahead(epi_data, ahead)
56+
# see latency_adjusting for other examples
57+
# this next part is basically unavoidable boilerplate you'll want to copy
58+
epi_data <- epidataAhead[[1]]
59+
effective_ahead <- epidataAhead[[2]]
60+
args_input <- list(...)
61+
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
62+
if (!confirm_sufficient_data(epi_data, effective_ahead, args_input)) {
63+
null_result <- tibble(
64+
geo_value = character(),
65+
forecast_date = lubridate::Date(),
66+
target_end_date = lubridate::Date(),
67+
quantile = numeric(),
68+
value = numeric()
69+
)
70+
return(null_result)
71+
}
72+
args_input[["ahead"]] <- effective_ahead
73+
args_input[["quantile_levels"]] <- quantile_levels
74+
args_list <- do.call(arx_args_list, args_input)
75+
# if you want to ignore extra_sources, setting predictors is the way to do it
76+
predictors <- c(outcome, extra_sources)
77+
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
78+
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)
79+
args_list <- argsPredictorsTrainer[[1]]
80+
predictors <- argsPredictorsTrainer[[2]]
81+
trainer <- argsPredictorsTrainer[[3]]
82+
# end of the copypasta
83+
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
84+
# epipredict
85+
# smoothing
86+
keep_mean <- (smooth_width == sd_mean_width) # do we need to do the mean separately?
87+
if (!is.null(smooth_width) && !keep_mean) {
88+
epi_data %<>% rolling_mean(
89+
width = smooth_width,
90+
cols_to_mean = smooth_cols
91+
)
92+
}
93+
94+
# measuring standard deviation
95+
if (!is.null(sd_width)) {
96+
epi_data %<>% rolling_sd(
97+
sd_width = sd_width,
98+
mean_width = sd_mean_width,
99+
cols_to_sd = sd_cols,
100+
keep_mean = keep_mean
101+
)
102+
}
103+
# even
104+
105+
# preprocessing supported by epipredict
106+
preproc <- epi_recipe(epi_data)
107+
if (pop_scaling) {
108+
preproc %<>% step_population_scaling(
109+
all_numeric(),
110+
df = epipredict::state_census,
111+
df_pop_col = "pop",
112+
create_new = FALSE,
113+
rate_rescaling = 1e5,
114+
by = c("geo_value" = "abbr")
115+
)
116+
}
117+
preproc %<>% arx_preprocess(outcome, predictors, args_list)
118+
119+
# postprocessing supported by epipredict
120+
postproc <- frosting()
121+
postproc %<>% arx_postprocess(trainer, args_list)
122+
if (pop_scaling) {
123+
postproc %<>% layer_population_scaling(
124+
.pred, .pred_distn,
125+
df = epipredict::state_census,
126+
df_pop_col = "pop",
127+
create_new = FALSE,
128+
rate_rescaling = 1e5,
129+
by = c("geo_value" = "abbr")
130+
)
131+
}
132+
# with all the setup done, we execute and format
133+
pred <- run_workflow_and_format(preproc, postproc, trainer, epi_data)
134+
# now pred has the columns
135+
# (geo_value, forecast_date, target_end_date, quantile, value)
136+
# finally, any postprocessing not supported by epipredict e.g. calibration
137+
return(pred)
138+
}

man/get_trainable_names.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/smooth_scaled.Rd renamed to man/smoothed_scaled.Rd

Lines changed: 13 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-forecasters-basics.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ library(dplyr)
22
# TODO better way to do this than copypasta
33
forecasters <- list(
44
c("scaled_pop", scaled_pop),
5-
c("flatline_fc", flatline_fc)
5+
c("flatline_fc", flatline_fc),
6+
c("smoothed_scaled", smoothed_scaled)
67
)
78
for (forecaster in forecasters) {
89
test_that(paste(forecaster[[1]], "gets the date and columns right"), {

tests/testthat/test-transforms.R

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,41 @@ n_days <- 40
22
simple_dates <- seq(as.Date("2012-01-01"), by = "day", length.out = n_days)
33
rand_vals <- rnorm(n_days)
44
epi_data <- epiprocess::as_epi_df(rbind(tibble(
5-
geo_value = "al",
6-
time_value = simple_dates,
7-
a = 1:n_days,
8-
b = rand_vals
5+
geo_value = "al",
6+
time_value = simple_dates,
7+
a = 1:n_days,
8+
b = rand_vals
99
), tibble(
10-
geo_value = "ca",
11-
time_value = simple_dates,
12-
a = n_days:1,
13-
b = rand_vals + 10
10+
geo_value = "ca",
11+
time_value = simple_dates,
12+
a = n_days:1,
13+
b = rand_vals + 10
1414
)))
1515
test_that("rolling_mean generates correct mean", {
1616
rolled <- rolling_mean(epi_data)
17-
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a7", "b7"))
17+
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_m7", "b_m7"))
1818
# hand specified rolling mean with a rear window of 7, noting that mean(1:7) = 4
19-
linear_roll_mean <- c(seq(from=1, to = 4, by = .5), seq(from = 4.5, to = 36.5, by = 1))
20-
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a7"), linear_roll_mean)
19+
linear_roll_mean <- c(seq(from = 1, to = 4, by = .5), seq(from = 4.5, to = 36.5, by = 1))
20+
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_m7"), linear_roll_mean)
2121
# same, but "ca" is reversed, noting mean(40:(40-7)) =36.5
22-
linear_reverse_roll_mean <- c(seq(from=40, to = 36.5, by = -0.5), seq(from = 35.5, to = 4.5, by = -1))
23-
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a7"), linear_reverse_roll_mean)
22+
linear_reverse_roll_mean <- c(seq(from = 40, to = 36.5, by = -0.5), seq(from = 35.5, to = 4.5, by = -1))
23+
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a_m7"), linear_reverse_roll_mean)
2424
})
2525

2626
test_that("rolling_sd generates correct standard deviation", {
2727
rolled <- rolling_sd(epi_data)
2828
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_SD28", "b_SD28"))
2929
# hand specified rolling mean with a rear window of 7, noting that mean(1:14) = 7.5
30-
linear_roll_mean <- c(seq(from=1, to = 7.5, by = .5), seq(from = 8, to = 33, by = 1))
30+
linear_roll_mean <- c(seq(from = 1, to = 7.5, by = .5), seq(from = 8, to = 33, by = 1))
3131
# and the standard deviation is
3232
linear_roll_sd <- sqrt(slider::slide_dbl((1:40 - linear_roll_mean)^2, mean, .before = 28))
3333
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_SD28"), linear_roll_sd)
3434
# even though ca is reversed, the changes are all the same, so the standard deviation is *exactly* the same values
3535
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a_SD28"), linear_roll_sd)
36-
})
36+
})
37+
testthat("get_trainable_names pulls out mean and sd columns", {
38+
rolled <- rolling_sd(epi_data, keep_mean = TRUE)
39+
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_m14", "a_SD28", "b_m14", "b_SD28"))
40+
expect_equal(get_trainable_names(rolled, NULL), c("a", "b"))
41+
})
3742
# TODO example with NA's, example with missing days, only one column, keep_mean

0 commit comments

Comments
 (0)