|
| 1 | +# TODO replace with `step_arx_forecaster` |
| 2 | +#' add the default steps for arx_forecaster |
| 3 | +#' @description |
| 4 | +#' add the default steps for arx_forecaster |
| 5 | +#' @param rec an [`epipredict::epi_recipe`] |
| 6 | +#' @param outcome a character of the column to be predicted |
| 7 | +#' @param predictors a character vector of the columns used as predictors |
| 8 | +#' @param args_list an [`epipredict::arx_args_list`] |
| 9 | +#' @seealso [arx_postprocess] for the layer equivalent |
| 10 | +#' @importFrom epipredict step_epi_lag step_epi_ahead step_epi_naomit step_training_window |
| 11 | +#' @export |
| 12 | +arx_preprocess <- function(rec, outcome, predictors, args_list) { |
| 13 | + # input already validated |
| 14 | + lags <- args_list$lags |
| 15 | + for (l in seq_along(lags)) { |
| 16 | + p <- predictors[l] |
| 17 | + rec %<>% step_epi_lag(!!p, lag = lags[[l]]) |
| 18 | + } |
| 19 | + rec %<>% |
| 20 | + step_epi_ahead(!!outcome, ahead = args_list$ahead) %>% |
| 21 | + step_epi_naomit() %>% |
| 22 | + step_training_window(n_recent = args_list$n_training) |
| 23 | + return(rec) |
| 24 | +} |
| 25 | + |
| 26 | +# TODO replace with `layer_arx_forecaster` |
| 27 | +#' add the default layers for arx_forecaster |
| 28 | +#' @description |
| 29 | +#' add the default layers for arx_forecaster |
| 30 | +#' @param postproc an [`epipredict::frosting`] |
| 31 | +#' @param trainer the trainer used (e.g. linear_reg() or quantile_reg()) |
| 32 | +#' @param args_list an [`epipredict::arx_args_list`] |
| 33 | +#' @param forecast_date the date from which the forecast was made. defaults to |
| 34 | +#' the default of `layer_add_forecast_date`, which is currently the max |
| 35 | +#' time_value present in the data |
| 36 | +#' @param target_date the date about which the forecast was made. defaults to |
| 37 | +#' the default of `layer_add_target_date`, which is either |
| 38 | +#' `forecast_date+ahead`, or the `max time_value + ahead` |
| 39 | +#' @seealso [arx_preprocess] for the step equivalent |
| 40 | +#' @importFrom epipredict layer_predict layer_quantile_distn layer_point_from_distn layer_residual_quantiles layer_threshold layer_naomit layer_add_target_date |
| 41 | +#' @export |
| 42 | +arx_postprocess <- function(postproc, |
| 43 | + trainer, |
| 44 | + args_list, |
| 45 | + forecast_date = NULL, |
| 46 | + target_date = NULL) { |
| 47 | + postproc %<>% layer_predict() |
| 48 | + if (inherits(trainer, "quantile_reg")) { |
| 49 | + postproc %<>% |
| 50 | + layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>% |
| 51 | + layer_point_from_distn() |
| 52 | + } else { |
| 53 | + postproc %<>% layer_residual_quantiles( |
| 54 | + quantile_levels = args_list$quantile_levels, symmetrize = args_list$symmetrize, |
| 55 | + by_key = args_list$quantile_by_key |
| 56 | + ) |
| 57 | + } |
| 58 | + if (args_list$nonneg) { |
| 59 | + postproc %<>% layer_threshold(dplyr::starts_with(".pred")) |
| 60 | + } |
| 61 | + |
| 62 | + postproc %<>% |
| 63 | + layer_naomit(dplyr::starts_with(".pred")) %>% |
| 64 | + layer_add_target_date(target_date = target_date) |
| 65 | + return(postproc) |
| 66 | +} |
| 67 | + |
| 68 | +#' helper function to run a epipredict model and reformat to hub format |
| 69 | +#' @description |
| 70 | +#' helper function to run a epipredict model and reformat to hub format |
| 71 | +#' @param preproc the preprocessing steps |
| 72 | +#' @param postproc the postprocessing frosting |
| 73 | +#' @param trainer the parsnip trainer |
| 74 | +#' @param epi_data the actual epi_df to train on |
| 75 | +#' @export |
| 76 | +#' @importFrom epipredict epi_workflow fit add_frosting get_test_data |
| 77 | +run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) { |
| 78 | + workflow <- epi_workflow(preproc, trainer) %>% |
| 79 | + fit(epi_data) %>% |
| 80 | + add_frosting(postproc) |
| 81 | + latest <- get_test_data(recipe = preproc, x = epi_data) |
| 82 | + pred <- predict(workflow, latest) |
| 83 | + # the forecast_date may currently be the max time_value |
| 84 | + as_of <- attributes(epi_data)$metadata$as_of |
| 85 | + if (is.null(as_of)) { |
| 86 | + as_of <- max(epi_data$time_value) |
| 87 | + } |
| 88 | + true_forecast_date <- as_of |
| 89 | + return(format_storage(pred, true_forecast_date)) |
| 90 | +} |
0 commit comments