Skip to content

testing forecasters on simple datasets #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
^renv$
^renv\.lock$
^LICENSE\.md$
^.lintr$
^.renvignore$
^.github$
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ export(add_id)
export(arx_postprocess)
export(arx_preprocess)
export(collapse_cards)
export(confirm_insufficient_data)
export(covidhub_probs)
export(evaluate_predictions)
export(extend_ahead)
export(flatline_fc)
export(forecaster_pred)
export(format_storage)
export(interval_coverage)
Expand Down Expand Up @@ -38,6 +40,7 @@ importFrom(purrr,map)
importFrom(purrr,transpose)
importFrom(rlang,.data)
importFrom(rlang,quo)
importFrom(rlang,sym)
importFrom(rlang,syms)
importFrom(tibble,tibble)
importFrom(tidyr,pivot_wider)
45 changes: 36 additions & 9 deletions R/forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,32 @@ perform_sanity_checks <- function(epi_data,
return(list(args_list, predictors, trainer))
}

#' confirm that there's enough data to run this model
#' @description
#' epipredict is a little bit fragile about having enough data to train; we want
#' to be able to return a null result rather than error out; this check say to
#' return a null
#' @param epi_data the input data
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
#' this trains on one sample; the default is set so that `linear_reg` isn't
#' rank deficient)
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
#' @param args_input the input as supplied to `forecaster_pred`; lags is the
#' important argument, which may or may not be defined, with the default
#' coming from `arx_args_list`
#' @export
confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
if (!is.null(args_input$lags)) {
lag_max <- max(args_input$lags)
} else {
lag_max <- 14 # default value of 2 weeks
}
return(
is.infinite(ahead) ||
as.integer(max(epi_data$time_value) - min(epi_data$time_value)) <=
lag_max + ahead + buffer
)
}
# TODO replace with `step_arx_forecaster`
#' add the default steps for arx_forecaster
#' @description
Expand Down Expand Up @@ -86,7 +112,6 @@ arx_postprocess <- function(postproc,
target_date = NULL) {
postproc %<>% layer_predict()
if (inherits(trainer, "quantile_reg")) {

postproc %<>% layer_quantile_distn(levels = args_list$levels) %>% layer_point_from_distn()
} else {
postproc %<>% layer_residual_quantiles(
Expand All @@ -98,8 +123,7 @@ arx_postprocess <- function(postproc,
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
}

postproc %<>% layer_naomit(dplyr::starts_with(".pred"))
postproc %<>% layer_add_forecast_date(forecast_date = forecast_date) %>%
postproc %<>% layer_naomit(dplyr::starts_with(".pred")) %>%
layer_add_target_date(target_date = target_date)
return(postproc)
}
Expand Down Expand Up @@ -162,6 +186,14 @@ forecaster_pred <- function(data,
if (length(forecaster_args) > 0) {
names(forecaster_args) <- forecaster_args_names
}
if (is.null(forecaster_args$ahead)) {
cli::cli_abort(
c(
"exploration-tooling error: forecaster_pred needs some value for ahead."
),
class = "explorationToolingError"
)
}
if (!is.numeric(forecaster_args$n_training) && !is.null(forecaster_args$n_training)) {
n_training <- as.numeric(forecaster_args$n_training)
net_slide_training <- max(slide_training, n_training) + n_training_pad
Expand All @@ -171,11 +203,6 @@ forecaster_pred <- function(data,
}
# restrict the dataset to areas where training is possible
start_date <- min(archive$DT$time_value) + net_slide_training
if (slide_training < Inf) {
start_date <- min(archive$DT$time_value) + slide_training + n_training_pad
} else {
start_date <- min(archive$DT$time_value) + n_training_pad
}
end_date <- max(archive$DT$time_value) - forecaster_args$ahead
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = 1)

Expand Down Expand Up @@ -206,7 +233,7 @@ forecaster_pred <- function(data,

# append the truth data
true_value <- archive$as_of(archive$versions_end) %>%
select(geo_value, time_value, outcome) %>%
select(geo_value, time_value, !!outcome) %>%
rename(true_value = !!outcome)
res %<>%
inner_join(true_value,
Expand Down
54 changes: 54 additions & 0 deletions R/forecaster_flatline.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#' flatline forecaster (aka baseline)
#' @description
#' a minimal forecaster whose median is just the last value
#' does not support `lags` as a parameter, but otherwise has the same parameters as `arx_forecaster`
#' @inheritParams scaled_pop
#' @importFrom rlang sym
#' @export
flatline_fc <- function(epi_data,
outcome,
extra_sources = "",
ahead = 1,
trainer = parsnip::linear_reg(),
levels = covidhub_probs(),
...) {
# perform any preprocessing not supported by epipredict
# one that every forecaster will need to handle: how to manage max(time_value)
# that's older than the `as_of` date
epidataAhead <- extend_ahead(epi_data, ahead)
# see latency_adjusting for other examples
# this next part is basically unavoidable boilerplate you'll want to copy
epi_data <- epidataAhead[[1]]
effective_ahead <- epidataAhead[[2]]
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = Date(),
target_end_date = Date(),
quantile = numeric(),
value = numeric()
)
return(null_result)
}
args_input[["ahead"]] <- effective_ahead
args_input[["levels"]] <- levels
args_list <- do.call(flatline_args_list, args_input)
# if you want to ignore extra_sources, setting predictors is the way to do it
predictors <- c(outcome, extra_sources)
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, NULL, args_list)
args_list <- argsPredictorsTrainer[[1]]
predictors <- argsPredictorsTrainer[[2]]
# end of the copypasta
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
# epipredict

# since this is just the flatline, we don't need much of anything
res <- flatline_forecaster(epi_data, outcome = outcome, args_list = args_list)
true_forecast_date <- attributes(epi_data)$metadata$as_of
pred <- format_storage(res$predictions, true_forecast_date)
# (geo_value, forecast_date, target_end_date, quantile, value)
# finally, any postprocessing not supported by epipredict e.g. calibration
return(pred)
}
9 changes: 4 additions & 5 deletions R/forecaster_scaled_pop.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
scaled_pop <- function(epi_data,
outcome,
extra_sources = "",
ahead=1,
ahead = 1,
pop_scaling = TRUE,
trainer = parsnip::linear_reg(),
levels = covidhub_probs(),
Expand All @@ -58,9 +58,9 @@ scaled_pop <- function(epi_data,
# this next part is basically unavoidable boilerplate you'll want to copy
epi_data <- epidataAhead[[1]]
effective_ahead <- epidataAhead[[2]]
# edge case where there is no data; eventually epipredict will handle this
if (is.infinite(effective_ahead)) {
effective_ahead <- 0
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = Date(),
Expand All @@ -70,7 +70,6 @@ scaled_pop <- function(epi_data,
)
return(null_result)
}
args_input <- list(...)
args_input[["ahead"]] <- effective_ahead
args_input[["levels"]] <- levels
args_list <- do.call(arx_args_list, args_input)
Expand Down
2 changes: 1 addition & 1 deletion R/formatters.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
.dstn = nested_quantiles(.pred_distn)
) %>%
unnest(.dstn) %>%
select(-.pred_distn, -.pred, -time_value) %>%
select(-any_of(c(".pred_distn", ".pred", "time_value"))) %>%
rename(quantile = tau, value = q, target_end_date = target_date) %>%
relocate(geo_value, forecast_date, target_end_date, quantile, value)
}
Expand Down
57 changes: 52 additions & 5 deletions R/small_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,63 @@ covidhub_probs <- function(type = c("standard", "inc_case")) {
#' @importFrom cli hash_animal
#' @export
add_id <- function(df, n_adj = 2) {
stringified <- df %>%
select(-ahead) %>%
no_ahead <- df %>%
select(-ahead)
stringified <- no_ahead %>%
select(order(colnames(no_ahead))) %>%
rowwise() %>%
mutate(id = paste(across(everything()), collapse = ""), .keep="none") %>%
mutate(id = paste(across(everything()), sep = "", collapse = ""), .keep = "none") %>%
mutate(id = hash_animal(id, n_adj = n_adj)$words) %>%
mutate(id = paste(id[1:n_adj], sep="", collapse = " "))
mutate(id = paste(id[1:n_adj], sep = "", collapse = "."))
df %<>%
ungroup %>%
mutate(parent_id = stringified$id) %>%
rowwise() %>%
mutate(id = paste(parent_id, ahead, collapse = " ")) %>%
mutate(id = paste(parent_id, ahead, sep = ".", collapse = " ")) %>%
ungroup()
return(df)
}

#' generate an id from a simple list of parameters
#' @param param_list the list of parameters. must include `ahead` if `ahead = NULL`
#' @param ahead the ahead to use.
#' @inheritParams add_id
#' @export
single_id <- function(param_list, ahead = NULL, n_adj = 2) {
full_hash <- param_list[names(param_list) != "ahead"] %>%
.[order(names(.))] %>% # put in alphabetical order
paste(collapse = "") %>%
hash_animal(n_adj = n_adj)
single_string <- full_hash$words[[1]][1:n_adj] %>% paste(sep = ".", collapse = ".")
if (is.null(ahead)) {
full_name <- paste(single_string, param_list$ahead, sep = ".")
} else {
full_name <- paste(single_string, ahead, sep = ".")
}
return(full_name)
}
#' given target name(s), lookup the corresponding parameters
#' @export
lookup_ids <- function() {
}


#' add aheads, forecaster_ids, and ids to a list of ensemble models
#' @description
#' minor utility
#' @param ensemble_grid the list of ensembles,
#' @param aheads the aheads to add
#' @inheritParams add_id
#' @export
id_ahead_ensemble_grid <- function(ensemble_grid, aheads, n_adj = 2) {
ensemble_grid <- expand_grid(
ensemble_grid,
tibble(ahead = aheads)
)

ensemble_grid %<>%
add_id(., n_adj = 2) %>%
rowwise %>%
mutate(forecaster_ids = list(map2_vec(forecasters, ahead, single_id, n_adj = 2)))
return(ensemble_grid)
}
2 changes: 1 addition & 1 deletion covid_hosp_explore.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ tar_option_set(
imports = c("epieval", "parsnip"),
format = "qs", # Optionally set the default storage format. qs is fast.
controller = crew::crew_controller_local(workers = parallel::detectCores() - 5),
)
)
# Run the R scripts in the R/ folder with your custom functions:
# tar_source()
# where the forecasters and parameters are joined; see either the variable param_grid or `tar_read(forecasters)`
Expand Down
26 changes: 26 additions & 0 deletions man/confirm_insufficient_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions man/flatline_fc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions renv.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1926,12 +1926,10 @@
"renv": {
"Package": "renv",
"Version": "1.0.3",
"Source": "Repository",
"Repository": "CRAN",
"Requirements": [
"utils"
],
"Hash": "41b847654f567341725473431dd0d5ab"
"OS_type": null,
"NeedsCompilation": "no",
"Repository": "RSPM",
"Source": "Repository"
},
"rlang": {
"Package": "rlang",
Expand Down
Loading