|
1 |
| -#' @importFrom vctrs field vec_cast new_rcrd |
2 |
| -new_quantiles <- function(values = double(1), quantile_levels = double(1)) { |
3 |
| - arg_is_probabilities(quantile_levels) |
4 |
| - |
5 |
| - vec_cast(values, double()) |
6 |
| - vec_cast(quantile_levels, double()) |
7 |
| - values <- unname(values) |
8 |
| - if (length(values) == 0L) { |
9 |
| - return(new_rcrd( |
10 |
| - list( |
11 |
| - values = rep(NA_real_, length(quantile_levels)), |
12 |
| - quantile_levels = quantile_levels |
13 |
| - ), |
14 |
| - class = c("dist_quantiles", "dist_default") |
15 |
| - )) |
16 |
| - } |
17 |
| - stopifnot(length(values) == length(quantile_levels)) |
18 |
| - |
19 |
| - stopifnot(!vctrs::vec_duplicate_any(quantile_levels)) |
20 |
| - if (is.unsorted(quantile_levels)) { |
21 |
| - o <- vctrs::vec_order(quantile_levels) |
22 |
| - values <- values[o] |
23 |
| - quantile_levels <- quantile_levels[o] |
24 |
| - } |
25 |
| - if (is.unsorted(values, na.rm = TRUE)) { |
26 |
| - cli_abort("`values[order(quantile_levels)]` produces unsorted quantiles.") |
27 |
| - } |
28 |
| - |
29 |
| - new_rcrd(list(values = values, quantile_levels = quantile_levels), |
30 |
| - class = c("dist_quantiles", "dist_default") |
31 |
| - ) |
32 |
| -} |
33 |
| - |
34 |
| - |
35 |
| - |
36 |
| -#' @importFrom vctrs vec_ptype_abbr vec_ptype_full |
37 |
| -#' @export |
38 |
| -vec_ptype_abbr.dist_quantiles <- function(x, ...) "dist_qntls" |
39 |
| -#' @export |
40 |
| -vec_ptype_full.dist_quantiles <- function(x, ...) "dist_quantiles" |
41 |
| - |
42 |
| -#' @export |
43 |
| -format.dist_quantiles <- function(x, digits = 2, ...) { |
44 |
| - m <- suppressWarnings(median(x)) |
45 |
| - paste0("quantiles(", round(m, digits), ")[", vctrs::vec_size(x), "]") |
46 |
| -} |
47 |
| - |
48 |
| - |
49 |
| -#' A distribution parameterized by a set of quantiles |
50 |
| -#' |
51 |
| -#' @param values A vector (or list of vectors) of values. |
52 |
| -#' @param quantile_levels A vector (or list of vectors) of probabilities |
53 |
| -#' corresponding to `values`. |
54 |
| -#' |
55 |
| -#' When creating multiple sets of `values`/`quantile_levels` resulting in |
56 |
| -#' different distributions, the sizes must match. See the examples below. |
57 |
| -#' |
58 |
| -#' @return A vector of class `"distribution"`. |
59 |
| -#' |
60 |
| -#' @export |
61 |
| -#' |
62 |
| -#' @examples |
63 |
| -#' dist_quantiles(1:4, 1:4 / 5) |
64 |
| -#' dist_quantiles(list(1:3, 1:4), list(1:3 / 4, 1:4 / 5)) |
65 |
| -#' dstn <- dist_quantiles(list(1:4, 8:11), c(.2, .4, .6, .8)) |
66 |
| -#' dstn |
67 |
| -#' |
68 |
| -#' quantile(dstn, p = c(.1, .25, .5, .9)) |
69 |
| -#' median(dstn) |
70 |
| -#' |
71 |
| -#' # it's a bit annoying to inspect the data |
72 |
| -#' distributional::parameters(dstn[1]) |
73 |
| -#' nested_quantiles(dstn[1])[[1]] |
74 |
| -#' |
75 |
| -#' @importFrom vctrs as_list_of vec_recycle_common new_vctr |
76 |
| -dist_quantiles <- function(values, quantile_levels) { |
77 |
| - if (!is.list(quantile_levels)) { |
78 |
| - assert_numeric(quantile_levels, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L) |
79 |
| - quantile_levels <- list(quantile_levels) |
80 |
| - } |
81 |
| - if (!is.list(values)) { |
82 |
| - if (length(values) == 0L) values <- NA_real_ |
83 |
| - values <- list(values) |
84 |
| - } |
85 |
| - |
86 |
| - values <- as_list_of(values, .ptype = double()) |
87 |
| - quantile_levels <- as_list_of(quantile_levels, .ptype = double()) |
88 |
| - args <- vec_recycle_common(values = values, quantile_levels = quantile_levels) |
89 |
| - |
90 |
| - qntls <- as_list_of( |
91 |
| - map2(args$values, args$quantile_levels, new_quantiles), |
92 |
| - .ptype = new_quantiles(NA_real_, 0.5) |
93 |
| - ) |
94 |
| - new_vctr(qntls, class = "distribution") |
95 |
| -} |
96 |
| - |
97 |
| -validate_dist_quantiles <- function(values, quantile_levels) { |
98 |
| - map(quantile_levels, arg_is_probabilities) |
99 |
| - common_length <- vctrs::vec_size_common( # aborts internally |
100 |
| - values = values, |
101 |
| - quantile_levels = quantile_levels |
102 |
| - ) |
103 |
| - length_diff <- vctrs::list_sizes(values) != vctrs::list_sizes(quantile_levels) |
104 |
| - if (any(length_diff)) { |
105 |
| - cli_abort(c( |
106 |
| - "`values` and `quantile_levels` must have common length.", |
107 |
| - i = "Mismatches found at position(s): {.val {which(length_diff)}}." |
108 |
| - )) |
109 |
| - } |
110 |
| - level_duplication <- map_lgl(quantile_levels, vctrs::vec_duplicate_any) |
111 |
| - if (any(level_duplication)) { |
112 |
| - cli_abort(c( |
113 |
| - "`quantile_levels` must not be duplicated.", |
114 |
| - i = "Duplicates found at position(s): {.val {which(level_duplication)}}." |
115 |
| - )) |
116 |
| - } |
117 |
| -} |
118 |
| - |
119 |
| - |
120 |
| -is_dist_quantiles <- function(x) { |
121 |
| - is_distribution(x) & all(stats::family(x) == "quantiles") |
122 |
| -} |
123 |
| - |
124 |
| - |
125 |
| - |
126 |
| -#' @export |
127 |
| -#' @importFrom stats median qnorm family |
128 |
| -median.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "linear")) { |
129 |
| - quantile_levels <- field(x, "quantile_levels") |
130 |
| - values <- field(x, "values") |
131 |
| - # we have exactly that quantile |
132 |
| - if (0.5 %in% quantile_levels) { |
133 |
| - return(values[match(0.5, quantile_levels)]) |
134 |
| - } |
135 |
| - # if there's only 1 quantile_level (and it isn't 0.5), or the smallest quantile is larger than 0.5 or the largest smaller than 0.5, or if every value is NA, return NA |
136 |
| - if (length(quantile_levels) < 2 || min(quantile_levels) > 0.5 || max(quantile_levels) < 0.5 || all(is.na(values))) { |
137 |
| - return(NA) |
138 |
| - } |
139 |
| - if (length(quantile_levels) < 3 || min(quantile_levels) > .25 || max(quantile_levels) < .75) { |
140 |
| - return(stats::approx(quantile_levels, values, xout = 0.5)$y) |
141 |
| - } |
142 |
| - quantile(x, 0.5, ..., middle = middle) |
143 |
| -} |
144 | 1 |
|
145 | 2 | # placeholder to avoid errors, but not ideal
|
146 | 3 | #' @export
|
147 |
| -mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "linear")) { |
148 |
| - median(x, ..., middle = middle) |
| 4 | +mean.quantile_pred <- function(x, na.rm = FALSE, ...) { |
| 5 | + median(x, ...) |
149 | 6 | }
|
150 | 7 |
|
151 | 8 | #' @export
|
152 | 9 | #' @importFrom stats quantile
|
153 |
| -#' @import distributional |
154 |
| -quantile.dist_quantiles <- function(x, p, ..., middle = c("cubic", "linear")) { |
| 10 | +quantile.quantile_pred <- function(x, p, ..., middle = c("cubic", "linear")) { |
155 | 11 | arg_is_probabilities(p)
|
156 | 12 | p <- sort(p)
|
157 |
| - middle <- match.arg(middle) |
| 13 | + middle <- rlang::arg_match(middle) |
158 | 14 | quantile_extrapolate(x, p, middle)
|
159 | 15 | }
|
160 | 16 |
|
161 | 17 |
|
162 | 18 | quantile_extrapolate <- function(x, tau_out, middle) {
|
163 |
| - tau <- field(x, "quantile_levels") |
164 |
| - qvals <- field(x, "values") |
165 |
| - nas <- is.na(qvals) |
166 |
| - if (all(nas)) { |
167 |
| - return(rep(NA, times = length(tau_out))) |
168 |
| - } |
169 |
| - qvals_out <- rep(NA, length(tau_out)) |
170 |
| - qvals <- qvals[!nas] |
171 |
| - tau <- tau[!nas] |
| 19 | + tau <- x %@% "quantile_levels" |
| 20 | + qvals <- as.matrix(x) |
172 | 21 |
|
173 | 22 | # short circuit if we aren't actually extrapolating
|
174 | 23 | # matches to ~15 decimals
|
175 | 24 | if (all(tau_out %in% tau)) {
|
176 |
| - return(qvals[match(tau_out, tau)]) |
| 25 | + return(hardhat::quantile_pred( |
| 26 | + qvals[ ,match(tau_out, tau), drop = FALSE], tau_out |
| 27 | + )) |
177 | 28 | }
|
178 | 29 | if (length(tau) < 2) {
|
179 |
| - cli_abort( |
180 |
| - "Quantile extrapolation is not possible with fewer than 2 quantiles." |
181 |
| - ) |
182 |
| - return(qvals_out) |
| 30 | + cli_abort(paste( |
| 31 | + "Quantile extrapolation is not possible when fewer than 2 quantiles", |
| 32 | + "are available." |
| 33 | + )) |
183 | 34 | }
|
| 35 | + qvals_out <- map( |
| 36 | + vctrs::vec_chop(qvals), |
| 37 | + ~ extrapolate_quantiles_single(.x, tau, tau_out, middle) |
| 38 | + ) |
184 | 39 |
|
| 40 | + hardhat::quantile_pred(qvals_out, tau_out) |
| 41 | +} |
| 42 | + |
| 43 | +extrapolate_quantiles_single <- function(qvals, tau, tau_out, middle) { |
185 | 44 | indl <- tau_out < min(tau)
|
186 | 45 | indr <- tau_out > max(tau)
|
187 | 46 | indm <- !indl & !indr
|
| 47 | + qvals_out <- rep(NA, length(tau_out)) |
188 | 48 |
|
189 | 49 | if (middle == "cubic") {
|
190 | 50 | method <- "cubic"
|
191 |
| - result <- tryCatch( |
192 |
| - { |
193 |
| - Q <- stats::splinefun(tau, qvals, method = "hyman") |
194 |
| - quartiles <- Q(c(.25, .5, .75)) |
195 |
| - }, |
196 |
| - error = function(e) { |
197 |
| - return(NA) |
198 |
| - } |
199 |
| - ) |
| 51 | + result <- tryCatch({ |
| 52 | + Q <- stats::splinefun(tau, qvals, method = "hyman") |
| 53 | + quartiles <- Q(c(.25, .5, .75)) |
| 54 | + }, |
| 55 | + error = function(e) { |
| 56 | + return(NA) |
| 57 | + }) |
200 | 58 | }
|
201 | 59 | if (middle == "linear" || any(is.na(result))) {
|
202 | 60 | method <- "linear"
|
203 | 61 | quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y
|
204 | 62 | }
|
205 | 63 | if (any(indm)) {
|
206 |
| - qvals_out[indm] <- switch(method, |
| 64 | + qvals_out[indm] <- switch( |
| 65 | + method, |
207 | 66 | linear = stats::approx(tau, qvals, tau_out[indm])$y,
|
208 | 67 | cubic = Q(tau_out[indm])
|
209 | 68 | )
|
@@ -242,59 +101,3 @@ tail_extrapolate <- function(tau_out, qv) {
|
242 | 101 | m <- diff(y) / diff(x)
|
243 | 102 | m * (x0 - x[1]) + y[1]
|
244 | 103 | }
|
245 |
| - |
246 |
| - |
247 |
| -#' @method Math dist_quantiles |
248 |
| -#' @export |
249 |
| -Math.dist_quantiles <- function(x, ...) { |
250 |
| - quantile_levels <- field(x, "quantile_levels") |
251 |
| - values <- field(x, "values") |
252 |
| - values <- vctrs::vec_math(.Generic, values, ...) |
253 |
| - new_quantiles(values = values, quantile_levels = quantile_levels) |
254 |
| -} |
255 |
| - |
256 |
| -#' @method Ops dist_quantiles |
257 |
| -#' @export |
258 |
| -Ops.dist_quantiles <- function(e1, e2) { |
259 |
| - is_quantiles <- c( |
260 |
| - inherits(e1, "dist_quantiles"), |
261 |
| - inherits(e2, "dist_quantiles") |
262 |
| - ) |
263 |
| - is_dist <- c(inherits(e1, "dist_default"), inherits(e2, "dist_default")) |
264 |
| - tau1 <- tau2 <- NULL |
265 |
| - if (is_quantiles[1]) { |
266 |
| - q1 <- field(e1, "values") |
267 |
| - tau1 <- field(e1, "quantile_levels") |
268 |
| - } |
269 |
| - if (is_quantiles[2]) { |
270 |
| - q2 <- field(e2, "values") |
271 |
| - tau2 <- field(e2, "quantile_levels") |
272 |
| - } |
273 |
| - tau <- union(tau1, tau2) |
274 |
| - if (all(is_dist)) { |
275 |
| - cli_abort( |
276 |
| - "You can't perform arithmetic between two distributions like this." |
277 |
| - ) |
278 |
| - } else { |
279 |
| - if (is_quantiles[1]) { |
280 |
| - q2 <- e2 |
281 |
| - } else { |
282 |
| - q1 <- e1 |
283 |
| - } |
284 |
| - } |
285 |
| - q <- vctrs::vec_arith(.Generic, q1, q2) |
286 |
| - new_quantiles(values = q, quantile_levels = tau) |
287 |
| -} |
288 |
| - |
289 |
| -#' @method is.na distribution |
290 |
| -#' @export |
291 |
| -is.na.distribution <- function(x) { |
292 |
| - sapply(vec_data(x), is.na) |
293 |
| -} |
294 |
| - |
295 |
| -#' @method is.na dist_quantiles |
296 |
| -#' @export |
297 |
| -is.na.dist_quantiles <- function(x) { |
298 |
| - q <- field(x, "values") |
299 |
| - all(is.na(q)) |
300 |
| -} |
0 commit comments