Skip to content

Commit 109d9ce

Browse files
authored
Merge pull request #32 from cmu-delphi/forecaster_testing_init
testing forecasters on simple datasets
2 parents 71bad06 + 0481615 commit 109d9ce

22 files changed

+542
-102
lines changed

.Rbuildignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
^renv$
22
^renv\.lock$
3+
^LICENSE\.md$
4+
^.lintr$
5+
^.renvignore$
6+
^.github$

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@ export(add_id)
55
export(arx_postprocess)
66
export(arx_preprocess)
77
export(collapse_cards)
8+
export(confirm_insufficient_data)
89
export(covidhub_probs)
910
export(evaluate_predictions)
1011
export(extend_ahead)
12+
export(flatline_fc)
1113
export(forecaster_pred)
1214
export(format_storage)
15+
export(id_ahead_ensemble_grid)
1316
export(interval_coverage)
17+
export(lookup_ids)
1418
export(make_target_param_grid)
1519
export(overprediction)
1620
export(perform_sanity_checks)
@@ -19,6 +23,7 @@ export(run_evaluation_measure)
1923
export(run_workflow_and_format)
2024
export(scaled_pop)
2125
export(sharpness)
26+
export(single_id)
2227
export(underprediction)
2328
export(weighted_interval_score)
2429
import(dplyr)
@@ -38,6 +43,7 @@ importFrom(purrr,map)
3843
importFrom(purrr,transpose)
3944
importFrom(rlang,.data)
4045
importFrom(rlang,quo)
46+
importFrom(rlang,sym)
4147
importFrom(rlang,syms)
4248
importFrom(tibble,tibble)
4349
importFrom(tidyr,pivot_wider)

R/forecaster.R

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,44 @@ perform_sanity_checks <- function(epi_data,
2828
if (!is.null(trainer) && !epipredict:::is_regression(trainer)) {
2929
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
3030
} else if (inherits(trainer, "quantile_reg")) {
31-
# add all levels to the trainer and update args list
32-
tau <- sort(epipredict:::compare_quantile_args(
33-
args_list$levels,
34-
rlang::eval_tidy(trainer$args$tau)
31+
# add all quantile_levels to the trainer and update args list
32+
quantile_levels <- sort(epipredict:::compare_quantile_args(
33+
args_list$quantile_levels,
34+
rlang::eval_tidy(trainer$args$quantile_levels)
3535
))
36-
args_list$levels <- tau
37-
trainer$args$tau <- rlang::enquo(tau)
36+
args_list$quantile_levels <- quantile_levels
37+
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
3838
}
3939
args_list$lags <- epipredict:::arx_lags_validator(predictors, args_list$lags)
4040
return(list(args_list, predictors, trainer))
4141
}
4242

43+
#' confirm that there's enough data to run this model
44+
#' @description
45+
#' epipredict is a little bit fragile about having enough data to train; we want
46+
#' to be able to return a null result rather than error out; this check say to
47+
#' return a null
48+
#' @param epi_data the input data
49+
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
50+
#' this trains on one sample; the default is set so that `linear_reg` isn't
51+
#' rank deficient)
52+
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
53+
#' @param args_input the input as supplied to `forecaster_pred`; lags is the
54+
#' important argument, which may or may not be defined, with the default
55+
#' coming from `arx_args_list`
56+
#' @export
57+
confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
58+
if (!is.null(args_input$lags)) {
59+
lag_max <- max(args_input$lags)
60+
} else {
61+
lag_max <- 14 # default value of 2 weeks
62+
}
63+
return(
64+
is.infinite(ahead) ||
65+
as.integer(max(epi_data$time_value) - min(epi_data$time_value)) <=
66+
lag_max + ahead + buffer
67+
)
68+
}
4369
# TODO replace with `step_arx_forecaster`
4470
#' add the default steps for arx_forecaster
4571
#' @description
@@ -86,20 +112,18 @@ arx_postprocess <- function(postproc,
86112
target_date = NULL) {
87113
postproc %<>% layer_predict()
88114
if (inherits(trainer, "quantile_reg")) {
89-
90-
postproc %<>% layer_quantile_distn(levels = args_list$levels) %>% layer_point_from_distn()
115+
postproc %<>% layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>% layer_point_from_distn()
91116
} else {
92117
postproc %<>% layer_residual_quantiles(
93-
probs = args_list$levels, symmetrize = args_list$symmetrize,
118+
quantile_levels = args_list$quantile_levels, symmetrize = args_list$symmetrize,
94119
by_key = args_list$quantile_by_key
95120
)
96121
}
97122
if (args_list$nonneg) {
98123
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
99124
}
100125

101-
postproc %<>% layer_naomit(dplyr::starts_with(".pred"))
102-
postproc %<>% layer_add_forecast_date(forecast_date = forecast_date) %>%
126+
postproc %<>% layer_naomit(dplyr::starts_with(".pred")) %>%
103127
layer_add_target_date(target_date = target_date)
104128
return(postproc)
105129
}
@@ -162,6 +186,14 @@ forecaster_pred <- function(data,
162186
if (length(forecaster_args) > 0) {
163187
names(forecaster_args) <- forecaster_args_names
164188
}
189+
if (is.null(forecaster_args$ahead)) {
190+
cli::cli_abort(
191+
c(
192+
"exploration-tooling error: forecaster_pred needs some value for ahead."
193+
),
194+
class = "explorationToolingError"
195+
)
196+
}
165197
if (!is.numeric(forecaster_args$n_training) && !is.null(forecaster_args$n_training)) {
166198
n_training <- as.numeric(forecaster_args$n_training)
167199
net_slide_training <- max(slide_training, n_training) + n_training_pad
@@ -171,11 +203,6 @@ forecaster_pred <- function(data,
171203
}
172204
# restrict the dataset to areas where training is possible
173205
start_date <- min(archive$DT$time_value) + net_slide_training
174-
if (slide_training < Inf) {
175-
start_date <- min(archive$DT$time_value) + slide_training + n_training_pad
176-
} else {
177-
start_date <- min(archive$DT$time_value) + n_training_pad
178-
}
179206
end_date <- max(archive$DT$time_value) - forecaster_args$ahead
180207
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = 1)
181208

@@ -206,7 +233,7 @@ forecaster_pred <- function(data,
206233

207234
# append the truth data
208235
true_value <- archive$as_of(archive$versions_end) %>%
209-
select(geo_value, time_value, outcome) %>%
236+
select(geo_value, time_value, !!outcome) %>%
210237
rename(true_value = !!outcome)
211238
res %<>%
212239
inner_join(true_value,

R/forecaster_flatline.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#' flatline forecaster (aka baseline)
2+
#' @description
3+
#' a minimal forecaster whose median is just the last value
4+
#' does not support `lags` as a parameter, but otherwise has the same parameters as `arx_forecaster`
5+
#' @inheritParams scaled_pop
6+
#' @importFrom rlang sym
7+
#' @export
8+
flatline_fc <- function(epi_data,
9+
outcome,
10+
extra_sources = "",
11+
ahead = 1,
12+
trainer = parsnip::linear_reg(),
13+
quantile_levels = covidhub_probs(),
14+
...) {
15+
# perform any preprocessing not supported by epipredict
16+
# one that every forecaster will need to handle: how to manage max(time_value)
17+
# that's older than the `as_of` date
18+
epidataAhead <- extend_ahead(epi_data, ahead)
19+
# see latency_adjusting for other examples
20+
# this next part is basically unavoidable boilerplate you'll want to copy
21+
epi_data <- epidataAhead[[1]]
22+
effective_ahead <- epidataAhead[[2]]
23+
args_input <- list(...)
24+
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
25+
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
26+
null_result <- tibble(
27+
geo_value = character(),
28+
forecast_date = Date(),
29+
target_end_date = Date(),
30+
quantile = numeric(),
31+
value = numeric()
32+
)
33+
return(null_result)
34+
}
35+
args_input[["ahead"]] <- effective_ahead
36+
args_input[["quantile_levels"]] <- quantile_levels
37+
args_list <- do.call(flatline_args_list, args_input)
38+
# if you want to ignore extra_sources, setting predictors is the way to do it
39+
predictors <- c(outcome, extra_sources)
40+
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, NULL, args_list)
41+
args_list <- argsPredictorsTrainer[[1]]
42+
predictors <- argsPredictorsTrainer[[2]]
43+
# end of the copypasta
44+
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
45+
# epipredict
46+
47+
# since this is just the flatline, we don't need much of anything
48+
res <- flatline_forecaster(epi_data, outcome = outcome, args_list = args_list)
49+
true_forecast_date <- attributes(epi_data)$metadata$as_of
50+
pred <- format_storage(res$predictions, true_forecast_date)
51+
# (geo_value, forecast_date, target_end_date, quantile, value)
52+
# finally, any postprocessing not supported by epipredict e.g. calibration
53+
return(pred)
54+
}

R/forecaster_scaled_pop.R

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#' should be ones that will store well in a data.table; if you need more
3333
#' complicated parameters, it is better to store them in separate files, and
3434
#' use the filename as the parameter.
35-
#' @param levels The quantile levels to predict. Defaults to those required by
35+
#' @param quantile_levels The quantile levels to predict. Defaults to those required by
3636
#' covidhub.
3737
#' @seealso some utilities for making forecasters: [format_storage],
3838
#' [perform_sanity_checks]
@@ -45,10 +45,10 @@
4545
scaled_pop <- function(epi_data,
4646
outcome,
4747
extra_sources = "",
48-
ahead=1,
48+
ahead = 1,
4949
pop_scaling = TRUE,
5050
trainer = parsnip::linear_reg(),
51-
levels = covidhub_probs(),
51+
quantile_levels = covidhub_probs(),
5252
...) {
5353
# perform any preprocessing not supported by epipredict
5454
# one that every forecaster will need to handle: how to manage max(time_value)
@@ -58,9 +58,9 @@ scaled_pop <- function(epi_data,
5858
# this next part is basically unavoidable boilerplate you'll want to copy
5959
epi_data <- epidataAhead[[1]]
6060
effective_ahead <- epidataAhead[[2]]
61-
# edge case where there is no data; eventually epipredict will handle this
62-
if (is.infinite(effective_ahead)) {
63-
effective_ahead <- 0
61+
args_input <- list(...)
62+
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
63+
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
6464
null_result <- tibble(
6565
geo_value = character(),
6666
forecast_date = Date(),
@@ -70,9 +70,8 @@ scaled_pop <- function(epi_data,
7070
)
7171
return(null_result)
7272
}
73-
args_input <- list(...)
7473
args_input[["ahead"]] <- effective_ahead
75-
args_input[["levels"]] <- levels
74+
args_input[["quantile_levels"]] <- quantile_levels
7675
args_list <- do.call(arx_args_list, args_input)
7776
# if you want to ignore extra_sources, setting predictors is the way to do it
7877
predictors <- c(outcome, extra_sources)

R/formatters.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
1818
.dstn = nested_quantiles(.pred_distn)
1919
) %>%
2020
unnest(.dstn) %>%
21-
select(-.pred_distn, -.pred, -time_value) %>%
22-
rename(quantile = tau, value = q, target_end_date = target_date) %>%
21+
select(-any_of(c(".pred_distn", ".pred", "time_value"))) %>%
22+
rename(quantile = quantile_levels, value = values, target_end_date = target_date) %>%
2323
relocate(geo_value, forecast_date, target_end_date, quantile, value)
2424
}
2525

@@ -33,13 +33,13 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
3333
#' @param true_forecast_date the actual date from which the model is
3434
#' making the forecast, rather than the last day of available data
3535
#' @param target_end_date the date of the prediction
36-
#' @param levels the quantile levels
36+
#' @param quantile_levels the quantile levels
3737
#' @import dplyr
38-
format_covidhub <- function(pred, true_forecast_date, target_end_date, levels) {
38+
format_covidhub <- function(pred, true_forecast_date, target_end_date, quantile_levels) {
3939
pred %<>%
4040
group_by(forecast_date, geo_value, target_date) %>%
4141
rename(target_end_date = target_date) %>%
42-
reframe(quantile = levels, value = quantile(.pred_distn, levels)[[1]])
42+
reframe(quantile = quantile_levels, value = quantile(.pred_distn, quantile_levels)[[1]])
4343
forecasts$ahead <- ahead
4444
forecasts %<>%
4545
group_by(forecast_date, geo_value, target_date) %>%

R/small_utils.R

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,63 @@ covidhub_probs <- function(type = c("standard", "inc_case")) {
2020
#' @importFrom cli hash_animal
2121
#' @export
2222
add_id <- function(df, n_adj = 2) {
23-
stringified <- df %>%
24-
select(-ahead) %>%
23+
no_ahead <- df %>%
24+
select(-ahead)
25+
stringified <- no_ahead %>%
26+
select(order(colnames(no_ahead))) %>%
2527
rowwise() %>%
26-
mutate(id = paste(across(everything()), collapse = ""), .keep="none") %>%
28+
mutate(id = paste(across(everything()), sep = "", collapse = ""), .keep = "none") %>%
2729
mutate(id = hash_animal(id, n_adj = n_adj)$words) %>%
28-
mutate(id = paste(id[1:n_adj], sep="", collapse = " "))
30+
mutate(id = paste(id[1:n_adj], sep = "", collapse = "."))
2931
df %<>%
32+
ungroup %>%
3033
mutate(parent_id = stringified$id) %>%
3134
rowwise() %>%
32-
mutate(id = paste(parent_id, ahead, collapse = " ")) %>%
35+
mutate(id = paste(parent_id, ahead, sep = ".", collapse = " ")) %>%
3336
ungroup()
3437
return(df)
3538
}
39+
40+
#' generate an id from a simple list of parameters
41+
#' @param param_list the list of parameters. must include `ahead` if `ahead = NULL`
42+
#' @param ahead the ahead to use.
43+
#' @inheritParams add_id
44+
#' @export
45+
single_id <- function(param_list, ahead = NULL, n_adj = 2) {
46+
full_hash <- param_list[names(param_list) != "ahead"] %>%
47+
.[order(names(.))] %>% # put in alphabetical order
48+
paste(collapse = "") %>%
49+
hash_animal(n_adj = n_adj)
50+
single_string <- full_hash$words[[1]][1:n_adj] %>% paste(sep = ".", collapse = ".")
51+
if (is.null(ahead)) {
52+
full_name <- paste(single_string, param_list$ahead, sep = ".")
53+
} else {
54+
full_name <- paste(single_string, ahead, sep = ".")
55+
}
56+
return(full_name)
57+
}
58+
#' given target name(s), lookup the corresponding parameters
59+
#' @export
60+
lookup_ids <- function() {
61+
}
62+
63+
64+
#' add aheads, forecaster_ids, and ids to a list of ensemble models
65+
#' @description
66+
#' minor utility
67+
#' @param ensemble_grid the list of ensembles,
68+
#' @param aheads the aheads to add
69+
#' @inheritParams add_id
70+
#' @export
71+
id_ahead_ensemble_grid <- function(ensemble_grid, aheads, n_adj = 2) {
72+
ensemble_grid <- expand_grid(
73+
ensemble_grid,
74+
tibble(ahead = aheads)
75+
)
76+
77+
ensemble_grid %<>%
78+
add_id(., n_adj = 2) %>%
79+
rowwise() %>%
80+
mutate(forecaster_ids = list(map2_vec(forecasters, ahead, single_id, n_adj = 2)))
81+
return(ensemble_grid)
82+
}

_targets.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,3 @@ covid_hosp_prod:
1414
script: covid_hosp_prod.R
1515
store: covid_hosp_prod
1616
use_crew: yes
17-
forecaster_testing:
18-
script: forecaster_testing.R
19-
store: forecaster_testing
20-
use_crew: no

_targets/.gitignore

Lines changed: 0 additions & 5 deletions
This file was deleted.

covid_hosp_explore.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ tar_option_set(
2828
imports = c("epieval", "parsnip"),
2929
format = "qs", # Optionally set the default storage format. qs is fast.
3030
controller = crew::crew_controller_local(workers = parallel::detectCores() - 5),
31-
)
31+
)
3232
# Run the R scripts in the R/ folder with your custom functions:
3333
# tar_source()
3434
# where the forecasters and parameters are joined; see either the variable param_grid or `tar_read(forecasters)`

man/confirm_insufficient_data.Rd

Lines changed: 26 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)