Skip to content

Djm/resids hotfix #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
^epipredict\.Rproj$
^\.Rproj\.user$
^LICENSE\.md$
^DEVELOPMENT\.md$
^drafts$
^\.Rprofile$
^man-roxygen$
Expand Down
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.9
Version: 0.0.10
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down Expand Up @@ -72,4 +72,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.0
RoxygenNote: 7.3.1
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Two simple forecasters as test beds
- Working vignette
- use `checkmate` for input validation
- refactor quantile extrapolation (possibly creates different results)
99 changes: 30 additions & 69 deletions R/dist_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,15 @@ mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "line
#' @export
#' @importFrom stats quantile
#' @import distributional
quantile.dist_quantiles <- function(
x, p, ...,
middle = c("cubic", "linear"),
left_tail = c("normal", "exponential"),
right_tail = c("normal", "exponential")) {
quantile.dist_quantiles <- function(x, p, ..., middle = c("cubic", "linear")) {
arg_is_probabilities(p)
p <- sort(p)
middle <- match.arg(middle)
left_tail <- match.arg(left_tail)
right_tail <- match.arg(right_tail)
quantile_extrapolate(x, p, middle, left_tail, right_tail)
quantile_extrapolate(x, p, middle)
}


quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) {
quantile_extrapolate <- function(x, tau_out, middle) {
tau <- field(x, "quantile_levels")
qvals <- field(x, "values")
r <- range(tau, na.rm = TRUE)
Expand All @@ -196,10 +191,9 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) {
if (all(tau_out %in% tau)) {
return(qvals[match(tau_out, tau)])
}
if (length(qvals) < 3 || r[1] > .25 || r[2] < .75) {
cli::cli_warn(c(
"Quantile extrapolation is not possible with fewer than",
"3 quantiles or when the probs don't span [.25, .75]"
if (length(qvals) < 2) {
cli::cli_abort(c(
"Quantile extrapolation is not possible with fewer than 2 quantiles."
))
return(qvals_out)
}
Expand All @@ -213,87 +207,54 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) {
result <- tryCatch(
{
Q <- stats::splinefun(tau, qvals, method = "hyman")
qvals_out[indm] <- Q(tau_out[indm])
quartiles <- Q(c(.25, .5, .75))
},
error = function(e) {
return(NA)
}
error = function(e) { return(NA) }
)
}
if (middle == "linear" || any(is.na(result))) {
method <- "linear"
quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y
}


if (any(indm)) {
qvals_out[indm] <- switch(method,
linear = stats::approx(tau, qvals, tau_out[indm])$y,
cubic = Q(tau_out[indm])
)
}
if (any(indl) || any(indr)) {
qv <- data.frame(
q = c(tau, tau_out[indm]),
v = c(qvals, qvals_out[indm])
) %>%
dplyr::distinct(q, .keep_all = TRUE) %>%
dplyr::arrange(q)
}
if (any(indl)) {
qvals_out[indl] <- tail_extrapolate(
tau_out[indl], quartiles, "left", left_tail
)
qvals_out[indl] <- tail_extrapolate(tau_out[indl], head(qv, 2))
}
if (any(indr)) {
qvals_out[indr] <- tail_extrapolate(
tau_out[indr], quartiles, "right", right_tail
)
qvals_out[indr] <- tail_extrapolate(tau_out[indr], tail(qv, 2))
}
qvals_out
}

tail_extrapolate <- function(tau_out, quartiles, tail, type) {
if (tail == "left") {
p <- c(.25, .5)
par <- quartiles[1:2]
}
if (tail == "right") {
p <- c(.75, .5)
par <- quartiles[3:2]
}
if (type == "normal") {
return(norm_tail_q(p, par, tau_out))
}
if (type == "exponential") {
return(exp_tail_q(p, par, tau_out))
}
}


exp_q_par <- function(q) {
# tau should always be c(.75, .5) or c(.25, .5)
iqr <- 2 * abs(diff(q))
s <- iqr / (2 * log(2))
m <- q[2]
return(list(m = m, s = s))
logit <- function(p) {
p <- pmax(pmin(p, 1), 0)
log(p) - log(1 - p)
}

exp_tail_q <- function(p, q, target) {
ms <- exp_q_par(q)
qlaplace(target, ms$m, ms$s)
# extrapolates linearly on the logistic scale using
# the two points nearest the tail
tail_extrapolate <- function(tau_out, qv) {
if (nrow(qv) == 1L) return(rep(qv$v[1], length(tau_out)))
x <- logit(qv$q)
x0 <- logit(tau_out)
y <- qv$v
m <- diff(y) / diff(x)
m * (x0 - x[1]) + y[1]
}

qlaplace <- function(p, centre = 0, b = 1) {
# lower.tail = TRUE, log.p = FALSE
centre - b * sign(p - 0.5) * log(1 - 2 * abs(p - 0.5))
}

norm_q_par <- function(q) {
# tau should always be c(.75, .5) or c(.25, .5)
iqr <- 2 * abs(diff(q))
s <- iqr / 1.34897950039 # abs(diff(qnorm(c(.75, .25))))
m <- q[2]
return(list(m = m, s = s))
}

norm_tail_q <- function(p, q, target) {
ms <- norm_q_par(q)
stats::qnorm(target, ms$m, ms$s)
}

#' @method Math dist_quantiles
#' @export
Expand Down
8 changes: 5 additions & 3 deletions R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,16 @@ slather.layer_residual_quantiles <-
if (length(common) > 0L) {
r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid")))
common_in_r <- common[common %in% names(r)]
if (length(common_in_r) != length(common)) {
if (length(common_in_r) == length(common)) {
r <- dplyr::left_join(key_cols, r, by = common_in_r)
} else {
cli::cli_warn(c(
"Some grouping keys are not in data.frame returned by the",
"`residuals()` method. Groupings may not be correct."
))
r <- dplyr::bind_cols(key_cols, r %>% dplyr::select(.resid)) %>%
dplyr::group_by(!!!rlang::syms(common))
}
r <- dplyr::bind_cols(key_cols, r) %>%
dplyr::group_by(!!!rlang::syms(common))
}
}

Expand Down
36 changes: 28 additions & 8 deletions tests/testthat/test-dist_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@ test_that("constructor returns reasonable quantiles", {
expect_error(new_quantiles(c(1, 2, 3), c(.1, .2, 3)))
})

test_that("tail functions give reasonable output", {
expect_equal(norm_q_par(qnorm(c(.75, .5), 10, 5)), list(m = 10, s = 5))
expect_equal(norm_q_par(qnorm(c(.25, .5), 10, 5)), list(m = 10, s = 5))
expect_equal(norm_q_par(qnorm(c(.25, .5), 0, 1)), list(m = 0, s = 1))
expect_equal(exp_q_par(qlaplace(c(.75, .5), 10, 5)), list(m = 10, s = 5))
expect_equal(exp_q_par(qlaplace(c(.25, .5), 10, 5)), list(m = 10, s = 5))
expect_equal(exp_q_par(qlaplace(c(.25, .5), 0, 1)), list(m = 0, s = 1))
})

test_that("single dist_quantiles works, quantiles are accessible", {
z <- new_quantiles(values = 1:5, quantile_levels = c(.2, .4, .5, .6, .8))
Expand Down Expand Up @@ -49,6 +41,34 @@ test_that("quantile extrapolator works", {
expect_length(parameters(qq[1])$q[[1]], 7L)
})

test_that("small deviations of quantile requests work", {
l <- c(.05, .1, .25, .75, .9, .95)
v <- c(0.0890306, 0.1424997, 0.1971793, 0.2850978, 0.3832912, 0.4240479)
badl <- l
badl[1] <- badl[1] - 1e-14
distn <- dist_quantiles(list(v), list(l))

# was broken before, now works
expect_equal(quantile(distn, l), quantile(distn, badl))

# The tail extrapolation was still poor. It needs to _always_ use
# the smallest (largest) values or we could end up unsorted
l <- 1:9 / 10
v <- 1:9
distn <- dist_quantiles(list(v), list(l))
expect_equal(quantile(distn, c(.25, .75)), list(c(2.5, 7.5)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check, this is 2.5 because that's halfway between 2 and 3, which are the .2 and .3 quantile values?

expect_equal(quantile(distn, c(.1, .9)), list(c(1, 9)))
qv <- data.frame(q = l, v = v)
expect_equal(
unlist(quantile(distn, c(.01, .05))),
tail_extrapolate(c(.01, .05), head(qv, 2))
)
Comment on lines +62 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I quite get how this is testing the tail behavior; is it just checking that quantile is using tail_extrapolate to calculate the values at those quantiles?

expect_equal(
unlist(quantile(distn, c(.99, .95))),
tail_extrapolate(c(.95, .99), tail(qv, 2))
)
})

test_that("unary math works on quantiles", {
dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8)))
dstn2 <- dist_quantiles(list(log(1:4), log(8:11)), list(c(.2, .4, .6, .8)))
Expand Down
49 changes: 49 additions & 0 deletions tests/testthat/test-layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,52 @@ test_that("Errors when used with a classifier", {
wf <- wf %>% add_frosting(f)
expect_error(predict(wf, tib))
})


test_that("Grouping by keys is supported", {
f <- frosting() %>%
layer_predict() %>%
layer_naomit(.pred) %>%
layer_residual_quantiles()
wf1 <- wf %>% add_frosting(f)
expect_silent(p1 <- predict(wf1, latest))
f2 <- frosting() %>%
layer_predict() %>%
layer_naomit(.pred) %>%
layer_residual_quantiles(by_key = "geo_value")
wf2 <- wf %>% add_frosting(f2)
expect_warning(p2 <- predict(wf2, latest))

pivot1 <- pivot_quantiles_wider(p1, .pred_distn) %>%
mutate(width = `0.95` - `0.05`)
pivot2 <- pivot_quantiles_wider(p2, .pred_distn) %>%
mutate(width = `0.95` - `0.05`)
expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1)))
expect_false(all(pivot2$width == pivot2$width[1]))
})

test_that("Canned forecasters work with / without", {
meta <- attr(jhu, "metadata")
meta$as_of <- max(jhu$time_value)
attr(jhu, "metadata") <- meta

expect_silent(
flatline_forecaster(jhu, "death_rate")
)
expect_silent(
flatline_forecaster(
jhu, "death_rate",
args_list = flatline_args_list(quantile_by_key = "geo_value")
)
)

expect_silent(
arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"))
)
expect_silent(
flatline_forecaster(
jhu, "death_rate",
args_list = flatline_args_list(quantile_by_key = "geo_value")
)
)
})