Skip to content

testing forecasters on simple datasets #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
^renv$
^renv\.lock$
^LICENSE\.md$
^.lintr$
^.renvignore$
^.github$
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ export(add_id)
export(arx_postprocess)
export(arx_preprocess)
export(collapse_cards)
export(confirm_insufficient_data)
export(covidhub_probs)
export(evaluate_predictions)
export(extend_ahead)
export(flatline_fc)
export(forecaster_pred)
export(format_storage)
export(id_ahead_ensemble_grid)
export(interval_coverage)
export(lookup_ids)
export(make_target_param_grid)
export(overprediction)
export(perform_sanity_checks)
Expand All @@ -19,6 +23,7 @@ export(run_evaluation_measure)
export(run_workflow_and_format)
export(scaled_pop)
export(sharpness)
export(single_id)
export(underprediction)
export(weighted_interval_score)
import(dplyr)
Expand All @@ -38,6 +43,7 @@ importFrom(purrr,map)
importFrom(purrr,transpose)
importFrom(rlang,.data)
importFrom(rlang,quo)
importFrom(rlang,sym)
importFrom(rlang,syms)
importFrom(tibble,tibble)
importFrom(tidyr,pivot_wider)
61 changes: 44 additions & 17 deletions R/forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,44 @@ perform_sanity_checks <- function(epi_data,
if (!is.null(trainer) && !epipredict:::is_regression(trainer)) {
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
} else if (inherits(trainer, "quantile_reg")) {
# add all levels to the trainer and update args list
tau <- sort(epipredict:::compare_quantile_args(
args_list$levels,
rlang::eval_tidy(trainer$args$tau)
# add all quantile_levels to the trainer and update args list
quantile_levels <- sort(epipredict:::compare_quantile_args(
args_list$quantile_levels,
rlang::eval_tidy(trainer$args$quantile_levels)
))
args_list$levels <- tau
trainer$args$tau <- rlang::enquo(tau)
args_list$quantile_levels <- quantile_levels
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
}
args_list$lags <- epipredict:::arx_lags_validator(predictors, args_list$lags)
return(list(args_list, predictors, trainer))
}

#' confirm that there's enough data to run this model
#' @description
#' epipredict is a little bit fragile about having enough data to train; we want
#' to be able to return a null result rather than error out; this check say to
#' return a null
#' @param epi_data the input data
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
#' this trains on one sample; the default is set so that `linear_reg` isn't
#' rank deficient)
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
#' @param args_input the input as supplied to `forecaster_pred`; lags is the
#' important argument, which may or may not be defined, with the default
#' coming from `arx_args_list`
#' @export
confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
if (!is.null(args_input$lags)) {
lag_max <- max(args_input$lags)
} else {
lag_max <- 14 # default value of 2 weeks
}
return(
is.infinite(ahead) ||
as.integer(max(epi_data$time_value) - min(epi_data$time_value)) <=
lag_max + ahead + buffer
)
}
# TODO replace with `step_arx_forecaster`
#' add the default steps for arx_forecaster
#' @description
Expand Down Expand Up @@ -86,20 +112,18 @@ arx_postprocess <- function(postproc,
target_date = NULL) {
postproc %<>% layer_predict()
if (inherits(trainer, "quantile_reg")) {

postproc %<>% layer_quantile_distn(levels = args_list$levels) %>% layer_point_from_distn()
postproc %<>% layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>% layer_point_from_distn()
} else {
postproc %<>% layer_residual_quantiles(
probs = args_list$levels, symmetrize = args_list$symmetrize,
quantile_levels = args_list$quantile_levels, symmetrize = args_list$symmetrize,
by_key = args_list$quantile_by_key
)
}
if (args_list$nonneg) {
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
}

postproc %<>% layer_naomit(dplyr::starts_with(".pred"))
postproc %<>% layer_add_forecast_date(forecast_date = forecast_date) %>%
postproc %<>% layer_naomit(dplyr::starts_with(".pred")) %>%
layer_add_target_date(target_date = target_date)
return(postproc)
}
Expand Down Expand Up @@ -162,6 +186,14 @@ forecaster_pred <- function(data,
if (length(forecaster_args) > 0) {
names(forecaster_args) <- forecaster_args_names
}
if (is.null(forecaster_args$ahead)) {
cli::cli_abort(
c(
"exploration-tooling error: forecaster_pred needs some value for ahead."
),
class = "explorationToolingError"
)
}
if (!is.numeric(forecaster_args$n_training) && !is.null(forecaster_args$n_training)) {
n_training <- as.numeric(forecaster_args$n_training)
net_slide_training <- max(slide_training, n_training) + n_training_pad
Expand All @@ -171,11 +203,6 @@ forecaster_pred <- function(data,
}
# restrict the dataset to areas where training is possible
start_date <- min(archive$DT$time_value) + net_slide_training
if (slide_training < Inf) {
start_date <- min(archive$DT$time_value) + slide_training + n_training_pad
} else {
start_date <- min(archive$DT$time_value) + n_training_pad
}
end_date <- max(archive$DT$time_value) - forecaster_args$ahead
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = 1)

Expand Down Expand Up @@ -206,7 +233,7 @@ forecaster_pred <- function(data,

# append the truth data
true_value <- archive$as_of(archive$versions_end) %>%
select(geo_value, time_value, outcome) %>%
select(geo_value, time_value, !!outcome) %>%
rename(true_value = !!outcome)
res %<>%
inner_join(true_value,
Expand Down
54 changes: 54 additions & 0 deletions R/forecaster_flatline.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#' flatline forecaster (aka baseline)
#' @description
#' a minimal forecaster whose median is just the last value
#' does not support `lags` as a parameter, but otherwise has the same parameters as `arx_forecaster`
#' @inheritParams scaled_pop
#' @importFrom rlang sym
#' @export
flatline_fc <- function(epi_data,
outcome,
extra_sources = "",
ahead = 1,
trainer = parsnip::linear_reg(),
quantile_levels = covidhub_probs(),
...) {
# perform any preprocessing not supported by epipredict
# one that every forecaster will need to handle: how to manage max(time_value)
# that's older than the `as_of` date
epidataAhead <- extend_ahead(epi_data, ahead)
# see latency_adjusting for other examples
# this next part is basically unavoidable boilerplate you'll want to copy
epi_data <- epidataAhead[[1]]
effective_ahead <- epidataAhead[[2]]
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = Date(),
target_end_date = Date(),
quantile = numeric(),
value = numeric()
)
return(null_result)
}
args_input[["ahead"]] <- effective_ahead
args_input[["quantile_levels"]] <- quantile_levels
args_list <- do.call(flatline_args_list, args_input)
# if you want to ignore extra_sources, setting predictors is the way to do it
predictors <- c(outcome, extra_sources)
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, NULL, args_list)
args_list <- argsPredictorsTrainer[[1]]
predictors <- argsPredictorsTrainer[[2]]
# end of the copypasta
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
# epipredict

# since this is just the flatline, we don't need much of anything
res <- flatline_forecaster(epi_data, outcome = outcome, args_list = args_list)
true_forecast_date <- attributes(epi_data)$metadata$as_of
pred <- format_storage(res$predictions, true_forecast_date)
# (geo_value, forecast_date, target_end_date, quantile, value)
# finally, any postprocessing not supported by epipredict e.g. calibration
return(pred)
}
15 changes: 7 additions & 8 deletions R/forecaster_scaled_pop.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#' should be ones that will store well in a data.table; if you need more
#' complicated parameters, it is better to store them in separate files, and
#' use the filename as the parameter.
#' @param levels The quantile levels to predict. Defaults to those required by
#' @param quantile_levels The quantile levels to predict. Defaults to those required by
#' covidhub.
#' @seealso some utilities for making forecasters: [format_storage],
#' [perform_sanity_checks]
Expand All @@ -45,10 +45,10 @@
scaled_pop <- function(epi_data,
outcome,
extra_sources = "",
ahead=1,
ahead = 1,
pop_scaling = TRUE,
trainer = parsnip::linear_reg(),
levels = covidhub_probs(),
quantile_levels = covidhub_probs(),
...) {
# perform any preprocessing not supported by epipredict
# one that every forecaster will need to handle: how to manage max(time_value)
Expand All @@ -58,9 +58,9 @@ scaled_pop <- function(epi_data,
# this next part is basically unavoidable boilerplate you'll want to copy
epi_data <- epidataAhead[[1]]
effective_ahead <- epidataAhead[[2]]
# edge case where there is no data; eventually epipredict will handle this
if (is.infinite(effective_ahead)) {
effective_ahead <- 0
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = Date(),
Expand All @@ -70,9 +70,8 @@ scaled_pop <- function(epi_data,
)
return(null_result)
}
args_input <- list(...)
args_input[["ahead"]] <- effective_ahead
args_input[["levels"]] <- levels
args_input[["quantile_levels"]] <- quantile_levels
args_list <- do.call(arx_args_list, args_input)
# if you want to ignore extra_sources, setting predictors is the way to do it
predictors <- c(outcome, extra_sources)
Expand Down
10 changes: 5 additions & 5 deletions R/formatters.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
.dstn = nested_quantiles(.pred_distn)
) %>%
unnest(.dstn) %>%
select(-.pred_distn, -.pred, -time_value) %>%
rename(quantile = tau, value = q, target_end_date = target_date) %>%
select(-any_of(c(".pred_distn", ".pred", "time_value"))) %>%
rename(quantile = quantile_levels, value = values, target_end_date = target_date) %>%
relocate(geo_value, forecast_date, target_end_date, quantile, value)
}

Expand All @@ -33,13 +33,13 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
#' @param true_forecast_date the actual date from which the model is
#' making the forecast, rather than the last day of available data
#' @param target_end_date the date of the prediction
#' @param levels the quantile levels
#' @param quantile_levels the quantile levels
#' @import dplyr
format_covidhub <- function(pred, true_forecast_date, target_end_date, levels) {
format_covidhub <- function(pred, true_forecast_date, target_end_date, quantile_levels) {
pred %<>%
group_by(forecast_date, geo_value, target_date) %>%
rename(target_end_date = target_date) %>%
reframe(quantile = levels, value = quantile(.pred_distn, levels)[[1]])
reframe(quantile = quantile_levels, value = quantile(.pred_distn, quantile_levels)[[1]])
forecasts$ahead <- ahead
forecasts %<>%
group_by(forecast_date, geo_value, target_date) %>%
Expand Down
57 changes: 52 additions & 5 deletions R/small_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,63 @@ covidhub_probs <- function(type = c("standard", "inc_case")) {
#' @importFrom cli hash_animal
#' @export
add_id <- function(df, n_adj = 2) {
stringified <- df %>%
select(-ahead) %>%
no_ahead <- df %>%
select(-ahead)
stringified <- no_ahead %>%
select(order(colnames(no_ahead))) %>%
rowwise() %>%
mutate(id = paste(across(everything()), collapse = ""), .keep="none") %>%
mutate(id = paste(across(everything()), sep = "", collapse = ""), .keep = "none") %>%
mutate(id = hash_animal(id, n_adj = n_adj)$words) %>%
mutate(id = paste(id[1:n_adj], sep="", collapse = " "))
mutate(id = paste(id[1:n_adj], sep = "", collapse = "."))
df %<>%
ungroup %>%
mutate(parent_id = stringified$id) %>%
rowwise() %>%
mutate(id = paste(parent_id, ahead, collapse = " ")) %>%
mutate(id = paste(parent_id, ahead, sep = ".", collapse = " ")) %>%
ungroup()
return(df)
}

#' generate an id from a simple list of parameters
#' @param param_list the list of parameters. must include `ahead` if `ahead = NULL`
#' @param ahead the ahead to use.
#' @inheritParams add_id
#' @export
single_id <- function(param_list, ahead = NULL, n_adj = 2) {
full_hash <- param_list[names(param_list) != "ahead"] %>%
.[order(names(.))] %>% # put in alphabetical order
paste(collapse = "") %>%
hash_animal(n_adj = n_adj)
single_string <- full_hash$words[[1]][1:n_adj] %>% paste(sep = ".", collapse = ".")
if (is.null(ahead)) {
full_name <- paste(single_string, param_list$ahead, sep = ".")
} else {
full_name <- paste(single_string, ahead, sep = ".")
}
return(full_name)
}
#' given target name(s), lookup the corresponding parameters
#' @export
lookup_ids <- function() {
}


#' add aheads, forecaster_ids, and ids to a list of ensemble models
#' @description
#' minor utility
#' @param ensemble_grid the list of ensembles,
#' @param aheads the aheads to add
#' @inheritParams add_id
#' @export
id_ahead_ensemble_grid <- function(ensemble_grid, aheads, n_adj = 2) {
ensemble_grid <- expand_grid(
ensemble_grid,
tibble(ahead = aheads)
)

ensemble_grid %<>%
add_id(., n_adj = 2) %>%
rowwise() %>%
mutate(forecaster_ids = list(map2_vec(forecasters, ahead, single_id, n_adj = 2)))
return(ensemble_grid)
}
4 changes: 0 additions & 4 deletions _targets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ covid_hosp_prod:
script: covid_hosp_prod.R
store: covid_hosp_prod
use_crew: yes
forecaster_testing:
script: forecaster_testing.R
store: forecaster_testing
use_crew: no
5 changes: 0 additions & 5 deletions _targets/.gitignore

This file was deleted.

2 changes: 1 addition & 1 deletion covid_hosp_explore.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ tar_option_set(
imports = c("epieval", "parsnip"),
format = "qs", # Optionally set the default storage format. qs is fast.
controller = crew::crew_controller_local(workers = parallel::detectCores() - 5),
)
)
# Run the R scripts in the R/ folder with your custom functions:
# tar_source()
# where the forecasters and parameters are joined; see either the variable param_grid or `tar_read(forecasters)`
Expand Down
26 changes: 26 additions & 0 deletions man/confirm_insufficient_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading