Skip to content

242 quantile renaming #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
^musings$
^data-raw$
^vignettes/articles$
^.git-blame-ignore-revs$
11 changes: 10 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ import(distributional)
import(epiprocess)
import(parsnip)
import(recipes)
import(vctrs)
importFrom(cli,cli_abort)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
Expand Down Expand Up @@ -194,3 +194,12 @@ importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,is_tibble)
importFrom(tibble,tibble)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
importFrom(vctrs,new_vctr)
importFrom(vctrs,vec_cast)
importFrom(vctrs,vec_data)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_recycle_common)
5 changes: 4 additions & 1 deletion R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ arx_class_args_list <- function(
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf) {
nafill_buffer = Inf,
...) {
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)
method <- match.arg(method)
Expand Down Expand Up @@ -305,3 +306,5 @@ print.arx_class <- function(x, ...) {
name <- "ARX Classifier"
NextMethod(name = name, ...)
}

# this is a trivial change to induce a check
40 changes: 22 additions & 18 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#' out <- arx_forecaster(jhu, "death_rate",
#' c("case_rate", "death_rate"),
#' trainer = quantile_reg(),
#' args_list = arx_args_list(levels = 1:9 / 10)
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
#' )
arx_forecaster <- function(epi_data,
outcome,
Expand Down Expand Up @@ -99,7 +99,7 @@ arx_forecaster <- function(epi_data,
#' arx_fcast_epi_workflow(jhu, "death_rate",
#' c("case_rate", "death_rate"),
#' trainer = quantile_reg(),
#' args_list = arx_args_list(levels = 1:9 / 10)
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
#' )
arx_fcast_epi_workflow <- function(
epi_data,
Expand Down Expand Up @@ -134,18 +134,20 @@ arx_fcast_epi_workflow <- function(
# --- postprocessor
f <- frosting() %>% layer_predict() # %>% layer_naomit()
if (inherits(trainer, "quantile_reg")) {
# add all levels to the forecaster and update postprocessor
tau <- sort(compare_quantile_args(
args_list$levels,
rlang::eval_tidy(trainer$args$tau)
# add all quantile_level to the forecaster and update postprocessor
quantile_levels <- sort(compare_quantile_args(
args_list$quantile_levels,
rlang::eval_tidy(trainer$args$quantile_levels)
))
args_list$levels <- tau
trainer$args$tau <- rlang::enquo(tau)
f <- layer_quantile_distn(f, levels = tau) %>% layer_point_from_distn()
args_list$quantile_levels <- quantile_levels
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>%
layer_point_from_distn()
} else {
f <- layer_residual_quantiles(
f,
probs = args_list$levels, symmetrize = args_list$symmetrize,
quantile_levels = args_list$quantile_levels,
symmetrize = args_list$symmetrize,
by_key = args_list$quantile_by_key
)
}
Expand Down Expand Up @@ -173,7 +175,7 @@ arx_fcast_epi_workflow <- function(
#' The default `NULL` will attempt to determine this automatically.
#' @param target_date Date. The date for which the forecast is intended.
#' The default `NULL` will attempt to determine this automatically.
#' @param levels Vector or `NULL`. A vector of probabilities to produce
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
#' @param symmetrize Logical. The default `TRUE` calculates
Expand All @@ -197,6 +199,7 @@ arx_fcast_epi_workflow <- function(
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
#' will be treated as _additional_ allowed recent data rather than the
#' total amount of recent data to examine.
#' @param ... Space to handle future expansions (unused).
#'
#'
#' @return A list containing updated parameter choices with class `arx_flist`.
Expand All @@ -205,18 +208,19 @@ arx_fcast_epi_workflow <- function(
#' @examples
#' arx_args_list()
#' arx_args_list(symmetrize = FALSE)
#' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
#' arx_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
arx_args_list <- function(
lags = c(0L, 7L, 14L),
ahead = 7L,
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
levels = c(0.05, 0.95),
quantile_levels = c(0.05, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L),
nafill_buffer = Inf) {
nafill_buffer = Inf,
...) {
# error checking if lags is a list
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)
Expand All @@ -227,7 +231,7 @@ arx_args_list <- function(
arg_is_date(forecast_date, target_date, allow_null = TRUE)
arg_is_nonneg_int(ahead, lags)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(levels, allow_null = TRUE)
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
Expand All @@ -238,7 +242,7 @@ arx_args_list <- function(
lags = .lags,
ahead,
n_training,
levels,
quantile_levels,
forecast_date,
target_date,
symmetrize,
Expand All @@ -259,8 +263,8 @@ print.arx_fcast <- function(x, ...) {
}

compare_quantile_args <- function(alist, tlist) {
default_alist <- eval(formals(arx_args_list)$levels)
default_tlist <- eval(formals(quantile_reg)$tau)
default_alist <- eval(formals(arx_args_list)$quantile_level)
default_tlist <- eval(formals(quantile_reg)$quantile_level)
if (setequal(alist, default_alist)) {
if (setequal(tlist, default_tlist)) {
return(sort(unique(union(alist, tlist))))
Expand Down
8 changes: 4 additions & 4 deletions R/canned-epipred.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ validate_forecaster_inputs <- function(epi_data, outcome, predictors) {
arg_is_chr(predictors)
arg_is_chr_scalar(outcome)
if (!outcome %in% names(epi_data)) {
cli::cli_abort("{outcome} was not found in the training data.")
cli::cli_abort("{.var {outcome}} was not found in the training data.")
}
check <- hardhat::check_column_names(epi_data, predictors)
if (!check$ok) {
cli::cli_abort(c(
"At least one predictor was not found in the training data.",
"!" = "The following required columns are missing: {check$missing_names}."
"!" = "The following required columns are missing: {.val {check$missing_names}}."
))
}
invisible(TRUE)
Expand All @@ -41,8 +41,8 @@ arx_lags_validator <- function(predictors, lags) {
predictors_miss <- setdiff(predictors, names(lags))
cli::cli_abort(c(
"If lags is a named list, then all predictors must be present.",
i = "The predictors are '{predictors}'.",
i = "So lags is missing '{predictors_miss}'."
i = "The predictors are {.var {predictors}}.",
i = "So lags is missing {.var {predictors_miss}}'."
))
}
}
Expand Down
Loading