Skip to content

Commit 7a4ea55

Browse files
authored
Merge pull request #294 from cmu-delphi/djm/resids-hotfix
Djm/resids hotfix
2 parents c21f366 + 3b8c889 commit 7a4ea55

7 files changed

+117
-79
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
^epipredict\.Rproj$
44
^\.Rproj\.user$
55
^LICENSE\.md$
6+
^DEVELOPMENT\.md$
67
^drafts$
78
^\.Rprofile$
89
^man-roxygen$

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.9
3+
Version: 0.0.10
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -72,4 +72,4 @@ Config/testthat/edition: 3
7272
Encoding: UTF-8
7373
LazyData: true
7474
Roxygen: list(markdown = TRUE)
75-
RoxygenNote: 7.3.0
75+
RoxygenNote: 7.3.1

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
3232
- Two simple forecasters as test beds
3333
- Working vignette
3434
- use `checkmate` for input validation
35+
- refactor quantile extrapolation (possibly creates different results)

R/dist_quantiles.R

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -172,20 +172,15 @@ mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "line
172172
#' @export
173173
#' @importFrom stats quantile
174174
#' @import distributional
175-
quantile.dist_quantiles <- function(
176-
x, p, ...,
177-
middle = c("cubic", "linear"),
178-
left_tail = c("normal", "exponential"),
179-
right_tail = c("normal", "exponential")) {
175+
quantile.dist_quantiles <- function(x, p, ..., middle = c("cubic", "linear")) {
180176
arg_is_probabilities(p)
177+
p <- sort(p)
181178
middle <- match.arg(middle)
182-
left_tail <- match.arg(left_tail)
183-
right_tail <- match.arg(right_tail)
184-
quantile_extrapolate(x, p, middle, left_tail, right_tail)
179+
quantile_extrapolate(x, p, middle)
185180
}
186181

187182

188-
quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) {
183+
quantile_extrapolate <- function(x, tau_out, middle) {
189184
tau <- field(x, "quantile_levels")
190185
qvals <- field(x, "values")
191186
r <- range(tau, na.rm = TRUE)
@@ -196,10 +191,9 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) {
196191
if (all(tau_out %in% tau)) {
197192
return(qvals[match(tau_out, tau)])
198193
}
199-
if (length(qvals) < 3 || r[1] > .25 || r[2] < .75) {
200-
cli::cli_warn(c(
201-
"Quantile extrapolation is not possible with fewer than",
202-
"3 quantiles or when the probs don't span [.25, .75]"
194+
if (length(qvals) < 2) {
195+
cli::cli_abort(c(
196+
"Quantile extrapolation is not possible with fewer than 2 quantiles."
203197
))
204198
return(qvals_out)
205199
}
@@ -213,7 +207,6 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) {
213207
result <- tryCatch(
214208
{
215209
Q <- stats::splinefun(tau, qvals, method = "hyman")
216-
qvals_out[indm] <- Q(tau_out[indm])
217210
quartiles <- Q(c(.25, .5, .75))
218211
},
219212
error = function(e) {
@@ -225,75 +218,47 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) {
225218
method <- "linear"
226219
quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y
227220
}
228-
229-
230221
if (any(indm)) {
231222
qvals_out[indm] <- switch(method,
232223
linear = stats::approx(tau, qvals, tau_out[indm])$y,
233224
cubic = Q(tau_out[indm])
234225
)
235226
}
227+
if (any(indl) || any(indr)) {
228+
qv <- data.frame(
229+
q = c(tau, tau_out[indm]),
230+
v = c(qvals, qvals_out[indm])
231+
) %>%
232+
dplyr::distinct(q, .keep_all = TRUE) %>%
233+
dplyr::arrange(q)
234+
}
236235
if (any(indl)) {
237-
qvals_out[indl] <- tail_extrapolate(
238-
tau_out[indl], quartiles, "left", left_tail
239-
)
236+
qvals_out[indl] <- tail_extrapolate(tau_out[indl], head(qv, 2))
240237
}
241238
if (any(indr)) {
242-
qvals_out[indr] <- tail_extrapolate(
243-
tau_out[indr], quartiles, "right", right_tail
244-
)
239+
qvals_out[indr] <- tail_extrapolate(tau_out[indr], tail(qv, 2))
245240
}
246241
qvals_out
247242
}
248243

249-
tail_extrapolate <- function(tau_out, quartiles, tail, type) {
250-
if (tail == "left") {
251-
p <- c(.25, .5)
252-
par <- quartiles[1:2]
253-
}
254-
if (tail == "right") {
255-
p <- c(.75, .5)
256-
par <- quartiles[3:2]
257-
}
258-
if (type == "normal") {
259-
return(norm_tail_q(p, par, tau_out))
260-
}
261-
if (type == "exponential") {
262-
return(exp_tail_q(p, par, tau_out))
263-
}
244+
logit <- function(p) {
245+
p <- pmax(pmin(p, 1), 0)
246+
log(p) - log(1 - p)
264247
}
265248

266-
267-
exp_q_par <- function(q) {
268-
# tau should always be c(.75, .5) or c(.25, .5)
269-
iqr <- 2 * abs(diff(q))
270-
s <- iqr / (2 * log(2))
271-
m <- q[2]
272-
return(list(m = m, s = s))
273-
}
274-
275-
exp_tail_q <- function(p, q, target) {
276-
ms <- exp_q_par(q)
277-
qlaplace(target, ms$m, ms$s)
278-
}
279-
280-
qlaplace <- function(p, centre = 0, b = 1) {
281-
# lower.tail = TRUE, log.p = FALSE
282-
centre - b * sign(p - 0.5) * log(1 - 2 * abs(p - 0.5))
283-
}
284-
285-
norm_q_par <- function(q) {
286-
# tau should always be c(.75, .5) or c(.25, .5)
287-
iqr <- 2 * abs(diff(q))
288-
s <- iqr / 1.34897950039 # abs(diff(qnorm(c(.75, .25))))
289-
m <- q[2]
290-
return(list(m = m, s = s))
249+
# extrapolates linearly on the logistic scale using
250+
# the two points nearest the tail
251+
tail_extrapolate <- function(tau_out, qv) {
252+
if (nrow(qv) == 1L) {
253+
return(rep(qv$v[1], length(tau_out)))
254+
}
255+
x <- logit(qv$q)
256+
x0 <- logit(tau_out)
257+
y <- qv$v
258+
m <- diff(y) / diff(x)
259+
m * (x0 - x[1]) + y[1]
291260
}
292261

293-
norm_tail_q <- function(p, q, target) {
294-
ms <- norm_q_par(q)
295-
stats::qnorm(target, ms$m, ms$s)
296-
}
297262

298263
#' @method Math dist_quantiles
299264
#' @export

R/layer_residual_quantiles.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,16 @@ slather.layer_residual_quantiles <-
103103
if (length(common) > 0L) {
104104
r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid")))
105105
common_in_r <- common[common %in% names(r)]
106-
if (length(common_in_r) != length(common)) {
106+
if (length(common_in_r) == length(common)) {
107+
r <- dplyr::left_join(key_cols, r, by = common_in_r)
108+
} else {
107109
cli::cli_warn(c(
108110
"Some grouping keys are not in data.frame returned by the",
109111
"`residuals()` method. Groupings may not be correct."
110112
))
113+
r <- dplyr::bind_cols(key_cols, r %>% dplyr::select(.resid)) %>%
114+
dplyr::group_by(!!!rlang::syms(common))
111115
}
112-
r <- dplyr::bind_cols(key_cols, r) %>%
113-
dplyr::group_by(!!!rlang::syms(common))
114116
}
115117
}
116118

tests/testthat/test-dist_quantiles.R

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@ test_that("constructor returns reasonable quantiles", {
1010
expect_error(new_quantiles(c(1, 2, 3), c(.1, .2, 3)))
1111
})
1212

13-
test_that("tail functions give reasonable output", {
14-
expect_equal(norm_q_par(qnorm(c(.75, .5), 10, 5)), list(m = 10, s = 5))
15-
expect_equal(norm_q_par(qnorm(c(.25, .5), 10, 5)), list(m = 10, s = 5))
16-
expect_equal(norm_q_par(qnorm(c(.25, .5), 0, 1)), list(m = 0, s = 1))
17-
expect_equal(exp_q_par(qlaplace(c(.75, .5), 10, 5)), list(m = 10, s = 5))
18-
expect_equal(exp_q_par(qlaplace(c(.25, .5), 10, 5)), list(m = 10, s = 5))
19-
expect_equal(exp_q_par(qlaplace(c(.25, .5), 0, 1)), list(m = 0, s = 1))
20-
})
2113

2214
test_that("single dist_quantiles works, quantiles are accessible", {
2315
z <- new_quantiles(values = 1:5, quantile_levels = c(.2, .4, .5, .6, .8))
@@ -49,6 +41,34 @@ test_that("quantile extrapolator works", {
4941
expect_length(parameters(qq[1])$q[[1]], 7L)
5042
})
5143

44+
test_that("small deviations of quantile requests work", {
45+
l <- c(.05, .1, .25, .75, .9, .95)
46+
v <- c(0.0890306, 0.1424997, 0.1971793, 0.2850978, 0.3832912, 0.4240479)
47+
badl <- l
48+
badl[1] <- badl[1] - 1e-14
49+
distn <- dist_quantiles(list(v), list(l))
50+
51+
# was broken before, now works
52+
expect_equal(quantile(distn, l), quantile(distn, badl))
53+
54+
# The tail extrapolation was still poor. It needs to _always_ use
55+
# the smallest (largest) values or we could end up unsorted
56+
l <- 1:9 / 10
57+
v <- 1:9
58+
distn <- dist_quantiles(list(v), list(l))
59+
expect_equal(quantile(distn, c(.25, .75)), list(c(2.5, 7.5)))
60+
expect_equal(quantile(distn, c(.1, .9)), list(c(1, 9)))
61+
qv <- data.frame(q = l, v = v)
62+
expect_equal(
63+
unlist(quantile(distn, c(.01, .05))),
64+
tail_extrapolate(c(.01, .05), head(qv, 2))
65+
)
66+
expect_equal(
67+
unlist(quantile(distn, c(.99, .95))),
68+
tail_extrapolate(c(.95, .99), tail(qv, 2))
69+
)
70+
})
71+
5272
test_that("unary math works on quantiles", {
5373
dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
5474
dstn2 <- dist_quantiles(list(log(1:4), log(8:11)), list(c(.2, .4, .6, .8)))

tests/testthat/test-layer_residual_quantiles.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,52 @@ test_that("Errors when used with a classifier", {
4949
wf <- wf %>% add_frosting(f)
5050
expect_error(predict(wf, tib))
5151
})
52+
53+
54+
test_that("Grouping by keys is supported", {
55+
f <- frosting() %>%
56+
layer_predict() %>%
57+
layer_naomit(.pred) %>%
58+
layer_residual_quantiles()
59+
wf1 <- wf %>% add_frosting(f)
60+
expect_silent(p1 <- predict(wf1, latest))
61+
f2 <- frosting() %>%
62+
layer_predict() %>%
63+
layer_naomit(.pred) %>%
64+
layer_residual_quantiles(by_key = "geo_value")
65+
wf2 <- wf %>% add_frosting(f2)
66+
expect_warning(p2 <- predict(wf2, latest))
67+
68+
pivot1 <- pivot_quantiles_wider(p1, .pred_distn) %>%
69+
mutate(width = `0.95` - `0.05`)
70+
pivot2 <- pivot_quantiles_wider(p2, .pred_distn) %>%
71+
mutate(width = `0.95` - `0.05`)
72+
expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1)))
73+
expect_false(all(pivot2$width == pivot2$width[1]))
74+
})
75+
76+
test_that("Canned forecasters work with / without", {
77+
meta <- attr(jhu, "metadata")
78+
meta$as_of <- max(jhu$time_value)
79+
attr(jhu, "metadata") <- meta
80+
81+
expect_silent(
82+
flatline_forecaster(jhu, "death_rate")
83+
)
84+
expect_silent(
85+
flatline_forecaster(
86+
jhu, "death_rate",
87+
args_list = flatline_args_list(quantile_by_key = "geo_value")
88+
)
89+
)
90+
91+
expect_silent(
92+
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"))
93+
)
94+
expect_silent(
95+
flatline_forecaster(
96+
jhu, "death_rate",
97+
args_list = flatline_args_list(quantile_by_key = "geo_value")
98+
)
99+
)
100+
})

0 commit comments

Comments
 (0)