Skip to content

Commit 806119b

Browse files
authored
Merge pull request #75 from cmu-delphi/refactor
Minor Refactor
2 parents b762470 + c67d89d commit 806119b

22 files changed

+366
-424
lines changed

NAMESPACE

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ export(evaluate_predictions)
1515
export(extend_ahead)
1616
export(flatline_fc)
1717
export(forecaster_lookup)
18-
export(forecaster_pred)
1918
export(format_storage)
2019
export(id_ahead_ensemble_grid)
2120
export(interval_coverage)
@@ -28,7 +27,6 @@ export(make_shared_ensembles)
2827
export(make_shared_grids)
2928
export(make_target_ensemble_grid)
3029
export(make_target_param_grid)
31-
export(manage_S3_forecast_cache)
3230
export(overprediction)
3331
export(perform_sanity_checks)
3432
export(read_external_predictions_data)
@@ -37,11 +35,10 @@ export(run_workflow_and_format)
3735
export(scaled_pop)
3836
export(sharpness)
3937
export(single_id)
38+
export(slide_forecaster)
4039
export(underprediction)
4140
export(weighted_interval_score)
4241
importFrom(assertthat,assert_that)
43-
importFrom(aws.s3,get_bucket)
44-
importFrom(aws.s3,s3sync)
4542
importFrom(cli,cli_abort)
4643
importFrom(cli,hash_animal)
4744
importFrom(dplyr,across)
@@ -85,7 +82,6 @@ importFrom(epipredict,step_population_scaling)
8582
importFrom(epipredict,step_training_window)
8683
importFrom(epiprocess,as_epi_df)
8784
importFrom(epiprocess,epix_slide)
88-
importFrom(here,here)
8985
importFrom(magrittr,"%<>%")
9086
importFrom(magrittr,"%>%")
9187
importFrom(purrr,imap)

R/data_validation.R

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#' helper function for those writing forecasters
2+
#' @description
3+
#' a smorgasbord of checks that any epipredict-based forecaster should do:
4+
#' 1. check that the args list is created correctly,
5+
#' 2. rewrite an empty extra sources list from an empty string
6+
#' 3. validate the outcome and predictors as present,
7+
#' 4. make sure the trainer is a `regression` model from `parsnip`
8+
#' 5. adjust the trainer's quantiles based on those in args_list if it's a
9+
#' quantile trainer
10+
#' 6. remake the lags to match the numebr of predictors
11+
#' @inheritParams scaled_pop
12+
#' @param predictors the full list of predictors including the outcome. can
13+
#' include empty strings
14+
#' @param args_list the args list created by [`epipredict::arx_args_list`]
15+
#' @export
16+
perform_sanity_checks <- function(epi_data,
17+
outcome,
18+
predictors,
19+
trainer,
20+
args_list) {
21+
if (!inherits(args_list, c("arx_fcast", "alist"))) {
22+
cli::cli_abort("args_list was not created using `arx_args_list().")
23+
}
24+
25+
predictors <- predictors[predictors != ""]
26+
epipredict:::validate_forecaster_inputs(epi_data, outcome, predictors)
27+
28+
if (!is.null(trainer) && !epipredict:::is_regression(trainer)) {
29+
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
30+
} else if (inherits(trainer, "quantile_reg")) {
31+
# add all quantile_levels to the trainer and update args list
32+
quantile_levels <- sort(epipredict:::compare_quantile_args(
33+
args_list$quantile_levels,
34+
rlang::eval_tidy(trainer$args$quantile_levels)
35+
))
36+
args_list$quantile_levels <- quantile_levels
37+
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
38+
}
39+
args_list$lags <- epipredict:::arx_lags_validator(predictors, args_list$lags)
40+
return(list(args_list, predictors, trainer))
41+
}
42+
43+
#' confirm that there's enough data to run this model
44+
#' @description
45+
#' epipredict is a little bit fragile about having enough data to train; we want
46+
#' to be able to return a null result rather than error out.
47+
#' @param epi_data the input data
48+
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
49+
#' @param args_input the input as supplied to `slide_forecaster`; lags is the
50+
#' important argument, which may or may not be defined, with the default
51+
#' coming from `arx_args_list`
52+
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
53+
#' this trains on one sample; the default is set so that `linear_reg` isn't
54+
#' rank deficient)
55+
#' @importFrom tidyr drop_na
56+
#' @export
57+
confirm_sufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
58+
if (!is.null(args_input$lags)) {
59+
lag_max <- max(args_input$lags)
60+
} else {
61+
lag_max <- 14 # default value of 2 weeks
62+
}
63+
64+
# TODO: Buffer should probably be 2 * n(lags) * n(predictors). But honestly,
65+
# this needs to be fixed in epipredict itself, see
66+
# https://github.com/cmu-delphi/epipredict/issues/106.
67+
68+
return(
69+
!is.infinite(ahead) &&
70+
epi_data %>%
71+
drop_na() %>%
72+
group_by(geo_value) %>%
73+
summarise(has_enough_data = n_distinct(time_value) >= lag_max + ahead + buffer) %>%
74+
pull(has_enough_data) %>%
75+
any()
76+
)
77+
}

R/epipredict_utilities.R

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# TODO replace with `step_arx_forecaster`
2+
#' add the default steps for arx_forecaster
3+
#' @description
4+
#' add the default steps for arx_forecaster
5+
#' @param rec an [`epipredict::epi_recipe`]
6+
#' @param outcome a character of the column to be predicted
7+
#' @param predictors a character vector of the columns used as predictors
8+
#' @param args_list an [`epipredict::arx_args_list`]
9+
#' @seealso [arx_postprocess] for the layer equivalent
10+
#' @importFrom epipredict step_epi_lag step_epi_ahead step_epi_naomit step_training_window
11+
#' @export
12+
arx_preprocess <- function(rec, outcome, predictors, args_list) {
13+
# input already validated
14+
lags <- args_list$lags
15+
for (l in seq_along(lags)) {
16+
p <- predictors[l]
17+
rec %<>% step_epi_lag(!!p, lag = lags[[l]])
18+
}
19+
rec %<>%
20+
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
21+
step_epi_naomit() %>%
22+
step_training_window(n_recent = args_list$n_training)
23+
return(rec)
24+
}
25+
26+
# TODO replace with `layer_arx_forecaster`
27+
#' add the default layers for arx_forecaster
28+
#' @description
29+
#' add the default layers for arx_forecaster
30+
#' @param postproc an [`epipredict::frosting`]
31+
#' @param trainer the trainer used (e.g. linear_reg() or quantile_reg())
32+
#' @param args_list an [`epipredict::arx_args_list`]
33+
#' @param forecast_date the date from which the forecast was made. defaults to
34+
#' the default of `layer_add_forecast_date`, which is currently the max
35+
#' time_value present in the data
36+
#' @param target_date the date about which the forecast was made. defaults to
37+
#' the default of `layer_add_target_date`, which is either
38+
#' `forecast_date+ahead`, or the `max time_value + ahead`
39+
#' @seealso [arx_preprocess] for the step equivalent
40+
#' @importFrom epipredict layer_predict layer_quantile_distn layer_point_from_distn layer_residual_quantiles layer_threshold layer_naomit layer_add_target_date
41+
#' @export
42+
arx_postprocess <- function(postproc,
43+
trainer,
44+
args_list,
45+
forecast_date = NULL,
46+
target_date = NULL) {
47+
postproc %<>% layer_predict()
48+
if (inherits(trainer, "quantile_reg")) {
49+
postproc %<>%
50+
layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>%
51+
layer_point_from_distn()
52+
} else {
53+
postproc %<>% layer_residual_quantiles(
54+
quantile_levels = args_list$quantile_levels, symmetrize = args_list$symmetrize,
55+
by_key = args_list$quantile_by_key
56+
)
57+
}
58+
if (args_list$nonneg) {
59+
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
60+
}
61+
62+
postproc %<>%
63+
layer_naomit(dplyr::starts_with(".pred")) %>%
64+
layer_add_target_date(target_date = target_date)
65+
return(postproc)
66+
}
67+
68+
#' helper function to run a epipredict model and reformat to hub format
69+
#' @description
70+
#' helper function to run a epipredict model and reformat to hub format
71+
#' @param preproc the preprocessing steps
72+
#' @param postproc the postprocessing frosting
73+
#' @param trainer the parsnip trainer
74+
#' @param epi_data the actual epi_df to train on
75+
#' @export
76+
#' @importFrom epipredict epi_workflow fit add_frosting get_test_data
77+
run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
78+
workflow <- epi_workflow(preproc, trainer) %>%
79+
fit(epi_data) %>%
80+
add_frosting(postproc)
81+
latest <- get_test_data(recipe = preproc, x = epi_data)
82+
pred <- predict(workflow, latest)
83+
# the forecast_date may currently be the max time_value
84+
as_of <- attributes(epi_data)$metadata$as_of
85+
if (is.null(as_of)) {
86+
as_of <- max(epi_data$time_value)
87+
}
88+
true_forecast_date <- as_of
89+
return(format_storage(pred, true_forecast_date))
90+
}

0 commit comments

Comments
 (0)