Skip to content

feat: flu hosp explore #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
tmp/
extras/**.html
*.pdf
.Renviron
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Imports:
purrr,
recipes (>= 1.0.4),
rlang,
targets,
tibble,
tidyr
Suggests:
Expand Down
13 changes: 12 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ export(absolute_error)
export(add_id)
export(arx_postprocess)
export(arx_preprocess)
export(clear_lastminute_nas)
export(collapse_cards)
export(confirm_insufficient_data)
export(confirm_sufficient_data)
export(covidhub_probs)
export(evaluate_predictions)
export(extend_ahead)
Expand All @@ -15,6 +16,12 @@ export(format_storage)
export(id_ahead_ensemble_grid)
export(interval_coverage)
export(lookup_ids)
export(make_data_targets)
export(make_ensemble_targets)
export(make_external_names_and_scores)
export(make_forecasts_and_scores)
export(make_forecasts_and_scores_by_ahead)
export(make_shared_grids)
export(make_target_param_grid)
export(manage_S3_forecast_cache)
export(overprediction)
Expand Down Expand Up @@ -71,6 +78,7 @@ importFrom(epipredict,step_epi_lag)
importFrom(epipredict,step_epi_naomit)
importFrom(epipredict,step_population_scaling)
importFrom(epipredict,step_training_window)
importFrom(epiprocess,as_epi_df)
importFrom(epiprocess,epix_slide)
importFrom(here,here)
importFrom(magrittr,"%<>%")
Expand All @@ -84,7 +92,10 @@ importFrom(rlang,.data)
importFrom(rlang,quo)
importFrom(rlang,sym)
importFrom(rlang,syms)
importFrom(targets,tar_group)
importFrom(targets,tar_target)
importFrom(tibble,tibble)
importFrom(tidyr,drop_na)
importFrom(tidyr,expand_grid)
importFrom(tidyr,pivot_wider)
importFrom(tidyr,unnest)
85 changes: 62 additions & 23 deletions R/forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,39 @@ perform_sanity_checks <- function(epi_data,
#' confirm that there's enough data to run this model
#' @description
#' epipredict is a little bit fragile about having enough data to train; we want
#' to be able to return a null result rather than error out; this check say to
#' return a null
#' to be able to return a null result rather than error out.
#' @param epi_data the input data
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
#' this trains on one sample; the default is set so that `linear_reg` isn't
#' rank deficient)
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
#' @param args_input the input as supplied to `forecaster_pred`; lags is the
#' important argument, which may or may not be defined, with the default
#' coming from `arx_args_list`
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
#' this trains on one sample; the default is set so that `linear_reg` isn't
#' rank deficient)
#' @importFrom tidyr drop_na
#' @export
confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
confirm_sufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
if (!is.null(args_input$lags)) {
lag_max <- max(args_input$lags)
} else {
lag_max <- 14 # default value of 2 weeks
}

# TODO: Buffer should probably be 2 * n(lags) * n(predictors). But honestly,
# this needs to be fixed in epipredict itself, see
# https://github.com/cmu-delphi/epipredict/issues/106.

return(
is.infinite(ahead) ||
as.integer(max(epi_data$time_value) - min(epi_data$time_value)) <=
lag_max + ahead + buffer
!is.infinite(ahead) &&
epi_data %>%
drop_na() %>%
group_by(geo_value) %>%
summarise(has_enough_data = n_distinct(time_value) >= lag_max + ahead + buffer) %>%
pull(has_enough_data) %>%
any()
)
}

# TODO replace with `step_arx_forecaster`
#' add the default steps for arx_forecaster
#' @description
Expand Down Expand Up @@ -149,7 +159,11 @@ run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
latest <- get_test_data(recipe = preproc, x = epi_data)
pred <- predict(workflow, latest)
# the forecast_date may currently be the max time_value
true_forecast_date <- attributes(epi_data)$metadata$as_of
as_of <- attributes(epi_data)$metadata$as_of
if (is.null(as_of)) {
as_of <- max(epi_data$time_value)
}
true_forecast_date <- as_of
return(format_storage(pred, true_forecast_date))
}

Expand All @@ -176,6 +190,8 @@ run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
#' contain `ahead`
#' @param forecaster_args_names a bit of a hack around targets, it contains
#' the names of the `forecaster_args`.
#' @param date_range_step_size the step size (in days) to use when generating
#' the forecast dates.
#' @importFrom epiprocess epix_slide
#' @importFrom cli cli_abort
#' @importFrom rlang !!
Expand All @@ -187,7 +203,8 @@ forecaster_pred <- function(data,
slide_training = 0,
n_training_pad = 5,
forecaster_args = list(),
forecaster_args_names = list()) {
forecaster_args_names = list(),
date_range_step_size = 1L) {
archive <- data
if (length(forecaster_args) > 0) {
names(forecaster_args) <- forecaster_args_names
Expand All @@ -210,25 +227,47 @@ forecaster_pred <- function(data,
# restrict the dataset to areas where training is possible
start_date <- min(archive$DT$time_value) + net_slide_training
end_date <- max(archive$DT$time_value) - forecaster_args$ahead
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = 1)
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = date_range_step_size)

# first generate the forecasts
before <- n_training + n_training_pad - 1
## TODO epix_slide doesn't support infinite `before`
## TODO: epix_slide doesn't support infinite `before`
## https://github.com/cmu-delphi/epiprocess/issues/219
if (before == Inf) before <- 365L * 10000
res <- epix_slide(archive,
function(data, gk, rtv, ...) {
do.call(
forecaster,
append(
list(
epi_data = data,
outcome = outcome,
extra_sources = extra_sources
),
forecaster_args
)
# TODO: Can we get rid of this tryCatch and instead hook it up to targets
# error handling or something else?
# https://github.com/cmu-delphi/exploration-tooling/issues/41
tryCatch(
{
do.call(
forecaster,
append(
list(
epi_data = data,
outcome = outcome,
extra_sources = extra_sources
),
forecaster_args
)
)
},
error = function(e) {
if (interactive()) {
browser()
} else {
dump_vars <- list(
data = data,
rtv = rtv,
forecaster = forecaster,
forecaster_args = forecaster_args,
e = e
)
saveRDS(dump_vars, "forecaster_pred_error.rds")
e
}
}
)
},
before = before,
Expand Down
7 changes: 6 additions & 1 deletion R/forecaster_flatline.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ flatline_fc <- function(epi_data,
quantile_levels = covidhub_probs(),
...) {
# perform any preprocessing not supported by epipredict
# this is a temp fix until a real fix gets put into epipredict
epi_data <- clear_lastminute_nas(epi_data)
# one that every forecaster will need to handle: how to manage max(time_value)
# that's older than the `as_of` date
epidataAhead <- extend_ahead(epi_data, ahead)
Expand All @@ -23,7 +25,7 @@ flatline_fc <- function(epi_data,
effective_ahead <- epidataAhead[[2]]
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
if (!confirm_sufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = lubridate::Date(),
Expand All @@ -48,6 +50,9 @@ flatline_fc <- function(epi_data,
# since this is just the flatline, we don't need much of anything
res <- flatline_forecaster(epi_data, outcome = outcome, args_list = args_list)
true_forecast_date <- attributes(epi_data)$metadata$as_of
if (is.null(true_forecast_date)) {
true_forecast_date <- max(epi_data$time_value)
}
pred <- format_storage(res$predictions, true_forecast_date)
# (geo_value, forecast_date, target_end_date, quantile, value)
# finally, any postprocessing not supported by epipredict e.g. calibration
Expand Down
6 changes: 4 additions & 2 deletions R/forecaster_scaled_pop.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ scaled_pop <- function(epi_data,
quantile_levels = covidhub_probs(),
...) {
# perform any preprocessing not supported by epipredict
# this is a temp fix until a real fix gets put into epipredict
epi_data <- clear_lastminute_nas(epi_data)
# one that every forecaster will need to handle: how to manage max(time_value)
# that's older than the `as_of` date
epidataAhead <- extend_ahead(epi_data, ahead)
Expand All @@ -58,7 +60,7 @@ scaled_pop <- function(epi_data,
effective_ahead <- epidataAhead[[2]]
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
if (!confirm_sufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = lubridate::Date(),
Expand All @@ -73,6 +75,7 @@ scaled_pop <- function(epi_data,
args_list <- do.call(arx_args_list, args_input)
# if you want to ignore extra_sources, setting predictors is the way to do it
predictors <- c(outcome, extra_sources)
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)
args_list <- argsPredictorsTrainer[[1]]
predictors <- argsPredictorsTrainer[[2]]
Expand All @@ -98,7 +101,6 @@ scaled_pop <- function(epi_data,
# postprocessing supported by epipredict
postproc <- frosting()
postproc %<>% arx_postprocess(trainer, args_list)
postproc
if (pop_scaling) {
postproc %<>% layer_population_scaling(
.pred, .pred_distn,
Expand Down
9 changes: 6 additions & 3 deletions R/latency_adjusting.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
extend_ahead <- function(epi_data, ahead) {
time_values <- epi_data$time_value
if (length(time_values) > 0) {
as_of <- attributes(epi_data)$metadata$as_of
max_time <- max(time_values)
if (is.null(as_of)) {
as_of <- max_time
}
effective_ahead <- as.integer(
as.Date(attributes(epi_data)$metadata$as_of) -
max(time_values) +
ahead
as.Date(as_of) - max_time + ahead
)
} else {
effective_ahead <- Inf
Expand Down
Loading