Skip to content

Commit 6e76e2b

Browse files
dajmcdondsweber2
authored andcommitted
skeleton
1 parent 4da9b23 commit 6e76e2b

File tree

5 files changed

+53
-312
lines changed

5 files changed

+53
-312
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Imports:
3636
generics,
3737
ggplot2,
3838
glue,
39-
hardhat (>= 1.3.0),
39+
hardhat (>= 1.4.0.9002),
4040
lifecycle,
4141
lubridate,
4242
magrittr,

NAMESPACE

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
S3method(Add_model,epi_workflow)
44
S3method(Add_model,workflow)
5-
S3method(Math,dist_quantiles)
6-
S3method(Ops,dist_quantiles)
75
S3method(Remove_model,epi_workflow)
86
S3method(Remove_model,workflow)
97
S3method(Update_model,epi_workflow)
@@ -48,13 +46,9 @@ S3method(fit,epi_workflow)
4846
S3method(flusight_hub_formatter,canned_epipred)
4947
S3method(flusight_hub_formatter,data.frame)
5048
S3method(forecast,epi_workflow)
51-
S3method(format,dist_quantiles)
52-
S3method(is.na,dist_quantiles)
53-
S3method(is.na,distribution)
5449
S3method(key_colnames,epi_workflow)
5550
S3method(key_colnames,recipe)
56-
S3method(mean,dist_quantiles)
57-
S3method(median,dist_quantiles)
51+
S3method(mean,quantile_pred)
5852
S3method(predict,epi_workflow)
5953
S3method(predict,flatline)
6054
S3method(prep,check_enough_train_data)
@@ -101,7 +95,7 @@ S3method(print,step_lag_difference)
10195
S3method(print,step_naomit)
10296
S3method(print,step_population_scaling)
10397
S3method(print,step_training_window)
104-
S3method(quantile,dist_quantiles)
98+
S3method(quantile,quantile_pred)
10599
S3method(refresh_blueprint,default_epi_recipe_blueprint)
106100
S3method(residuals,flatline)
107101
S3method(run_mold,default_epi_recipe_blueprint)
@@ -125,8 +119,6 @@ S3method(tidy,check_enough_train_data)
125119
S3method(tidy,frosting)
126120
S3method(tidy,layer)
127121
S3method(update,layer)
128-
S3method(vec_ptype_abbr,dist_quantiles)
129-
S3method(vec_ptype_full,dist_quantiles)
130122
S3method(weighted_interval_score,default)
131123
S3method(weighted_interval_score,dist_default)
132124
S3method(weighted_interval_score,dist_quantiles)
@@ -160,7 +152,6 @@ export(climate_args_list)
160152
export(climatological_forecaster)
161153
export(default_epi_recipe_blueprint)
162154
export(detect_layer)
163-
export(dist_quantiles)
164155
export(epi_recipe)
165156
export(epi_recipe_blueprint)
166157
export(epi_workflow)
@@ -230,7 +221,6 @@ export(update_frosting)
230221
export(update_model)
231222
export(validate_layer)
232223
export(weighted_interval_score)
233-
import(distributional)
234224
import(epidatasets)
235225
import(epiprocess)
236226
import(parsnip)
@@ -325,14 +315,11 @@ importFrom(rlang,list2)
325315
importFrom(rlang,set_names)
326316
importFrom(rlang,sym)
327317
importFrom(stats,as.formula)
328-
importFrom(stats,family)
329318
importFrom(stats,lm)
330-
importFrom(stats,median)
331319
importFrom(stats,model.frame)
332320
importFrom(stats,na.omit)
333321
importFrom(stats,poly)
334322
importFrom(stats,predict)
335-
importFrom(stats,qnorm)
336323
importFrom(stats,quantile)
337324
importFrom(stats,residuals)
338325
importFrom(tibble,as_tibble)
@@ -346,13 +333,6 @@ importFrom(tidyr,pivot_wider)
346333
importFrom(tidyr,unnest)
347334
importFrom(tidyselect,all_of)
348335
importFrom(utils,capture.output)
349-
importFrom(vctrs,as_list_of)
350-
importFrom(vctrs,field)
351-
importFrom(vctrs,new_rcrd)
352-
importFrom(vctrs,new_vctr)
353336
importFrom(vctrs,vec_cast)
354337
importFrom(vctrs,vec_data)
355-
importFrom(vctrs,vec_ptype_abbr)
356-
importFrom(vctrs,vec_ptype_full)
357-
importFrom(vctrs,vec_recycle_common)
358338
importFrom(workflows,extract_preprocessor)

R/dist_quantiles.R

Lines changed: 31 additions & 228 deletions
Original file line numberDiff line numberDiff line change
@@ -1,209 +1,68 @@
1-
#' @importFrom vctrs field vec_cast new_rcrd
2-
new_quantiles <- function(values = double(1), quantile_levels = double(1)) {
3-
arg_is_probabilities(quantile_levels)
4-
5-
vec_cast(values, double())
6-
vec_cast(quantile_levels, double())
7-
values <- unname(values)
8-
if (length(values) == 0L) {
9-
return(new_rcrd(
10-
list(
11-
values = rep(NA_real_, length(quantile_levels)),
12-
quantile_levels = quantile_levels
13-
),
14-
class = c("dist_quantiles", "dist_default")
15-
))
16-
}
17-
stopifnot(length(values) == length(quantile_levels))
18-
19-
stopifnot(!vctrs::vec_duplicate_any(quantile_levels))
20-
if (is.unsorted(quantile_levels)) {
21-
o <- vctrs::vec_order(quantile_levels)
22-
values <- values[o]
23-
quantile_levels <- quantile_levels[o]
24-
}
25-
if (is.unsorted(values, na.rm = TRUE)) {
26-
cli_abort("`values[order(quantile_levels)]` produces unsorted quantiles.")
27-
}
28-
29-
new_rcrd(list(values = values, quantile_levels = quantile_levels),
30-
class = c("dist_quantiles", "dist_default")
31-
)
32-
}
33-
34-
35-
36-
#' @importFrom vctrs vec_ptype_abbr vec_ptype_full
37-
#' @export
38-
vec_ptype_abbr.dist_quantiles <- function(x, ...) "dist_qntls"
39-
#' @export
40-
vec_ptype_full.dist_quantiles <- function(x, ...) "dist_quantiles"
41-
42-
#' @export
43-
format.dist_quantiles <- function(x, digits = 2, ...) {
44-
m <- suppressWarnings(median(x))
45-
paste0("quantiles(", round(m, digits), ")[", vctrs::vec_size(x), "]")
46-
}
47-
48-
49-
#' A distribution parameterized by a set of quantiles
50-
#'
51-
#' @param values A vector (or list of vectors) of values.
52-
#' @param quantile_levels A vector (or list of vectors) of probabilities
53-
#' corresponding to `values`.
54-
#'
55-
#' When creating multiple sets of `values`/`quantile_levels` resulting in
56-
#' different distributions, the sizes must match. See the examples below.
57-
#'
58-
#' @return A vector of class `"distribution"`.
59-
#'
60-
#' @export
61-
#'
62-
#' @examples
63-
#' dist_quantiles(1:4, 1:4 / 5)
64-
#' dist_quantiles(list(1:3, 1:4), list(1:3 / 4, 1:4 / 5))
65-
#' dstn <- dist_quantiles(list(1:4, 8:11), c(.2, .4, .6, .8))
66-
#' dstn
67-
#'
68-
#' quantile(dstn, p = c(.1, .25, .5, .9))
69-
#' median(dstn)
70-
#'
71-
#' # it's a bit annoying to inspect the data
72-
#' distributional::parameters(dstn[1])
73-
#' nested_quantiles(dstn[1])[[1]]
74-
#'
75-
#' @importFrom vctrs as_list_of vec_recycle_common new_vctr
76-
dist_quantiles <- function(values, quantile_levels) {
77-
if (!is.list(quantile_levels)) {
78-
assert_numeric(quantile_levels, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L)
79-
quantile_levels <- list(quantile_levels)
80-
}
81-
if (!is.list(values)) {
82-
if (length(values) == 0L) values <- NA_real_
83-
values <- list(values)
84-
}
85-
86-
values <- as_list_of(values, .ptype = double())
87-
quantile_levels <- as_list_of(quantile_levels, .ptype = double())
88-
args <- vec_recycle_common(values = values, quantile_levels = quantile_levels)
89-
90-
qntls <- as_list_of(
91-
map2(args$values, args$quantile_levels, new_quantiles),
92-
.ptype = new_quantiles(NA_real_, 0.5)
93-
)
94-
new_vctr(qntls, class = "distribution")
95-
}
96-
97-
validate_dist_quantiles <- function(values, quantile_levels) {
98-
map(quantile_levels, arg_is_probabilities)
99-
common_length <- vctrs::vec_size_common( # aborts internally
100-
values = values,
101-
quantile_levels = quantile_levels
102-
)
103-
length_diff <- vctrs::list_sizes(values) != vctrs::list_sizes(quantile_levels)
104-
if (any(length_diff)) {
105-
cli_abort(c(
106-
"`values` and `quantile_levels` must have common length.",
107-
i = "Mismatches found at position(s): {.val {which(length_diff)}}."
108-
))
109-
}
110-
level_duplication <- map_lgl(quantile_levels, vctrs::vec_duplicate_any)
111-
if (any(level_duplication)) {
112-
cli_abort(c(
113-
"`quantile_levels` must not be duplicated.",
114-
i = "Duplicates found at position(s): {.val {which(level_duplication)}}."
115-
))
116-
}
117-
}
118-
119-
120-
is_dist_quantiles <- function(x) {
121-
is_distribution(x) & all(stats::family(x) == "quantiles")
122-
}
123-
124-
125-
126-
#' @export
127-
#' @importFrom stats median qnorm family
128-
median.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "linear")) {
129-
quantile_levels <- field(x, "quantile_levels")
130-
values <- field(x, "values")
131-
# we have exactly that quantile
132-
if (0.5 %in% quantile_levels) {
133-
return(values[match(0.5, quantile_levels)])
134-
}
135-
# if there's only 1 quantile_level (and it isn't 0.5), or the smallest quantile is larger than 0.5 or the largest smaller than 0.5, or if every value is NA, return NA
136-
if (length(quantile_levels) < 2 || min(quantile_levels) > 0.5 || max(quantile_levels) < 0.5 || all(is.na(values))) {
137-
return(NA)
138-
}
139-
if (length(quantile_levels) < 3 || min(quantile_levels) > .25 || max(quantile_levels) < .75) {
140-
return(stats::approx(quantile_levels, values, xout = 0.5)$y)
141-
}
142-
quantile(x, 0.5, ..., middle = middle)
143-
}
1441

1452
# placeholder to avoid errors, but not ideal
1463
#' @export
147-
mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "linear")) {
148-
median(x, ..., middle = middle)
4+
mean.quantile_pred <- function(x, na.rm = FALSE, ...) {
5+
median(x, ...)
1496
}
1507

1518
#' @export
1529
#' @importFrom stats quantile
153-
#' @import distributional
154-
quantile.dist_quantiles <- function(x, p, ..., middle = c("cubic", "linear")) {
10+
quantile.quantile_pred <- function(x, p, ..., middle = c("cubic", "linear")) {
15511
arg_is_probabilities(p)
15612
p <- sort(p)
157-
middle <- match.arg(middle)
13+
middle <- rlang::arg_match(middle)
15814
quantile_extrapolate(x, p, middle)
15915
}
16016

16117

16218
quantile_extrapolate <- function(x, tau_out, middle) {
163-
tau <- field(x, "quantile_levels")
164-
qvals <- field(x, "values")
165-
nas <- is.na(qvals)
166-
if (all(nas)) {
167-
return(rep(NA, times = length(tau_out)))
168-
}
169-
qvals_out <- rep(NA, length(tau_out))
170-
qvals <- qvals[!nas]
171-
tau <- tau[!nas]
19+
tau <- x %@% "quantile_levels"
20+
qvals <- as.matrix(x)
17221

17322
# short circuit if we aren't actually extrapolating
17423
# matches to ~15 decimals
17524
if (all(tau_out %in% tau)) {
176-
return(qvals[match(tau_out, tau)])
25+
return(hardhat::quantile_pred(
26+
qvals[ ,match(tau_out, tau), drop = FALSE], tau_out
27+
))
17728
}
17829
if (length(tau) < 2) {
179-
cli_abort(
180-
"Quantile extrapolation is not possible with fewer than 2 quantiles."
181-
)
182-
return(qvals_out)
30+
cli_abort(paste(
31+
"Quantile extrapolation is not possible when fewer than 2 quantiles",
32+
"are available."
33+
))
18334
}
35+
qvals_out <- map(
36+
vctrs::vec_chop(qvals),
37+
~ extrapolate_quantiles_single(.x, tau, tau_out, middle)
38+
)
18439

40+
hardhat::quantile_pred(qvals_out, tau_out)
41+
}
42+
43+
extrapolate_quantiles_single <- function(qvals, tau, tau_out, middle) {
18544
indl <- tau_out < min(tau)
18645
indr <- tau_out > max(tau)
18746
indm <- !indl & !indr
47+
qvals_out <- rep(NA, length(tau_out))
18848

18949
if (middle == "cubic") {
19050
method <- "cubic"
191-
result <- tryCatch(
192-
{
193-
Q <- stats::splinefun(tau, qvals, method = "hyman")
194-
quartiles <- Q(c(.25, .5, .75))
195-
},
196-
error = function(e) {
197-
return(NA)
198-
}
199-
)
51+
result <- tryCatch({
52+
Q <- stats::splinefun(tau, qvals, method = "hyman")
53+
quartiles <- Q(c(.25, .5, .75))
54+
},
55+
error = function(e) {
56+
return(NA)
57+
})
20058
}
20159
if (middle == "linear" || any(is.na(result))) {
20260
method <- "linear"
20361
quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y
20462
}
20563
if (any(indm)) {
206-
qvals_out[indm] <- switch(method,
64+
qvals_out[indm] <- switch(
65+
method,
20766
linear = stats::approx(tau, qvals, tau_out[indm])$y,
20867
cubic = Q(tau_out[indm])
20968
)
@@ -242,59 +101,3 @@ tail_extrapolate <- function(tau_out, qv) {
242101
m <- diff(y) / diff(x)
243102
m * (x0 - x[1]) + y[1]
244103
}
245-
246-
247-
#' @method Math dist_quantiles
248-
#' @export
249-
Math.dist_quantiles <- function(x, ...) {
250-
quantile_levels <- field(x, "quantile_levels")
251-
values <- field(x, "values")
252-
values <- vctrs::vec_math(.Generic, values, ...)
253-
new_quantiles(values = values, quantile_levels = quantile_levels)
254-
}
255-
256-
#' @method Ops dist_quantiles
257-
#' @export
258-
Ops.dist_quantiles <- function(e1, e2) {
259-
is_quantiles <- c(
260-
inherits(e1, "dist_quantiles"),
261-
inherits(e2, "dist_quantiles")
262-
)
263-
is_dist <- c(inherits(e1, "dist_default"), inherits(e2, "dist_default"))
264-
tau1 <- tau2 <- NULL
265-
if (is_quantiles[1]) {
266-
q1 <- field(e1, "values")
267-
tau1 <- field(e1, "quantile_levels")
268-
}
269-
if (is_quantiles[2]) {
270-
q2 <- field(e2, "values")
271-
tau2 <- field(e2, "quantile_levels")
272-
}
273-
tau <- union(tau1, tau2)
274-
if (all(is_dist)) {
275-
cli_abort(
276-
"You can't perform arithmetic between two distributions like this."
277-
)
278-
} else {
279-
if (is_quantiles[1]) {
280-
q2 <- e2
281-
} else {
282-
q1 <- e1
283-
}
284-
}
285-
q <- vctrs::vec_arith(.Generic, q1, q2)
286-
new_quantiles(values = q, quantile_levels = tau)
287-
}
288-
289-
#' @method is.na distribution
290-
#' @export
291-
is.na.distribution <- function(x) {
292-
sapply(vec_data(x), is.na)
293-
}
294-
295-
#' @method is.na dist_quantiles
296-
#' @export
297-
is.na.dist_quantiles <- function(x) {
298-
q <- field(x, "values")
299-
all(is.na(q))
300-
}

0 commit comments

Comments
 (0)