Skip to content

Commit dacf6e7

Browse files
authored
Merge pull request #243 from cmu-delphi/242-quantile-renaming
242 quantile renaming
2 parents b2d1e11 + b8865c4 commit dacf6e7

36 files changed

+337
-259
lines changed

NAMESPACE

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ import(distributional)
167167
import(epiprocess)
168168
import(parsnip)
169169
import(recipes)
170-
import(vctrs)
170+
importFrom(cli,cli_abort)
171171
importFrom(epiprocess,growth_rate)
172172
importFrom(generics,augment)
173173
importFrom(generics,fit)
@@ -202,3 +202,12 @@ importFrom(stats,residuals)
202202
importFrom(tibble,as_tibble)
203203
importFrom(tibble,is_tibble)
204204
importFrom(tibble,tibble)
205+
importFrom(vctrs,as_list_of)
206+
importFrom(vctrs,field)
207+
importFrom(vctrs,new_rcrd)
208+
importFrom(vctrs,new_vctr)
209+
importFrom(vctrs,vec_cast)
210+
importFrom(vctrs,vec_data)
211+
importFrom(vctrs,vec_ptype_abbr)
212+
importFrom(vctrs,vec_ptype_full)
213+
importFrom(vctrs,vec_recycle_common)

R/arx_classifier.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ arx_class_args_list <- function(
250250
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
251251
log_scale = FALSE,
252252
additional_gr_args = list(),
253-
nafill_buffer = Inf) {
253+
nafill_buffer = Inf,
254+
...) {
254255
.lags <- lags
255256
if (is.list(lags)) lags <- unlist(lags)
256257
method <- match.arg(method)
@@ -305,3 +306,5 @@ print.arx_class <- function(x, ...) {
305306
name <- "ARX Classifier"
306307
NextMethod(name = name, ...)
307308
}
309+
310+
# this is a trivial change to induce a check

R/arx_forecaster.R

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#' out <- arx_forecaster(jhu, "death_rate",
3434
#' c("case_rate", "death_rate"),
3535
#' trainer = quantile_reg(),
36-
#' args_list = arx_args_list(levels = 1:9 / 10)
36+
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
3737
#' )
3838
arx_forecaster <- function(epi_data,
3939
outcome,
@@ -99,7 +99,7 @@ arx_forecaster <- function(epi_data,
9999
#' arx_fcast_epi_workflow(jhu, "death_rate",
100100
#' c("case_rate", "death_rate"),
101101
#' trainer = quantile_reg(),
102-
#' args_list = arx_args_list(levels = 1:9 / 10)
102+
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
103103
#' )
104104
arx_fcast_epi_workflow <- function(
105105
epi_data,
@@ -134,18 +134,20 @@ arx_fcast_epi_workflow <- function(
134134
# --- postprocessor
135135
f <- frosting() %>% layer_predict() # %>% layer_naomit()
136136
if (inherits(trainer, "quantile_reg")) {
137-
# add all levels to the forecaster and update postprocessor
138-
tau <- sort(compare_quantile_args(
139-
args_list$levels,
140-
rlang::eval_tidy(trainer$args$tau)
137+
# add all quantile_level to the forecaster and update postprocessor
138+
quantile_levels <- sort(compare_quantile_args(
139+
args_list$quantile_levels,
140+
rlang::eval_tidy(trainer$args$quantile_levels)
141141
))
142-
args_list$levels <- tau
143-
trainer$args$tau <- rlang::enquo(tau)
144-
f <- layer_quantile_distn(f, levels = tau) %>% layer_point_from_distn()
142+
args_list$quantile_levels <- quantile_levels
143+
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
144+
f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>%
145+
layer_point_from_distn()
145146
} else {
146147
f <- layer_residual_quantiles(
147148
f,
148-
probs = args_list$levels, symmetrize = args_list$symmetrize,
149+
quantile_levels = args_list$quantile_levels,
150+
symmetrize = args_list$symmetrize,
149151
by_key = args_list$quantile_by_key
150152
)
151153
}
@@ -173,7 +175,7 @@ arx_fcast_epi_workflow <- function(
173175
#' The default `NULL` will attempt to determine this automatically.
174176
#' @param target_date Date. The date for which the forecast is intended.
175177
#' The default `NULL` will attempt to determine this automatically.
176-
#' @param levels Vector or `NULL`. A vector of probabilities to produce
178+
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
177179
#' prediction intervals. These are created by computing the quantiles of
178180
#' training residuals. A `NULL` value will result in point forecasts only.
179181
#' @param symmetrize Logical. The default `TRUE` calculates
@@ -197,6 +199,7 @@ arx_fcast_epi_workflow <- function(
197199
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
198200
#' will be treated as _additional_ allowed recent data rather than the
199201
#' total amount of recent data to examine.
202+
#' @param ... Space to handle future expansions (unused).
200203
#'
201204
#'
202205
#' @return A list containing updated parameter choices with class `arx_flist`.
@@ -205,18 +208,19 @@ arx_fcast_epi_workflow <- function(
205208
#' @examples
206209
#' arx_args_list()
207210
#' arx_args_list(symmetrize = FALSE)
208-
#' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
211+
#' arx_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
209212
arx_args_list <- function(
210213
lags = c(0L, 7L, 14L),
211214
ahead = 7L,
212215
n_training = Inf,
213216
forecast_date = NULL,
214217
target_date = NULL,
215-
levels = c(0.05, 0.95),
218+
quantile_levels = c(0.05, 0.95),
216219
symmetrize = TRUE,
217220
nonneg = TRUE,
218221
quantile_by_key = character(0L),
219-
nafill_buffer = Inf) {
222+
nafill_buffer = Inf,
223+
...) {
220224
# error checking if lags is a list
221225
.lags <- lags
222226
if (is.list(lags)) lags <- unlist(lags)
@@ -227,7 +231,7 @@ arx_args_list <- function(
227231
arg_is_date(forecast_date, target_date, allow_null = TRUE)
228232
arg_is_nonneg_int(ahead, lags)
229233
arg_is_lgl(symmetrize, nonneg)
230-
arg_is_probabilities(levels, allow_null = TRUE)
234+
arg_is_probabilities(quantile_levels, allow_null = TRUE)
231235
arg_is_pos(n_training)
232236
if (is.finite(n_training)) arg_is_pos_int(n_training)
233237
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
@@ -238,7 +242,7 @@ arx_args_list <- function(
238242
lags = .lags,
239243
ahead,
240244
n_training,
241-
levels,
245+
quantile_levels,
242246
forecast_date,
243247
target_date,
244248
symmetrize,
@@ -259,8 +263,8 @@ print.arx_fcast <- function(x, ...) {
259263
}
260264

261265
compare_quantile_args <- function(alist, tlist) {
262-
default_alist <- eval(formals(arx_args_list)$levels)
263-
default_tlist <- eval(formals(quantile_reg)$tau)
266+
default_alist <- eval(formals(arx_args_list)$quantile_level)
267+
default_tlist <- eval(formals(quantile_reg)$quantile_level)
264268
if (setequal(alist, default_alist)) {
265269
if (setequal(tlist, default_tlist)) {
266270
return(sort(unique(union(alist, tlist))))

R/canned-epipred.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ validate_forecaster_inputs <- function(epi_data, outcome, predictors) {
88
arg_is_chr(predictors)
99
arg_is_chr_scalar(outcome)
1010
if (!outcome %in% names(epi_data)) {
11-
cli::cli_abort("{outcome} was not found in the training data.")
11+
cli::cli_abort("{.var {outcome}} was not found in the training data.")
1212
}
1313
check <- hardhat::check_column_names(epi_data, predictors)
1414
if (!check$ok) {
1515
cli::cli_abort(c(
1616
"At least one predictor was not found in the training data.",
17-
"!" = "The following required columns are missing: {check$missing_names}."
17+
"!" = "The following required columns are missing: {.val {check$missing_names}}."
1818
))
1919
}
2020
invisible(TRUE)
@@ -41,8 +41,8 @@ arx_lags_validator <- function(predictors, lags) {
4141
predictors_miss <- setdiff(predictors, names(lags))
4242
cli::cli_abort(c(
4343
"If lags is a named list, then all predictors must be present.",
44-
i = "The predictors are '{predictors}'.",
45-
i = "So lags is missing '{predictors_miss}'."
44+
i = "The predictors are {.var {predictors}}.",
45+
i = "So lags is missing {.var {predictors_miss}}'."
4646
))
4747
}
4848
}

0 commit comments

Comments
 (0)