Skip to content

Commit f5c794a

Browse files
authored
Merge branch 'v0.0.5' into smooth-quant-reg
2 parents 8c48889 + d6e685a commit f5c794a

9 files changed

+262
-68
lines changed

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ export(layer_unnest)
140140
export(nested_quantiles)
141141
export(new_default_epi_recipe_blueprint)
142142
export(new_epi_recipe_blueprint)
143+
export(pivot_quantiles)
143144
export(prep)
144145
export(quantile_reg)
145146
export(remove_frosting)

R/dist_quantiles.R

+66-3
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ extrapolate_quantiles.dist_quantiles <- function(x, p, ...) {
107107
new_quantiles(q = c(qvals, q), tau = c(tau, p))
108108
}
109109

110-
110+
is_dist_quantiles <- function(x) {
111+
is_distribution(x) && all(stats::family(x) == "quantiles")
112+
}
111113

112114

113115
#' Turn a a vector of quantile distributions into a list-col
@@ -124,8 +126,7 @@ extrapolate_quantiles.dist_quantiles <- function(x, p, ...) {
124126
#' edf_nested <- edf %>% dplyr::mutate(q = nested_quantiles(q))
125127
#' edf_nested %>% tidyr::unnest(q)
126128
nested_quantiles <- function(x) {
127-
stopifnot(is_distribution(x),
128-
all(stats::family(x) == "quantiles"))
129+
stopifnot(is_dist_quantiles(x))
129130
distributional:::dist_apply(x, .f = function(z) {
130131
tibble::as_tibble(vec_data(z)) %>%
131132
dplyr::mutate(dplyr::across(tidyselect::everything(), as.double)) %>%
@@ -134,6 +135,68 @@ nested_quantiles <- function(x) {
134135
}
135136

136137

138+
#' Pivot columns containing `dist_quantile` wider
139+
#'
140+
#' Any selected columns that contain `dist_quantiles` will be "widened" with
141+
#' the "taus" (quantile) serving as names and the values in the data frame.
142+
#' When pivoting multiple columns, the original column name will be used as
143+
#' a prefix.
144+
#'
145+
#' @param .data A data frame, or a data frame extension such as a tibble or
146+
#' epi_df.
147+
#' @param ... <[`tidy-select`][dplyr::dplyr_tidy_select]> One or more unquoted
148+
#' expressions separated by commas. Variable names can be used as if they
149+
#' were positions in the data frame, so expressions like `x:y` can
150+
#' be used to select a range of variables. Any selected columns should
151+
#'
152+
#' @return An object of the same class as `.data`
153+
#' @export
154+
#'
155+
#' @examples
156+
#' d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:4, 1:3 / 4))
157+
#' d2 <- c(dist_quantiles(2:4, 2:4 / 5), dist_quantiles(3:5, 2:4 / 5))
158+
#' tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2)
159+
#'
160+
#' pivot_quantiles(tib, c("d1", "d2"))
161+
#' pivot_quantiles(tib, tidyselect::starts_with("d"))
162+
#' pivot_quantiles(tib, d2)
163+
pivot_quantiles <- function(.data, ...) {
164+
expr <- rlang::expr(c(...))
165+
cols <- names(tidyselect::eval_select(expr, .data))
166+
dqs <- map_lgl(cols, ~ is_dist_quantiles(.data[[.x]]))
167+
if (!all(dqs)) {
168+
nms <- cols[!dqs]
169+
cli::cli_abort(
170+
"Variables(s) {.var {nms}} are not `dist_quantiles`. Cannot pivot them."
171+
)
172+
}
173+
.data <- .data %>%
174+
dplyr::mutate(dplyr::across(tidyselect::all_of(cols), nested_quantiles))
175+
checks <- map_lgl(cols, ~ diff(range(vctrs::list_sizes(.data[[.x]]))) == 0L)
176+
if (!all(checks)) {
177+
nms <- cols[!checks]
178+
cli::cli_abort(
179+
c("Quantiles must be the same length and have the same set of taus.",
180+
i = "Check failed for variables(s) {.var {nms}}."))
181+
}
182+
if (length(cols) > 1L) {
183+
for (col in cols) {
184+
.data <- .data %>%
185+
tidyr::unnest(tidyselect::all_of(col)) %>%
186+
tidyr::pivot_wider(
187+
names_from = "tau", values_from = "q",
188+
names_prefix = paste0(col, "_")
189+
)
190+
}
191+
} else {
192+
.data <- .data %>%
193+
tidyr::unnest(tidyselect::all_of(cols)) %>%
194+
tidyr::pivot_wider(names_from = "tau", values_from = "q")
195+
}
196+
.data
197+
}
198+
199+
137200

138201

139202
#' @export

R/flatline.R

+21-1
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,31 @@ predict.flatline <- function(object, newdata, ...) {
7373
object <- object$.pred
7474
metadata <- names(object)[names(object) != ".pred"]
7575
ek <- names(newdata)
76-
if (! all(metadata %in% ek)) {
76+
if (!all(metadata %in% ek)) {
7777
cli_stop("`newdata` has different metadata than was used",
7878
"to fit the flatline forecaster")
7979
}
8080

8181
dplyr::left_join(newdata, object, by = metadata) %>%
8282
dplyr::pull(.pred)
8383
}
84+
85+
#' @export
86+
print.flatline <- function(x, ...) {
87+
keys <- colnames(x$.pred)
88+
keys <- paste(keys[!(keys %in% ".pred")], collapse = ", ")
89+
nloc <- nrow(x$.pred)
90+
nres <- nrow(x$residuals)
91+
pmsg <- glue::glue(
92+
"Predictions produced by {keys} resulting in {nloc} total forecasts."
93+
)
94+
rmsg <- glue::glue(
95+
"A total of {nres} residuals are available from the training set."
96+
)
97+
cat("Flatline forecaster\n")
98+
cat("\n")
99+
cat(pmsg)
100+
cat("\n")
101+
cat(rmsg)
102+
cat("\n\n")
103+
}

R/flatline_forecaster.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ flatline_forecaster <- function(
3737
cli_stop("args_list was not created using `flatline_args_list().")
3838
}
3939
keys <- epi_keys(epi_data)
40-
ek <- keys[-1]
40+
ek <- kill_time_value(keys)
4141
outcome <- rlang::sym(outcome)
4242

4343

4444
r <- epi_recipe(epi_data) %>%
4545
step_epi_ahead(!!outcome, ahead = args_list$ahead, skip = TRUE) %>%
4646
recipes::update_role(!!outcome, new_role = "predictor") %>%
47-
recipes::add_role(dplyr::all_of(keys), new_role = "predictor") %>%
47+
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
4848
step_training_window(n_recent = args_list$n_training)
4949

5050
latest <- get_test_data(epi_recipe(epi_data), epi_data)
@@ -65,7 +65,6 @@ flatline_forecaster <- function(
6565
eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
6666

6767
wf <- epi_workflow(r, eng, f)
68-
6968
wf <- generics::fit(wf, epi_data)
7069
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
7170
tibble::as_tibble() %>%

R/layer_residual_quantiles.R

+61-20
Original file line numberDiff line numberDiff line change
@@ -75,29 +75,43 @@ slather.layer_residual_quantiles <-
7575
function(object, components, the_fit, the_recipe, ...) {
7676
if (is.null(object$probs)) return(components)
7777

78-
7978
s <- ifelse(object$symmetrize, -1, NA)
80-
r <- dplyr::bind_cols(
81-
r = grab_residuals(the_fit, components),
82-
geo_value = components$mold$extras$roles$geo_value,
83-
components$mold$extras$roles$key)
79+
r <- grab_residuals(the_fit, components)
8480

8581
## Handle any grouping requests
8682
if (length(object$by_key) > 0L) {
87-
common <- intersect(object$by_key, names(r))
88-
excess <- setdiff(object$by_key, names(r))
83+
key_cols <- dplyr::bind_cols(
84+
geo_value = components$mold$extras$roles$geo_value,
85+
components$mold$extras$roles$key
86+
)
87+
common <- intersect(object$by_key, names(key_cols))
88+
excess <- setdiff(object$by_key, names(key_cols))
8989
if (length(excess) > 0L) {
90-
cli_warn("Requested residual grouping key(s) {excess} unavailable ",
91-
"in the original data. Grouping by the remainder {common}.")
92-
90+
rlang::warn(
91+
"Requested residual grouping key(s) {excess} are unavailable ",
92+
"in the original data. Grouping by the remainder: {common}."
93+
)
94+
}
95+
if (length(common) > 0L) {
96+
r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid")))
97+
common_in_r <- common[common %in% names(r)]
98+
if (length(common_in_r) != length(common)) {
99+
rlang::warn(
100+
"Some grouping keys are not in data.frame returned by the",
101+
"`residuals()` method. Groupings may not be correct."
102+
)
103+
}
104+
r <- dplyr::bind_cols(key_cols, r) %>%
105+
dplyr::group_by(!!!rlang::syms(common))
93106
}
94-
if (length(common) > 0L)
95-
r <- r %>% dplyr::group_by(!!!rlang::syms(common))
96107
}
97108

98109
r <- r %>%
99-
dplyr::summarise(
100-
q = list(quantile(c(r, s * r), probs = object$probs, na.rm = TRUE))
110+
dplyr::summarize(
111+
q = list(quantile(
112+
c(.resid, s * .resid),
113+
probs = object$probs, na.rm = TRUE
114+
))
101115
)
102116

103117
estimate <- components$predictions$.pred
@@ -112,13 +126,40 @@ slather.layer_residual_quantiles <-
112126
grab_residuals <- function(the_fit, components) {
113127
if (the_fit$spec$mode != "regression")
114128
rlang::abort("For meaningful residuals, the predictor should be a regression model.")
115-
r_generic <- attr(utils::methods(class = class(the_fit)[1]), "info")$generic
116-
if ("residuals" %in% r_generic) {
117-
r <- residuals(the_fit)
118-
} else {
119-
yhat <- predict(the_fit, new_data = components$mold$predictors)
120-
r <- c(components$mold$outcomes - yhat)[[1]]
129+
r_generic <- attr(utils::methods(class = class(the_fit$fit)[1]), "info")$generic
130+
if ("residuals" %in% r_generic) { # Try to use the available method.
131+
cl <- class(the_fit$fit)[1]
132+
r <- residuals(the_fit$fit)
133+
if (inherits(r, "data.frame")) {
134+
if (".resid" %in% names(r)) { # success
135+
return(r)
136+
} else { # failure
137+
rlang::warn(c(
138+
"The `residuals()` method for objects of class {cl} results in",
139+
"a data frame without a column named `.resid`.",
140+
i = "Residual quantiles will be calculated directly from the",
141+
i = "difference between predictions and observations.",
142+
i = "This may result in unexpected behaviour."
143+
))
144+
}
145+
} else if (is.vector(drop(r))) { # also success
146+
return(tibble(.resid = drop(r)))
147+
} else { # failure
148+
rlang::warn(c(
149+
"The `residuals()` method for objects of class {cl} results in an",
150+
"object that is neither a data frame with a column named `.resid`,",
151+
"nor something coercible to a vector.",
152+
i = "Residual quantiles will be calculated directly from the",
153+
i = "difference between predictions and observations.",
154+
i = "This may result in unexpected behaviour."
155+
))
156+
}
121157
}
158+
# The method failed for one reason or another and a warning was issued
159+
# Or there was no method available.
160+
yhat <- predict(the_fit, new_data = components$mold$predictors)
161+
r <- c(components$mold$outcomes - yhat)[[1]] # this will be a vector
162+
r <- tibble(.resid = r)
122163
r
123164
}
124165

0 commit comments

Comments
 (0)