Skip to content

Commit 7278c69

Browse files
dajmcdondsweber2
authored andcommitted
remove the distributional package
1 parent d508069 commit 7278c69

29 files changed

+244
-413
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ Depends:
3030
Imports:
3131
checkmate,
3232
cli,
33-
distributional,
3433
dplyr,
3534
epiprocess (>= 0.10.4),
3635
generics,
@@ -51,6 +50,7 @@ Imports:
5150
workflows (>= 1.0.0)
5251
Suggests:
5352
data.table,
53+
distributional,
5454
epidatr (>= 1.0.0),
5555
fs,
5656
grf,

NAMESPACE

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ S3method(slather,layer_residual_quantiles)
110110
S3method(slather,layer_threshold)
111111
S3method(slather,layer_unnest)
112112
S3method(snap,default)
113-
S3method(snap,dist_default)
114-
S3method(snap,dist_quantiles)
115-
S3method(snap,distribution)
113+
S3method(snap,quantile_pred)
116114
S3method(tidy,check_enough_train_data)
117115
S3method(tidy,frosting)
118116
S3method(tidy,layer)
@@ -184,8 +182,6 @@ export(layer_quantile_distn)
184182
export(layer_residual_quantiles)
185183
export(layer_threshold)
186184
export(layer_unnest)
187-
export(mutate)
188-
export(nested_quantiles)
189185
export(new_default_epi_recipe_blueprint)
190186
export(new_epi_recipe_blueprint)
191187
export(pivot_longer)
@@ -202,6 +198,7 @@ export(rename)
202198
export(select)
203199
export(slather)
204200
export(smooth_quantile_reg)
201+
export(snap)
205202
export(step_adjust_latency)
206203
export(step_climate)
207204
export(step_epi_ahead)
@@ -280,6 +277,7 @@ importFrom(ggplot2,geom_point)
280277
importFrom(ggplot2,geom_ribbon)
281278
importFrom(glue,glue)
282279
importFrom(hardhat,extract_recipe)
280+
importFrom(hardhat,quantile_pred)
283281
importFrom(hardhat,refresh_blueprint)
284282
importFrom(hardhat,run_mold)
285283
importFrom(lubridate,"%m-%")

R/autoplot.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ plot_bands <- function(
243243
ntarget_dates <- dplyr::n_distinct(predictions$time_value)
244244

245245
predictions <- predictions %>%
246-
mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, levels), levels)) %>%
246+
mutate(.pred_distn = quantile_pred(quantile(.pred_distn, levels), levels)) %>%
247247
pivot_quantiles_wider(.pred_distn)
248248
qnames <- setdiff(names(predictions), innames)
249249

R/extrapolate_quantiles.R

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,26 @@
11
#' Summarize a distribution with a set of quantiles
22
#'
3-
#' @param x a `distribution` vector
3+
#' This function takes a `quantile_pred` vector and returns the same
4+
#' type of object, expanded to include
5+
#' *additional* quantiles computed at `probs`. If you want behaviour more
6+
#' similar to [stats::quantile()], then `quantile(x,...)` may be more
7+
#' appropriate.
8+
#'
9+
#' @param x A vector of class `quantile_pred`.
410
#' @param probs a vector of probabilities at which to calculate quantiles
511
#' @param replace_na logical. If `x` contains `NA`'s, these are imputed if
6-
#' possible (if `TRUE`) or retained (if `FALSE`). This only effects
7-
#' elements of class `dist_quantiles`.
12+
#' possible (if `TRUE`) or retained (if `FALSE`).
813
#' @param ... additional arguments passed on to the `quantile` method
914
#'
10-
#' @return a `distribution` vector containing `dist_quantiles`. Any elements
11-
#' of `x` which were originally `dist_quantiles` will now have a superset
15+
#' @return a `quantile_pred` vector. Each element
16+
#' of `x` will now have a superset
1217
#' of the original `quantile_values` (the union of those and `probs`).
1318
#' @export
1419
#'
1520
#' @examples
16-
#' library(distributional)
17-
#' dstn <- dist_normal(c(10, 2), c(5, 10))
18-
#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75))
19-
#'
20-
#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
21-
#' # because this distribution is already quantiles, any extra quantiles are
22-
#' # appended
23-
#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75))
24-
#'
25-
#' dstn <- c(
26-
#' dist_normal(c(10, 2), c(5, 10)),
27-
#' dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
28-
#' )
29-
#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75))
21+
#' dstn <- quantile_dstn(rbind(1:4, 8:11), c(.2, .4, .6, .8))
22+
#' # extra quantiles are appended
23+
#' as.tibble(extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)))
3024
extrapolate_quantiles <- function(x, probs, replace_na = TRUE, ...) {
3125
UseMethod("extrapolate_quantiles")
3226
}

R/flusight_hub_formatter.R

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,11 @@ flusight_hub_formatter.data.frame <- function(
104104

105105
object <- object %>%
106106
# combine the predictions and the distribution
107-
mutate(.pred_distn = nested_quantiles(.pred_distn)) %>%
108-
tidyr::unnest(.pred_distn) %>%
107+
pivot_quantiles_longer(.pred_distn) %>%
109108
# now we create the correct column names
110109
rename(
111-
value = values,
112-
output_type_id = quantile_levels,
110+
value = .pred_distn_value,
111+
output_type_id = .pred_distn_quantile_level,
113112
reference_date = forecast_date
114113
) %>%
115114
# convert to fips codes, and add any constant cols passed in ...

R/layer_cdc_flatline_quantiles.R

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
#' in an additional `<list-col>` named `.pred_distn_all` containing 2-column
5252
#' [tibble::tibble()]'s. For each
5353
#' desired combination of `key`'s, the tibble will contain one row per ahead
54-
#' with the associated [dist_quantiles()].
54+
#' with the associated [quantile_pred()].
5555
#' @export
5656
#'
5757
#' @examples
@@ -265,11 +265,10 @@ propagate_samples <- function(
265265
}
266266
}
267267
res <- res[aheads]
268+
res_quantiles <- map(res, quantile, probs = quantile_levels)
268269
list(tibble(
269270
ahead = aheads,
270-
.pred_distn = map_vec(
271-
res, ~ dist_quantiles(quantile(.x, quantile_levels), quantile_levels)
272-
)
271+
.pred_distn = quantile_pred(do.call(rbind, res_quantiles), quantile_levels)
273272
))
274273
}
275274

R/layer_predictive_distn.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ layer_predictive_distn <- function(frosting,
4545
truncate = c(-Inf, Inf),
4646
name = ".pred_distn",
4747
id = rand_id("predictive_distn")) {
48+
if (!requireNamespace("distributional", quietly = TRUE)) {
49+
cli_abort(paste(
50+
"You must install the {.pkg distributional} package for",
51+
"this functionality."
52+
))
53+
}
4854
rlang::check_dots_empty()
4955
arg_is_chr_scalar(name, id)
5056
dist_type <- match.arg(dist_type)

R/layer_quantile_distn.R

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,23 @@ layer_quantile_distn_new <- function(quantile_levels, truncate, name, id) {
7878
slather.layer_quantile_distn <-
7979
function(object, components, workflow, new_data, ...) {
8080
dstn <- components$predictions$.pred
81-
if (!inherits(dstn, "distribution")) {
82-
cli_abort(c(
83-
"`layer_quantile_distn()` requires distributional predictions.",
84-
"These are of class {.cls {class(dstn)}}."
85-
))
81+
is_supported <- inherits(dstn, "distribution") ||
82+
inherits(dstn, "quantile_pred")
83+
if (!is_supported) {
84+
cli_abort(
85+
"`layer_quantile_distn()` requires distributional or quantile
86+
predictions. These are of class {.cls {class(dstn)}}."
87+
)
88+
}
89+
if (inherits(dstn, "distribution") && !requireNamespace("distributional", quietly = TRUE)) {
90+
cli_abort(
91+
"You must install the {.pkg distributional} package for this
92+
functionality."
93+
)
8694
}
8795
rlang::check_dots_empty()
8896

89-
dstn <- dist_quantiles(
97+
dstn <- quantile_pred(
9098
quantile(dstn, object$quantile_levels),
9199
object$quantile_levels
92100
)

R/layer_residual_quantiles.R

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,19 @@ slather.layer_residual_quantiles <-
125125
}
126126

127127
r <- r %>%
128-
summarise(
129-
dstn = list(quantile(
130-
c(.resid, s * .resid),
131-
probs = object$quantile_levels, na.rm = TRUE
132-
))
133-
)
128+
summarize(dstn = quantile_pred(matrix(quantile(
129+
c(.resid, s * .resid), probs = object$quantile_levels, na.rm = TRUE
130+
), nrow = 1), quantile_levels = object$quantile_levels))
134131
# Check for NA
135-
if (any(sapply(r$dstn, is.na))) {
132+
if (anyNA(as.matrix(r$dstn))) {
136133
cli_abort(c(
137134
"Residual quantiles could not be calculated due to missing residuals.",
138135
i = "This may be due to `n_train` < `ahead` in your {.cls epi_recipe}."
139136
))
140137
}
141138

142139
estimate <- components$predictions$.pred
143-
res <- tibble(
144-
.pred_distn = dist_quantiles(map2(estimate, r$dstn, "+"), object$quantile_levels)
145-
)
140+
res <- tibble(.pred_distn = r$dstn + estimate)
146141
res <- check_pname(res, components$predictions, object)
147142
components$predictions <- mutate(components$predictions, !!!res)
148143
components

R/layer_threshold_preds.R

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ layer_threshold_new <-
6262

6363

6464
# restrict various objects to the interval [lower, upper]
65+
#' @export
6566
snap <- function(x, lower, upper, ...) {
6667
UseMethod("snap")
6768
}
@@ -74,25 +75,11 @@ snap.default <- function(x, lower, upper, ...) {
7475
}
7576

7677
#' @export
77-
snap.distribution <- function(x, lower, upper, ...) {
78-
rlang::check_dots_empty()
79-
arg_is_scalar(lower, upper)
80-
dstn <- lapply(vec_data(x), snap, lower = lower, upper = upper)
81-
distributional:::wrap_dist(dstn)
82-
}
83-
84-
#' @export
85-
snap.dist_default <- function(x, lower, upper, ...) {
86-
rlang::check_dots_empty()
87-
x
88-
}
89-
90-
#' @export
91-
snap.dist_quantiles <- function(x, lower, upper, ...) {
92-
values <- field(x, "values")
93-
quantile_levels <- field(x, "quantile_levels")
94-
values <- snap(values, lower, upper)
95-
new_quantiles(values = values, quantile_levels = quantile_levels)
78+
snap.quantile_pred <- function(x, lower, upper, ...) {
79+
values <- as.matrix(x)
80+
quantile_levels <- x %@% "quantile_levels"
81+
values <- map(vctrs::vec_chop(values), ~ snap(.x, lower, upper))
82+
quantile_pred(do.call(rbind, values), quantile_levels = quantile_levels)
9683
}
9784

9885
#' @export

R/make_grf_quantiles.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,12 @@ make_grf_quantiles <- function() {
162162
)
163163
)
164164

165-
# turn the predictions into a tibble with a dist_quantiles column
165+
# turn the predictions into a tibble with a quantile_pred column
166166
process_qrf_preds <- function(x, object) {
167167
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig %>% sort()
168168
x <- x$predictions
169169
out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
170-
out <- dist_quantiles(out, list(quantile_levels))
170+
out <- hardhat::quantile_pred(do.call(rbind, out), quantile_levels)
171171
return(dplyr::tibble(.pred = out))
172172
}
173173

R/make_quantile_reg.R

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,11 @@ make_quantile_reg <- function() {
110110

111111
process_rq_preds <- function(x, object) {
112112
object <- parsnip::extract_fit_engine(object)
113-
type <- class(object)[1]
114-
115-
# can't make a method because object is second
116-
out <- switch(type,
117-
rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile
118-
rqs = {
119-
x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
120-
dist_quantiles(x, list(object$tau))
121-
},
122-
cli_abort(c(
123-
"Prediction is not implemented for this `rq` type.",
124-
i = "See {.fun quantreg::rq}."
125-
))
126-
)
127-
return(dplyr::tibble(.pred = out))
113+
if (!is.matrix(x)) x <- as.matrix(x)
114+
rownames(x) <- NULL
115+
n_pred_quantiles <- ncol(x)
116+
quantile_levels <- object$tau
117+
tibble(.pred = hardhat::quantile_pred(x, quantile_levels))
128118
}
129119

130120
parsnip::set_pred(

R/make_smooth_quantile_reg.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
#' x = x[length(x) - 20] + ahead / 100 * 2 * pi,
4848
#' ahead = NULL
4949
#' ) %>%
50-
#' pivot_wider(names_from = quantile_levels, values_from = values)
50+
#' pivot_wider(names_from = distn_quantile_levels, values_from = distn_value)
5151
#' plot(x, y, pch = 16, xlim = c(pi, 2 * pi), col = "lightgrey")
5252
#' curve(sin(x), add = TRUE)
5353
#' abline(v = fd, lty = 2)
@@ -171,7 +171,7 @@ make_smooth_quantile_reg <- function() {
171171
x <- lapply(unname(split(
172172
p, seq(nrow(p))
173173
)), function(q) unname(sort(q, na.last = TRUE)))
174-
dist_quantiles(x, list(object$tau))
174+
quantile_pred(do.call(rbind, x), object$tau)
175175
})
176176
n_preds <- length(list_of_pred_distns[[1]])
177177
nout <- length(list_of_pred_distns)

0 commit comments

Comments
 (0)