Skip to content

Commit 7a87ae6

Browse files
dajmcdondsweber2
authored andcommitted
add math ops and related tests
1 parent 6e76e2b commit 7a87ae6

File tree

4 files changed

+131
-92
lines changed

4 files changed

+131
-92
lines changed

NAMESPACE

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ S3method(extract_frosting,default)
3939
S3method(extract_frosting,epi_workflow)
4040
S3method(extract_layers,frosting)
4141
S3method(extract_layers,workflow)
42-
S3method(extrapolate_quantiles,dist_default)
43-
S3method(extrapolate_quantiles,dist_quantiles)
44-
S3method(extrapolate_quantiles,distribution)
42+
S3method(extrapolate_quantiles,quantile_pred)
4543
S3method(fit,epi_workflow)
4644
S3method(flusight_hub_formatter,canned_epipred)
4745
S3method(flusight_hub_formatter,data.frame)
@@ -119,6 +117,10 @@ S3method(tidy,check_enough_train_data)
119117
S3method(tidy,frosting)
120118
S3method(tidy,layer)
121119
S3method(update,layer)
120+
S3method(vec_arith,quantile_pred)
121+
S3method(vec_arith.numeric,quantile_pred)
122+
S3method(vec_arith.quantile_pred,numeric)
123+
S3method(vec_math,quantile_pred)
122124
S3method(weighted_interval_score,default)
123125
S3method(weighted_interval_score,dist_default)
124126
S3method(weighted_interval_score,dist_quantiles)
@@ -333,6 +335,8 @@ importFrom(tidyr,pivot_wider)
333335
importFrom(tidyr,unnest)
334336
importFrom(tidyselect,all_of)
335337
importFrom(utils,capture.output)
338+
importFrom(vctrs,vec_arith)
339+
importFrom(vctrs,vec_arith.numeric)
336340
importFrom(vctrs,vec_cast)
337-
importFrom(vctrs,vec_data)
341+
importFrom(vctrs,vec_math)
338342
importFrom(workflows,extract_preprocessor)

R/extrapolate_quantiles.R

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,34 +32,28 @@ extrapolate_quantiles <- function(x, probs, replace_na = TRUE, ...) {
3232
}
3333

3434
#' @export
35-
#' @importFrom vctrs vec_data
36-
extrapolate_quantiles.distribution <- function(x, probs, replace_na = TRUE, ...) {
37-
rlang::check_dots_empty()
35+
extrapolate_quantiles.quantile_pred <- function(x, probs, replace_na = TRUE, ...) {
3836
arg_is_lgl_scalar(replace_na)
3937
arg_is_probabilities(probs)
4038
if (is.unsorted(probs)) probs <- sort(probs)
41-
dstn <- lapply(vec_data(x), extrapolate_quantiles, probs = probs, replace_na = replace_na)
42-
new_vctr(dstn, vars = NULL, class = "distribution")
43-
}
44-
45-
#' @export
46-
extrapolate_quantiles.dist_default <- function(x, probs, replace_na = TRUE, ...) {
47-
values <- quantile(x, probs, ...)
48-
new_quantiles(values = values, quantile_levels = probs)
49-
}
39+
orig_probs <- x %@% "quantile_levels"
40+
orig_values <- as.matrix(x)
5041

51-
#' @export
52-
extrapolate_quantiles.dist_quantiles <- function(x, probs, replace_na = TRUE, ...) {
53-
orig_probs <- field(x, "quantile_levels")
54-
orig_values <- field(x, "values")
55-
new_probs <- c(orig_probs, probs)
56-
dups <- duplicated(new_probs)
5742
if (!replace_na || !anyNA(orig_values)) {
58-
new_values <- c(orig_values, quantile(x, probs, ...))
43+
all_values <- cbind(orig_values, quantile(x, probs, ...))
5944
} else {
60-
nas <- is.na(orig_values)
61-
orig_values[nas] <- quantile(x, orig_probs[nas], ...)
62-
new_values <- c(orig_values, quantile(x, probs, ...))
45+
newx <- quantile(x, orig_probs, ...) %>%
46+
hardhat::quantile_pred(orig_probs)
47+
all_values <- cbind(as.matrix(newx), quantile(newx, probs, ...))
6348
}
64-
new_quantiles(new_values[!dups], new_probs[!dups])
49+
all_probs <- c(orig_probs, probs)
50+
dups <- duplicated(all_probs)
51+
all_values <- all_values[, !dups, drop = FALSE]
52+
all_probs <- all_probs[!dups]
53+
o <- order(all_probs)
54+
55+
hardhat::quantile_pred(
56+
all_values[, o, drop = FALSE],
57+
quantile_levels = all_probs[o]
58+
)
6559
}

R/dist_quantiles.R renamed to R/quantile_pred-methods.R

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,29 @@ mean.quantile_pred <- function(x, na.rm = FALSE, ...) {
55
median(x, ...)
66
}
77

8+
9+
# quantiles by treating quantile_pred like a distribution -----------------
10+
11+
812
#' @export
913
#' @importFrom stats quantile
10-
quantile.quantile_pred <- function(x, p, ..., middle = c("cubic", "linear")) {
14+
quantile.quantile_pred <- function(x, p, na.rm = FALSE, ...,
15+
middle = c("cubic", "linear")) {
1116
arg_is_probabilities(p)
1217
p <- sort(p)
1318
middle <- rlang::arg_match(middle)
14-
quantile_extrapolate(x, p, middle)
19+
quantile_internal(x, p, middle)
1520
}
1621

1722

18-
quantile_extrapolate <- function(x, tau_out, middle) {
23+
quantile_internal <- function(x, tau_out, middle) {
1924
tau <- x %@% "quantile_levels"
2025
qvals <- as.matrix(x)
2126

2227
# short circuit if we aren't actually extrapolating
2328
# matches to ~15 decimals
24-
if (all(tau_out %in% tau)) {
25-
return(hardhat::quantile_pred(
26-
qvals[ ,match(tau_out, tau), drop = FALSE], tau_out
27-
))
29+
if (all(tau_out %in% tau) && !anyNA(qvals)) {
30+
return(qvals[ , match(tau_out, tau), drop = FALSE])
2831
}
2932
if (length(tau) < 2) {
3033
cli_abort(paste(
@@ -36,15 +39,26 @@ quantile_extrapolate <- function(x, tau_out, middle) {
3639
vctrs::vec_chop(qvals),
3740
~ extrapolate_quantiles_single(.x, tau, tau_out, middle)
3841
)
39-
40-
hardhat::quantile_pred(qvals_out, tau_out)
42+
qvals_out <- do.call(rbind, qvals_out) # ensure a matrix of the proper dims
43+
qvals_out
4144
}
4245

4346
extrapolate_quantiles_single <- function(qvals, tau, tau_out, middle) {
47+
qvals_out <- rep(NA, length(tau_out))
48+
good <- !is.na(qvals)
49+
qvals <- qvals[good]
50+
tau <- tau[good]
51+
52+
# in case we only have one point, and it matches something we wanted
53+
if (length(good) < 2) {
54+
matched_one <- tau_out %in% tau
55+
qvals_out[matched_one] <- qvals[matched_one]
56+
return(qvals_out)
57+
}
58+
4459
indl <- tau_out < min(tau)
4560
indr <- tau_out > max(tau)
4661
indm <- !indl & !indr
47-
qvals_out <- rep(NA, length(tau_out))
4862

4963
if (middle == "cubic") {
5064
method <- "cubic"
@@ -101,3 +115,44 @@ tail_extrapolate <- function(tau_out, qv) {
101115
m <- diff(y) / diff(x)
102116
m * (x0 - x[1]) + y[1]
103117
}
118+
119+
120+
# mathematical operations on the values -----------------------------------
121+
122+
123+
#' @importFrom vctrs vec_math
124+
#' @export
125+
#' @method vec_math quantile_pred
126+
vec_math.quantile_pred <- function(.fn, .x, ...) {
127+
fn <- .fn
128+
.fn <- getExportedValue("base", .fn)
129+
if (fn %in% c("any", "all", "prod", "sum", "cumsum", "cummax", "cummin", "cumprod")) {
130+
cli_abort("{.fn {fn}} is not a supported operation for {.cls quantile_pred}.")
131+
}
132+
quantile_levels <- .x %@% "quantile_levels"
133+
.x <- as.matrix(.x)
134+
hardhat::quantile_pred(.fn(.x), quantile_levels)
135+
}
136+
137+
#' @importFrom vctrs vec_arith vec_arith.numeric
138+
#' @export
139+
#' @method vec_arith quantile_pred
140+
vec_arith.quantile_pred <- function(op, x, y, ...) {
141+
UseMethod("vec_arith.quantile_pred", y)
142+
}
143+
144+
#' @export
145+
#' @method vec_arith.quantile_pred numeric
146+
vec_arith.quantile_pred.numeric <- function(op, x, y, ...) {
147+
op_fn <- getExportedValue("base", op)
148+
out <- op_fn(as.matrix(x), y)
149+
hardhat::quantile_pred(out, x %@% "quantile_levels")
150+
}
151+
152+
#' @export
153+
#' @method vec_arith.numeric quantile_pred
154+
vec_arith.numeric.quantile_pred <- function(op, x, y, ...) {
155+
op_fn <- getExportedValue("base", op)
156+
out <- op_fn(x, as.matrix(y))
157+
hardhat::quantile_pred(out, y %@% "quantile_levels")
158+
}

tests/testthat/test-dist_quantiles.R

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,53 +4,38 @@ test_that("single quantile_pred works, quantiles are accessible", {
44
quantile_levels = c(.2, .4, .5, .6, .8)
55
)
66
expect_equal(median(z), 3)
7-
expect_equal(
8-
quantile(z, c(.2, .4, .5, .6, .8)),
9-
hardhat::quantile_pred(matrix(1:5, nrow = 1), c(.2, .4, .5, .6, .8))
10-
)
7+
expect_equal(quantile(z, c(.2, .4, .5, .6, .8)), matrix(1:5, nrow = 1))
118
expect_equal(
129
quantile(z, c(.3, .7), middle = "linear"),
13-
hardhat::quantile_pred(matrix(c(1.5, 4.5), nrow = 1), c(.3, .7))
10+
matrix(c(1.5, 4.5), nrow = 1)
1411
)
1512

1613
Q <- stats::splinefun(c(.2, .4, .5, .6, .8), 1:5, method = "hyman")
17-
expect_equal(quantile(z, c(.3, .7), middle = "cubic"), Q(c(.3, .7)))
14+
expect_equal(quantile(z, c(.3, .7)), Q(c(.3, .7)))
1815
expect_identical(
1916
extrapolate_quantiles(z, c(.3, .7), middle = "linear"),
20-
hardhat::quantile_pred(c(1, 1.5, 2, 3, 4, 4.5, 5), 2:8 / 10)
17+
hardhat::quantile_pred(matrix(c(1, 1.5, 2, 3, 4, 4.5, 5), nrow = 1), 2:8 / 10)
2118
)
22-
# empty values slot results in a length zero distribution
23-
# see issue #361
24-
# expect_length(dist_quantiles(list(), c(.1, .9)), 0L)
25-
# expect_identical(
26-
# dist_quantiles(list(), c(.1, .9)),
27-
# distributional::dist_degenerate(double())
28-
# )
2919
})
3020

3121

3222
test_that("quantile extrapolator works", {
33-
dstn <- dist_normal(c(10, 2), c(5, 10))
34-
qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75))
35-
expect_s3_class(qq, "distribution")
36-
expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles")
37-
expect_length(parameters(qq[1])$quantile_levels[[1]], 3L)
38-
39-
40-
dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
23+
dstn <- hardhat::quantile_pred(
24+
matrix(c(1:4, 8:11), nrow = 2, byrow = TRUE),
25+
c(.2, .4, .6, .8)
26+
)
4127
qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75))
42-
expect_s3_class(qq, "distribution")
43-
expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles")
44-
expect_length(parameters(qq[1])$quantile_levels[[1]], 7L)
28+
expect_s3_class(qq, c("quantile_pred", "vctrs_vctr", "list"))
29+
expect_length(qq %@% "quantile_levels", 7L)
4530

46-
dstn <- dist_quantiles(1:4, 1:4 / 5)
31+
dstn <- hardhat::quantile_pred(matrix(1:4, nrow = 1), 1:4 / 5)
4732
qq <- extrapolate_quantiles(dstn, 1:9 / 10)
48-
dstn_na <- dist_quantiles(c(1, 2, NA, 4), 1:4 / 5)
33+
dstn_na <- hardhat::quantile_pred(matrix(c(1, 2, NA, 4), nrow = 1), 1:4 / 5)
4934
qq2 <- extrapolate_quantiles(dstn_na, 1:9 / 10)
5035
expect_equal(qq, qq2)
5136
qq3 <- extrapolate_quantiles(dstn_na, 1:9 / 10, replace_na = FALSE)
52-
qq2_vals <- field(vec_data(qq2)[[1]], "values")
53-
qq3_vals <- field(vec_data(qq3)[[1]], "values")
37+
qq2_vals <- unlist(qq2)
38+
qq3_vals <- unlist(qq3)
5439
qq2_vals[6] <- NA
5540
expect_equal(qq2_vals, qq3_vals)
5641
})
@@ -60,7 +45,7 @@ test_that("small deviations of quantile requests work", {
6045
v <- c(0.0890306, 0.1424997, 0.1971793, 0.2850978, 0.3832912, 0.4240479)
6146
badl <- l
6247
badl[1] <- badl[1] - 1e-14
63-
distn <- dist_quantiles(list(v), list(l))
48+
distn <- hardhat::quantile_pred(matrix(v, nrow = 1), l)
6449

6550
# was broken before, now works
6651
expect_equal(quantile(distn, l), quantile(distn, badl))
@@ -69,50 +54,51 @@ test_that("small deviations of quantile requests work", {
6954
# the smallest (largest) values or we could end up unsorted
7055
l <- 1:9 / 10
7156
v <- 1:9
72-
distn <- dist_quantiles(list(v), list(l))
73-
expect_equal(quantile(distn, c(.25, .75)), list(c(2.5, 7.5)))
74-
expect_equal(quantile(distn, c(.1, .9)), list(c(1, 9)))
57+
distn <- hardhat::quantile_pred(matrix(v, nrow = 1), l)
58+
expect_equal(quantile(distn, c(.25, .75)), matrix(c(2.5, 7.5), nrow = 1))
59+
expect_equal(quantile(distn, c(.1, .9)), matrix(c(1, 9), nrow = 1))
7560
qv <- data.frame(q = l, v = v)
7661
expect_equal(
77-
unlist(quantile(distn, c(.01, .05))),
62+
drop(quantile(distn, c(.01, .05))),
7863
tail_extrapolate(c(.01, .05), head(qv, 2))
7964
)
8065
expect_equal(
81-
unlist(quantile(distn, c(.99, .95))),
66+
drop(quantile(distn, c(.99, .95))),
8267
tail_extrapolate(c(.95, .99), tail(qv, 2))
8368
)
8469
})
8570

8671
test_that("unary math works on quantiles", {
87-
dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
88-
dstn2 <- dist_quantiles(list(log(1:4), log(8:11)), list(c(.2, .4, .6, .8)))
72+
dstn <- hardhat::quantile_pred(
73+
matrix(c(1:4, 8:11), nrow = 2, byrow = TRUE),
74+
1:4 / 5
75+
)
76+
dstn2 <- hardhat::quantile_pred(
77+
log(matrix(c(1:4, 8:11), nrow = 2, byrow = TRUE)),
78+
1:4 / 5
79+
)
8980
expect_identical(log(dstn), dstn2)
9081

91-
dstn2 <- dist_quantiles(list(cumsum(1:4), cumsum(8:11)), list(c(.2, .4, .6, .8)))
92-
expect_identical(cumsum(dstn), dstn2)
9382
})
9483

9584
test_that("arithmetic works on quantiles", {
96-
dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
97-
dstn2 <- dist_quantiles(list(1:4 + 1, 8:11 + 1), list(c(.2, .4, .6, .8)))
85+
dstn <- hardhat::quantile_pred(
86+
matrix(c(1:4, 8:11), nrow = 2, byrow = TRUE),
87+
1:4 / 5
88+
)
89+
dstn2 <- hardhat::quantile_pred(
90+
matrix(c(1:4, 8:11), nrow = 2, byrow = TRUE) + 1,
91+
1:4 / 5
92+
)
9893
expect_identical(dstn + 1, dstn2)
9994
expect_identical(1 + dstn, dstn2)
10095

101-
dstn2 <- dist_quantiles(list(1:4 / 4, 8:11 / 4), list(c(.2, .4, .6, .8)))
96+
dstn2 <- hardhat::quantile_pred(
97+
matrix(c(1:4, 8:11), nrow = 2, byrow = TRUE) / 4,
98+
1:4 / 5
99+
)
102100
expect_identical(dstn / 4, dstn2)
103101
expect_identical((1 / 4) * dstn, dstn2)
104102

105-
expect_snapshot(error = TRUE, sum(dstn))
106-
expect_snapshot(error = TRUE, suppressWarnings(dstn + distributional::dist_normal()))
107-
})
108-
109-
test_that("quantile.dist_quantile works for NA vectors", {
110-
distn <- dist_quantiles(
111-
list(c(NA, NA)),
112-
list(1:2 / 3)
113-
)
114-
expect_true(is.na(quantile(distn, p = 0.5)))
115-
expect_true(is.na(median(distn)))
116-
expect_true(is.na(mean(distn)))
117-
expect_equal(format(distn), "quantiles(NA)[2]")
103+
expect_error(sum(dstn))
118104
})

0 commit comments

Comments
 (0)