diff --git a/.Rbuildignore b/.Rbuildignore index 3a77bb347..d0c10332d 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -3,6 +3,7 @@ ^epipredict\.Rproj$ ^\.Rproj\.user$ ^LICENSE\.md$ +^DEVELOPMENT\.md$ ^drafts$ ^\.Rprofile$ ^man-roxygen$ diff --git a/DESCRIPTION b/DESCRIPTION index c451f755b..e068fe7e3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.9 +Version: 0.0.10 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), @@ -72,4 +72,4 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.0 +RoxygenNote: 7.3.1 diff --git a/NEWS.md b/NEWS.md index 3c5034080..04dc78e4f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -32,3 +32,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Two simple forecasters as test beds - Working vignette - use `checkmate` for input validation +- refactor quantile extrapolation (possibly creates different results) diff --git a/R/dist_quantiles.R b/R/dist_quantiles.R index 750e9560d..7f7af40f4 100644 --- a/R/dist_quantiles.R +++ b/R/dist_quantiles.R @@ -172,20 +172,15 @@ mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "line #' @export #' @importFrom stats quantile #' @import distributional -quantile.dist_quantiles <- function( - x, p, ..., - middle = c("cubic", "linear"), - left_tail = c("normal", "exponential"), - right_tail = c("normal", "exponential")) { +quantile.dist_quantiles <- function(x, p, ..., middle = c("cubic", "linear")) { arg_is_probabilities(p) + p <- sort(p) middle <- match.arg(middle) - left_tail <- match.arg(left_tail) - right_tail <- match.arg(right_tail) - quantile_extrapolate(x, p, middle, left_tail, right_tail) + quantile_extrapolate(x, p, middle) } -quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { +quantile_extrapolate <- function(x, tau_out, middle) { tau <- field(x, "quantile_levels") qvals <- field(x, "values") r <- range(tau, na.rm = TRUE) @@ -196,10 +191,9 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { if (all(tau_out %in% tau)) { return(qvals[match(tau_out, tau)]) } - if (length(qvals) < 3 || r[1] > .25 || r[2] < .75) { - cli::cli_warn(c( - "Quantile extrapolation is not possible with fewer than", - "3 quantiles or when the probs don't span [.25, .75]" + if (length(qvals) < 2) { + cli::cli_abort(c( + "Quantile extrapolation is not possible with fewer than 2 quantiles." )) return(qvals_out) } @@ -213,7 +207,6 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { result <- tryCatch( { Q <- stats::splinefun(tau, qvals, method = "hyman") - qvals_out[indm] <- Q(tau_out[indm]) quartiles <- Q(c(.25, .5, .75)) }, error = function(e) { @@ -225,75 +218,47 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { method <- "linear" quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y } - - if (any(indm)) { qvals_out[indm] <- switch(method, linear = stats::approx(tau, qvals, tau_out[indm])$y, cubic = Q(tau_out[indm]) ) } + if (any(indl) || any(indr)) { + qv <- data.frame( + q = c(tau, tau_out[indm]), + v = c(qvals, qvals_out[indm]) + ) %>% + dplyr::distinct(q, .keep_all = TRUE) %>% + dplyr::arrange(q) + } if (any(indl)) { - qvals_out[indl] <- tail_extrapolate( - tau_out[indl], quartiles, "left", left_tail - ) + qvals_out[indl] <- tail_extrapolate(tau_out[indl], head(qv, 2)) } if (any(indr)) { - qvals_out[indr] <- tail_extrapolate( - tau_out[indr], quartiles, "right", right_tail - ) + qvals_out[indr] <- tail_extrapolate(tau_out[indr], tail(qv, 2)) } qvals_out } -tail_extrapolate <- function(tau_out, quartiles, tail, type) { - if (tail == "left") { - p <- c(.25, .5) - par <- quartiles[1:2] - } - if (tail == "right") { - p <- c(.75, .5) - par <- quartiles[3:2] - } - if (type == "normal") { - return(norm_tail_q(p, par, tau_out)) - } - if (type == "exponential") { - return(exp_tail_q(p, par, tau_out)) - } +logit <- function(p) { + p <- pmax(pmin(p, 1), 0) + log(p) - log(1 - p) } - -exp_q_par <- function(q) { - # tau should always be c(.75, .5) or c(.25, .5) - iqr <- 2 * abs(diff(q)) - s <- iqr / (2 * log(2)) - m <- q[2] - return(list(m = m, s = s)) -} - -exp_tail_q <- function(p, q, target) { - ms <- exp_q_par(q) - qlaplace(target, ms$m, ms$s) -} - -qlaplace <- function(p, centre = 0, b = 1) { - # lower.tail = TRUE, log.p = FALSE - centre - b * sign(p - 0.5) * log(1 - 2 * abs(p - 0.5)) -} - -norm_q_par <- function(q) { - # tau should always be c(.75, .5) or c(.25, .5) - iqr <- 2 * abs(diff(q)) - s <- iqr / 1.34897950039 # abs(diff(qnorm(c(.75, .25)))) - m <- q[2] - return(list(m = m, s = s)) +# extrapolates linearly on the logistic scale using +# the two points nearest the tail +tail_extrapolate <- function(tau_out, qv) { + if (nrow(qv) == 1L) { + return(rep(qv$v[1], length(tau_out))) + } + x <- logit(qv$q) + x0 <- logit(tau_out) + y <- qv$v + m <- diff(y) / diff(x) + m * (x0 - x[1]) + y[1] } -norm_tail_q <- function(p, q, target) { - ms <- norm_q_par(q) - stats::qnorm(target, ms$m, ms$s) -} #' @method Math dist_quantiles #' @export diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index bd4ed27e3..b09956c2e 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -103,14 +103,16 @@ slather.layer_residual_quantiles <- if (length(common) > 0L) { r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid"))) common_in_r <- common[common %in% names(r)] - if (length(common_in_r) != length(common)) { + if (length(common_in_r) == length(common)) { + r <- dplyr::left_join(key_cols, r, by = common_in_r) + } else { cli::cli_warn(c( "Some grouping keys are not in data.frame returned by the", "`residuals()` method. Groupings may not be correct." )) + r <- dplyr::bind_cols(key_cols, r %>% dplyr::select(.resid)) %>% + dplyr::group_by(!!!rlang::syms(common)) } - r <- dplyr::bind_cols(key_cols, r) %>% - dplyr::group_by(!!!rlang::syms(common)) } } diff --git a/tests/testthat/test-dist_quantiles.R b/tests/testthat/test-dist_quantiles.R index 4fc5587d4..99ce742d5 100644 --- a/tests/testthat/test-dist_quantiles.R +++ b/tests/testthat/test-dist_quantiles.R @@ -10,14 +10,6 @@ test_that("constructor returns reasonable quantiles", { expect_error(new_quantiles(c(1, 2, 3), c(.1, .2, 3))) }) -test_that("tail functions give reasonable output", { - expect_equal(norm_q_par(qnorm(c(.75, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(norm_q_par(qnorm(c(.25, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(norm_q_par(qnorm(c(.25, .5), 0, 1)), list(m = 0, s = 1)) - expect_equal(exp_q_par(qlaplace(c(.75, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(exp_q_par(qlaplace(c(.25, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(exp_q_par(qlaplace(c(.25, .5), 0, 1)), list(m = 0, s = 1)) -}) test_that("single dist_quantiles works, quantiles are accessible", { z <- new_quantiles(values = 1:5, quantile_levels = c(.2, .4, .5, .6, .8)) @@ -49,6 +41,34 @@ test_that("quantile extrapolator works", { expect_length(parameters(qq[1])$q[[1]], 7L) }) +test_that("small deviations of quantile requests work", { + l <- c(.05, .1, .25, .75, .9, .95) + v <- c(0.0890306, 0.1424997, 0.1971793, 0.2850978, 0.3832912, 0.4240479) + badl <- l + badl[1] <- badl[1] - 1e-14 + distn <- dist_quantiles(list(v), list(l)) + + # was broken before, now works + expect_equal(quantile(distn, l), quantile(distn, badl)) + + # The tail extrapolation was still poor. It needs to _always_ use + # the smallest (largest) values or we could end up unsorted + l <- 1:9 / 10 + v <- 1:9 + distn <- dist_quantiles(list(v), list(l)) + expect_equal(quantile(distn, c(.25, .75)), list(c(2.5, 7.5))) + expect_equal(quantile(distn, c(.1, .9)), list(c(1, 9))) + qv <- data.frame(q = l, v = v) + expect_equal( + unlist(quantile(distn, c(.01, .05))), + tail_extrapolate(c(.01, .05), head(qv, 2)) + ) + expect_equal( + unlist(quantile(distn, c(.99, .95))), + tail_extrapolate(c(.95, .99), tail(qv, 2)) + ) +}) + test_that("unary math works on quantiles", { dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) dstn2 <- dist_quantiles(list(log(1:4), log(8:11)), list(c(.2, .4, .6, .8))) diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index 73f69b54a..5723fbce9 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -49,3 +49,52 @@ test_that("Errors when used with a classifier", { wf <- wf %>% add_frosting(f) expect_error(predict(wf, tib)) }) + + +test_that("Grouping by keys is supported", { + f <- frosting() %>% + layer_predict() %>% + layer_naomit(.pred) %>% + layer_residual_quantiles() + wf1 <- wf %>% add_frosting(f) + expect_silent(p1 <- predict(wf1, latest)) + f2 <- frosting() %>% + layer_predict() %>% + layer_naomit(.pred) %>% + layer_residual_quantiles(by_key = "geo_value") + wf2 <- wf %>% add_frosting(f2) + expect_warning(p2 <- predict(wf2, latest)) + + pivot1 <- pivot_quantiles_wider(p1, .pred_distn) %>% + mutate(width = `0.95` - `0.05`) + pivot2 <- pivot_quantiles_wider(p2, .pred_distn) %>% + mutate(width = `0.95` - `0.05`) + expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1))) + expect_false(all(pivot2$width == pivot2$width[1])) +}) + +test_that("Canned forecasters work with / without", { + meta <- attr(jhu, "metadata") + meta$as_of <- max(jhu$time_value) + attr(jhu, "metadata") <- meta + + expect_silent( + flatline_forecaster(jhu, "death_rate") + ) + expect_silent( + flatline_forecaster( + jhu, "death_rate", + args_list = flatline_args_list(quantile_by_key = "geo_value") + ) + ) + + expect_silent( + arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate")) + ) + expect_silent( + flatline_forecaster( + jhu, "death_rate", + args_list = flatline_args_list(quantile_by_key = "geo_value") + ) + ) +})