From 1fc981345519e5b6501c7926acb56c8b963c9f3c Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 20 Feb 2024 12:24:30 -0800 Subject: [PATCH 1/9] handle the quantile grouping correctly --- R/layer_residual_quantiles.R | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index bd4ed27e3..4a47a579f 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -103,14 +103,16 @@ slather.layer_residual_quantiles <- if (length(common) > 0L) { r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid"))) common_in_r <- common[common %in% names(r)] - if (length(common_in_r) != length(common)) { + if (length(common_in_r) == length(common)) { + r <- dplyr::left_join(key_cols, r, by = common_in_r) + } else { cli::cli_warn(c( "Some grouping keys are not in data.frame returned by the", "`residuals()` method. Groupings may not be correct." )) + r <- dplyr::bind_cols(key_cols, r %>% tidyselect::select(.resid)) %>% + dplyr::group_by(!!!rlang::syms(common)) } - r <- dplyr::bind_cols(key_cols, r) %>% - dplyr::group_by(!!!rlang::syms(common)) } } From 3094da9688a45baddbda0154fa224352e43f27ca Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 20 Feb 2024 12:34:50 -0800 Subject: [PATCH 2/9] existing tests pass, remove non-standard file --- .Rbuildignore | 1 + DESCRIPTION | 2 +- R/layer_residual_quantiles.R | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.Rbuildignore b/.Rbuildignore index 3a77bb347..d0c10332d 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -3,6 +3,7 @@ ^epipredict\.Rproj$ ^\.Rproj\.user$ ^LICENSE\.md$ +^DEVELOPMENT\.md$ ^drafts$ ^\.Rprofile$ ^man-roxygen$ diff --git a/DESCRIPTION b/DESCRIPTION index c451f755b..d1e966278 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -72,4 +72,4 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.0 +RoxygenNote: 7.3.1 diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 4a47a579f..b09956c2e 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -110,7 +110,7 @@ slather.layer_residual_quantiles <- "Some grouping keys are not in data.frame returned by the", "`residuals()` method. Groupings may not be correct." )) - r <- dplyr::bind_cols(key_cols, r %>% tidyselect::select(.resid)) %>% + r <- dplyr::bind_cols(key_cols, r %>% dplyr::select(.resid)) %>% dplyr::group_by(!!!rlang::syms(common)) } } From 63cd99145366eccd7353524405e0837cd1c5572a Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 20 Feb 2024 12:44:17 -0800 Subject: [PATCH 3/9] add a basic test --- .../testthat/test-layer_residual_quantiles.R | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index 73f69b54a..9cee038d5 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -49,3 +49,26 @@ test_that("Errors when used with a classifier", { wf <- wf %>% add_frosting(f) expect_error(predict(wf, tib)) }) + + +test_that("Grouping by keys is supported", { + f <- frosting() %>% + layer_predict() %>% + layer_naomit(.pred) %>% + layer_residual_quantiles() + wf1 <- wf %>% add_frosting(f) + expect_silent(p1 <- predict(wf1, latest)) + f2 <- frosting() %>% + layer_predict() %>% + layer_naomit(.pred) %>% + layer_residual_quantiles(by_key = "geo_value") + wf2 <- wf %>% add_frosting(f2) + expect_warning(p2 <- predict(wf2, latest)) + + pivot1 <- pivot_quantiles_wider(p1, .pred_distn) %>% + mutate(width = `0.95` - `0.05`) + pivot2 <- pivot_quantiles_wider(p2, .pred_distn) %>% + mutate(width = `0.95` - `0.05`) + expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1))) + expect_false(all(pivot2$width == pivot2$width[1])) +}) From aceeac4544f89a72b0bdc9f5a736832d35b652d6 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 20 Feb 2024 18:06:04 -0800 Subject: [PATCH 4/9] add flatline test --- .../testthat/test-layer_residual_quantiles.R | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index 9cee038d5..5723fbce9 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -72,3 +72,29 @@ test_that("Grouping by keys is supported", { expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1))) expect_false(all(pivot2$width == pivot2$width[1])) }) + +test_that("Canned forecasters work with / without", { + meta <- attr(jhu, "metadata") + meta$as_of <- max(jhu$time_value) + attr(jhu, "metadata") <- meta + + expect_silent( + flatline_forecaster(jhu, "death_rate") + ) + expect_silent( + flatline_forecaster( + jhu, "death_rate", + args_list = flatline_args_list(quantile_by_key = "geo_value") + ) + ) + + expect_silent( + arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate")) + ) + expect_silent( + flatline_forecaster( + jhu, "death_rate", + args_list = flatline_args_list(quantile_by_key = "geo_value") + ) + ) +}) From 695aeb07bcfc1791807282b35eaf243374a01b88 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Wed, 21 Feb 2024 09:03:53 -0800 Subject: [PATCH 5/9] fix: refactor quantile calculation * avoids bug when requesting quantile_values close to existing ones but off by a small tolerance * never creates unsorted quantiles * extrapolates outside the existing range by linearly interpolating on the logistic scale --- R/dist_quantiles.R | 99 +++++++++------------------- tests/testthat/test-dist_quantiles.R | 28 ++++++++ 2 files changed, 58 insertions(+), 69 deletions(-) diff --git a/R/dist_quantiles.R b/R/dist_quantiles.R index 750e9560d..c027873d9 100644 --- a/R/dist_quantiles.R +++ b/R/dist_quantiles.R @@ -172,20 +172,15 @@ mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "line #' @export #' @importFrom stats quantile #' @import distributional -quantile.dist_quantiles <- function( - x, p, ..., - middle = c("cubic", "linear"), - left_tail = c("normal", "exponential"), - right_tail = c("normal", "exponential")) { +quantile.dist_quantiles <- function(x, p, ..., middle = c("cubic", "linear")) { arg_is_probabilities(p) + p <- sort(p) middle <- match.arg(middle) - left_tail <- match.arg(left_tail) - right_tail <- match.arg(right_tail) - quantile_extrapolate(x, p, middle, left_tail, right_tail) + quantile_extrapolate(x, p, middle) } -quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { +quantile_extrapolate <- function(x, tau_out, middle) { tau <- field(x, "quantile_levels") qvals <- field(x, "values") r <- range(tau, na.rm = TRUE) @@ -196,10 +191,9 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { if (all(tau_out %in% tau)) { return(qvals[match(tau_out, tau)]) } - if (length(qvals) < 3 || r[1] > .25 || r[2] < .75) { - cli::cli_warn(c( - "Quantile extrapolation is not possible with fewer than", - "3 quantiles or when the probs don't span [.25, .75]" + if (length(qvals) < 2) { + cli::cli_abort(c( + "Quantile extrapolation is not possible with fewer than 2 quantiles." )) return(qvals_out) } @@ -213,87 +207,54 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { result <- tryCatch( { Q <- stats::splinefun(tau, qvals, method = "hyman") - qvals_out[indm] <- Q(tau_out[indm]) quartiles <- Q(c(.25, .5, .75)) }, - error = function(e) { - return(NA) - } + error = function(e) { return(NA) } ) } if (middle == "linear" || any(is.na(result))) { method <- "linear" quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y } - - if (any(indm)) { qvals_out[indm] <- switch(method, linear = stats::approx(tau, qvals, tau_out[indm])$y, cubic = Q(tau_out[indm]) ) } + if (any(indl) || any(indr)) { + qv <- data.frame( + q = c(tau, tau_out[indm]), + v = c(qvals, qvals_out[indm]) + ) %>% + dplyr::distinct(q, .keep_all = TRUE) %>% + dplyr::arrange(q) + } if (any(indl)) { - qvals_out[indl] <- tail_extrapolate( - tau_out[indl], quartiles, "left", left_tail - ) + qvals_out[indl] <- tail_extrapolate(tau_out[indl], head(qv, 2)) } if (any(indr)) { - qvals_out[indr] <- tail_extrapolate( - tau_out[indr], quartiles, "right", right_tail - ) + qvals_out[indr] <- tail_extrapolate(tau_out[indr], tail(qv, 2)) } qvals_out } -tail_extrapolate <- function(tau_out, quartiles, tail, type) { - if (tail == "left") { - p <- c(.25, .5) - par <- quartiles[1:2] - } - if (tail == "right") { - p <- c(.75, .5) - par <- quartiles[3:2] - } - if (type == "normal") { - return(norm_tail_q(p, par, tau_out)) - } - if (type == "exponential") { - return(exp_tail_q(p, par, tau_out)) - } -} - - -exp_q_par <- function(q) { - # tau should always be c(.75, .5) or c(.25, .5) - iqr <- 2 * abs(diff(q)) - s <- iqr / (2 * log(2)) - m <- q[2] - return(list(m = m, s = s)) +logit <- function(p) { + p <- pmax(pmin(p, 1), 0) + log(p) - log(1 - p) } -exp_tail_q <- function(p, q, target) { - ms <- exp_q_par(q) - qlaplace(target, ms$m, ms$s) +# extrapolates linearly on the logistic scale using +# the two points nearest the tail +tail_extrapolate <- function(tau_out, qv) { + if (nrow(qv) == 1L) return(rep(qv$v[1], length(tau_out))) + x <- logit(qv$q) + x0 <- logit(tau_out) + y <- qv$v + m <- diff(y) / diff(x) + m * (x0 - x[1]) + y[1] } -qlaplace <- function(p, centre = 0, b = 1) { - # lower.tail = TRUE, log.p = FALSE - centre - b * sign(p - 0.5) * log(1 - 2 * abs(p - 0.5)) -} - -norm_q_par <- function(q) { - # tau should always be c(.75, .5) or c(.25, .5) - iqr <- 2 * abs(diff(q)) - s <- iqr / 1.34897950039 # abs(diff(qnorm(c(.75, .25)))) - m <- q[2] - return(list(m = m, s = s)) -} - -norm_tail_q <- function(p, q, target) { - ms <- norm_q_par(q) - stats::qnorm(target, ms$m, ms$s) -} #' @method Math dist_quantiles #' @export diff --git a/tests/testthat/test-dist_quantiles.R b/tests/testthat/test-dist_quantiles.R index 4fc5587d4..64bf16b23 100644 --- a/tests/testthat/test-dist_quantiles.R +++ b/tests/testthat/test-dist_quantiles.R @@ -49,6 +49,34 @@ test_that("quantile extrapolator works", { expect_length(parameters(qq[1])$q[[1]], 7L) }) +test_that("small deviations of quantile requests work", { + l <- c(.05, .1, .25, .75, .9, .95) + v <- c(0.0890306, 0.1424997, 0.1971793, 0.2850978, 0.3832912, 0.4240479) + badl <- l + badl[1] <- badl[1] - 1e-14 + distn <- dist_quantiles(list(v), list(l)) + + # was broken before, now works + expect_equal(quantile(distn, l), quantile(distn, badl)) + + # The tail extrapolation was still poor. It needs to _always_ use + # the smallest (largest) values or we could end up unsorted + l <- 1:9 / 10 + v <- 1:9 + distn <- dist_quantiles(list(v), list(l)) + expect_equal(quantile(distn, c(.25, .75)), list(c(2.5, 7.5))) + expect_equal(quantile(distn, c(.1, .9)), list(c(1, 9))) + qv <- data.frame(q = l, v = v) + expect_equal( + unlist(quantile(distn, c(.01, .05))), + tail_extrapolate(c(.01, .05), head(qv, 2)) + ) + expect_equal( + unlist(quantile(distn, c(.99, .95))), + tail_extrapolate(c(.95, .99), tail(qv, 2)) + ) +}) + test_that("unary math works on quantiles", { dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) dstn2 <- dist_quantiles(list(log(1:4), log(8:11)), list(c(.2, .4, .6, .8))) From fb04f20241a84514965f44de8bf1fa08cd94e759 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Wed, 21 Feb 2024 09:04:32 -0800 Subject: [PATCH 6/9] bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index d1e966278..e068fe7e3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.9 +Version: 0.0.10 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), From 0165001d04e2e00bb93b5dc0dad24e6c1e258b43 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Wed, 21 Feb 2024 09:06:52 -0800 Subject: [PATCH 7/9] add to news --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 3c5034080..04dc78e4f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -32,3 +32,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Two simple forecasters as test beds - Working vignette - use `checkmate` for input validation +- refactor quantile extrapolation (possibly creates different results) From 0d133df5b92e38e72b039ee350186a36dc416101 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Wed, 21 Feb 2024 09:34:29 -0800 Subject: [PATCH 8/9] remove test for nonexistent functions --- tests/testthat/test-dist_quantiles.R | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/testthat/test-dist_quantiles.R b/tests/testthat/test-dist_quantiles.R index 64bf16b23..99ce742d5 100644 --- a/tests/testthat/test-dist_quantiles.R +++ b/tests/testthat/test-dist_quantiles.R @@ -10,14 +10,6 @@ test_that("constructor returns reasonable quantiles", { expect_error(new_quantiles(c(1, 2, 3), c(.1, .2, 3))) }) -test_that("tail functions give reasonable output", { - expect_equal(norm_q_par(qnorm(c(.75, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(norm_q_par(qnorm(c(.25, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(norm_q_par(qnorm(c(.25, .5), 0, 1)), list(m = 0, s = 1)) - expect_equal(exp_q_par(qlaplace(c(.75, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(exp_q_par(qlaplace(c(.25, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(exp_q_par(qlaplace(c(.25, .5), 0, 1)), list(m = 0, s = 1)) -}) test_that("single dist_quantiles works, quantiles are accessible", { z <- new_quantiles(values = 1:5, quantile_levels = c(.2, .4, .5, .6, .8)) From 3b8c889e569165aa717bcc0cab0b59bb4b1cff35 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Wed, 21 Feb 2024 10:41:25 -0800 Subject: [PATCH 9/9] style only --- R/dist_quantiles.R | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/R/dist_quantiles.R b/R/dist_quantiles.R index c027873d9..7f7af40f4 100644 --- a/R/dist_quantiles.R +++ b/R/dist_quantiles.R @@ -209,7 +209,9 @@ quantile_extrapolate <- function(x, tau_out, middle) { Q <- stats::splinefun(tau, qvals, method = "hyman") quartiles <- Q(c(.25, .5, .75)) }, - error = function(e) { return(NA) } + error = function(e) { + return(NA) + } ) } if (middle == "linear" || any(is.na(result))) { @@ -247,7 +249,9 @@ logit <- function(p) { # extrapolates linearly on the logistic scale using # the two points nearest the tail tail_extrapolate <- function(tau_out, qv) { - if (nrow(qv) == 1L) return(rep(qv$v[1], length(tau_out))) + if (nrow(qv) == 1L) { + return(rep(qv$v[1], length(tau_out))) + } x <- logit(qv$q) x0 <- logit(tau_out) y <- qv$v