Skip to content

testing forecasters on simple datasets #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
^renv$
^renv\.lock$
^LICENSE\.md$
^.lintr$
^.renvignore$
^.github$
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ export(add_id)
export(arx_postprocess)
export(arx_preprocess)
export(collapse_cards)
export(confirm_insufficient_data)
export(covidhub_probs)
export(evaluate_predictions)
export(extend_ahead)
export(flatline_fc)
export(forecaster_pred)
export(format_storage)
export(interval_coverage)
Expand Down Expand Up @@ -38,6 +40,7 @@ importFrom(purrr,map)
importFrom(purrr,transpose)
importFrom(rlang,.data)
importFrom(rlang,quo)
importFrom(rlang,sym)
importFrom(rlang,syms)
importFrom(tibble,tibble)
importFrom(tidyr,pivot_wider)
45 changes: 36 additions & 9 deletions R/forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,32 @@ perform_sanity_checks <- function(epi_data,
return(list(args_list, predictors, trainer))
}

#' confirm that there's enough data to run this model
#' @description
#' epipredict is a little bit fragile about having enough data to train; we want
#' to be able to return a null result rather than error out; this check say to
#' return a null
#' @param epi_data the input data
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
#' this trains on one sample; the default is set so that `linear_reg` isn't
#' rank deficient)
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
#' @param args_input the input as supplied to `forecaster_pred`; lags is the
#' important argument, which may or may not be defined, with the default
#' coming from `arx_args_list`
#' @export
confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
if (!is.null(args_input$lags)) {
lag_max <- max(args_input$lags)
} else {
lag_max <- 14
}
return(
is.infinite(ahead) ||
as.integer(max(epi_data$time_value) - min(epi_data$time_value)) <=
lag_max + ahead + 9
)
}
# TODO replace with `step_arx_forecaster`
#' add the default steps for arx_forecaster
#' @description
Expand Down Expand Up @@ -86,7 +112,6 @@ arx_postprocess <- function(postproc,
target_date = NULL) {
postproc %<>% layer_predict()
if (inherits(trainer, "quantile_reg")) {

postproc %<>% layer_quantile_distn(levels = args_list$levels) %>% layer_point_from_distn()
} else {
postproc %<>% layer_residual_quantiles(
Expand All @@ -98,8 +123,7 @@ arx_postprocess <- function(postproc,
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
}

postproc %<>% layer_naomit(dplyr::starts_with(".pred"))
postproc %<>% layer_add_forecast_date(forecast_date = forecast_date) %>%
postproc %<>% layer_naomit(dplyr::starts_with(".pred")) %>%
layer_add_target_date(target_date = target_date)
return(postproc)
}
Expand Down Expand Up @@ -162,6 +186,14 @@ forecaster_pred <- function(data,
if (length(forecaster_args) > 0) {
names(forecaster_args) <- forecaster_args_names
}
if (is.null(forecaster_args$ahead)) {
cli::cli_abort(
c(
"exploration-tooling error: forecaster_pred needs some value for ahead."
),
class = "explorationToolingError"
)
}
if (!is.numeric(forecaster_args$n_training) && !is.null(forecaster_args$n_training)) {
n_training <- as.numeric(forecaster_args$n_training)
net_slide_training <- max(slide_training, n_training) + n_training_pad
Expand All @@ -171,11 +203,6 @@ forecaster_pred <- function(data,
}
# restrict the dataset to areas where training is possible
start_date <- min(archive$DT$time_value) + net_slide_training
if (slide_training < Inf) {
start_date <- min(archive$DT$time_value) + slide_training + n_training_pad
} else {
start_date <- min(archive$DT$time_value) + n_training_pad
}
end_date <- max(archive$DT$time_value) - forecaster_args$ahead
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = 1)

Expand Down Expand Up @@ -206,7 +233,7 @@ forecaster_pred <- function(data,

# append the truth data
true_value <- archive$as_of(archive$versions_end) %>%
select(geo_value, time_value, outcome) %>%
select(geo_value, time_value, !!outcome) %>%
rename(true_value = !!outcome)
res %<>%
inner_join(true_value,
Expand Down
54 changes: 54 additions & 0 deletions R/forecaster_flatline.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#' flatline forecaster (aka baseline)
#' @description
#' a minimal forecaster whose median is just the last value
#' does not support `lags` as a parameter, but otherwise has the same parameters as `arx_forecaster`
#' @inheritParams scaled_pop
#' @importFrom rlang sym
#' @export
flatline_fc <- function(epi_data,
outcome,
extra_sources = "",
ahead = 1,
trainer = parsnip::linear_reg(),
levels = covidhub_probs(),
...) {
# perform any preprocessing not supported by epipredict
# one that every forecaster will need to handle: how to manage max(time_value)
# that's older than the `as_of` date
epidataAhead <- extend_ahead(epi_data, ahead)
# see latency_adjusting for other examples
# this next part is basically unavoidable boilerplate you'll want to copy
epi_data <- epidataAhead[[1]]
effective_ahead <- epidataAhead[[2]]
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = Date(),
target_end_date = Date(),
quantile = numeric(),
value = numeric()
)
return(null_result)
}
args_input[["ahead"]] <- effective_ahead
args_input[["levels"]] <- levels
args_list <- do.call(flatline_args_list, args_input)
# if you want to ignore extra_sources, setting predictors is the way to do it
predictors <- c(outcome, extra_sources)
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, NULL, args_list)
args_list <- argsPredictorsTrainer[[1]]
predictors <- argsPredictorsTrainer[[2]]
# end of the copypasta
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
# epipredict

# since this is just the flatline, we don't need much of anything
res <- flatline_forecaster(epi_data, outcome = outcome, args_list = args_list)
true_forecast_date <- attributes(epi_data)$metadata$as_of
pred <- format_storage(res$predictions, true_forecast_date)
# (geo_value, forecast_date, target_end_date, quantile, value)
# finally, any postprocessing not supported by epipredict e.g. calibration
return(pred)
}
9 changes: 4 additions & 5 deletions R/forecaster_scaled_pop.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
scaled_pop <- function(epi_data,
outcome,
extra_sources = "",
ahead=1,
ahead = 1,
pop_scaling = TRUE,
trainer = parsnip::linear_reg(),
levels = covidhub_probs(),
Expand All @@ -58,9 +58,9 @@ scaled_pop <- function(epi_data,
# this next part is basically unavoidable boilerplate you'll want to copy
epi_data <- epidataAhead[[1]]
effective_ahead <- epidataAhead[[2]]
# edge case where there is no data; eventually epipredict will handle this
if (is.infinite(effective_ahead)) {
effective_ahead <- 0
args_input <- list(...)
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
null_result <- tibble(
geo_value = character(),
forecast_date = Date(),
Expand All @@ -70,7 +70,6 @@ scaled_pop <- function(epi_data,
)
return(null_result)
}
args_input <- list(...)
args_input[["ahead"]] <- effective_ahead
args_input[["levels"]] <- levels
args_list <- do.call(arx_args_list, args_input)
Expand Down
2 changes: 1 addition & 1 deletion R/formatters.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
.dstn = nested_quantiles(.pred_distn)
) %>%
unnest(.dstn) %>%
select(-.pred_distn, -.pred, -time_value) %>%
select(-any_of(c(".pred_distn", ".pred", "time_value"))) %>%
rename(quantile = tau, value = q, target_end_date = target_date) %>%
relocate(geo_value, forecast_date, target_end_date, quantile, value)
}
Expand Down
26 changes: 26 additions & 0 deletions man/confirm_insufficient_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions man/flatline_fc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 0 additions & 45 deletions tests/testthat/test-example_spec.R

This file was deleted.

55 changes: 55 additions & 0 deletions tests/testthat/test-forecasters-basics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# TODO better way to do this than copypasta
forecasters <- list(
c("scaled_pop", scaled_pop),
c("flatline_fc", flatline_fc)
)
forecaster <- c("flatline", flatline_fc)
for (forecaster in forecasters) {
test_that(forecaster[[1]], {
jhu <- case_death_rate_subset %>%
dplyr::filter(time_value >= as.Date("2021-12-01"))
# the as_of for this is wildly far in the future
attributes(jhu)$metadata$as_of <- max(jhu$time_value) + 3
res <- forecaster[[2]](jhu, "case_rate", c("death_rate"), -2L)
expect_equal(
names(res),
c("geo_value", "forecast_date", "target_end_date", "quantile", "value")
)
expect_true(all(
res$target_end_date ==
as.Date("2022-01-01")
))
# any forecaster specific tests
if (forecaster[[1]] == "scaled_pop") {
# confirm scaling produces different results
res_unscaled <- forecaster[[2]](jhu,
"case_rate",
c("death_rate"),
-2L,
pop_scaling = FALSE
)
expect_false(res_unscaled %>%
full_join(res,
by = join_by(geo_value, forecast_date, target_end_date, quantile),
suffix = c(".unscaled", ".scaled")
) %>%
mutate(equal = value.unscaled == value.scaled) %>%
summarize(all(equal)) %>% pull(`all(equal)`))
}
# TODO confirming that it produces exactly the same result as arx_forecaster
# test case where extra_sources is "empty"
forecaster[[2]](
jhu,
"case_rate",
c(""),
1L
)
# test case where the epi_df is empty
null_jhu <- jhu %>% filter(time_value < as.Date("0009-01-01"))
expect_no_error(null_res <- forecaster[[2]](null_jhu, "case_rate", c("death_rate")))
null_res <- forecaster[[2]](null_jhu, "case_rate", c("death_rate"))
expect_identical(names(null_res), names(res))
expect_equal(nrow(null_res), 0)
expect_identical(null_res, tibble(geo_value = character(), forecast_date = Date(), target_end_date = Date(), quantile = numeric(), value = numeric()))
})
}
Loading