Skip to content

Commit 1d427b7

Browse files
committed
refactor residual quantiles to incorporate dist_quantiles.
1 parent 9f7aaec commit 1d427b7

10 files changed

+107
-65
lines changed

NAMESPACE

+6-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@ S3method(refresh_blueprint,default_epi_recipe_blueprint)
3333
S3method(run_mold,default_epi_recipe_blueprint)
3434
S3method(slather,layer_naomit)
3535
S3method(slather,layer_predict)
36-
S3method(slather,layer_residual_quantile)
36+
S3method(slather,layer_residual_quantiles)
3737
S3method(slather,layer_threshold)
38+
S3method(snap,default)
39+
S3method(snap,dist_default)
40+
S3method(snap,dist_quantiles)
41+
S3method(snap,distribution)
3842
S3method(vec_ptype_abbr,dist_quantiles)
3943
S3method(vec_ptype_full,dist_quantiles)
4044
export("%>%")
@@ -70,7 +74,7 @@ export(knnarx_forecaster)
7074
export(layer)
7175
export(layer_naomit)
7276
export(layer_predict)
73-
export(layer_residual_quantile)
77+
export(layer_residual_quantiles)
7478
export(layer_threshold)
7579
export(nested_quantiles)
7680
export(new_default_epi_recipe_blueprint)

R/compat-purrr.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
map <- function(.x, .f, ...) {
5-
.f <- rlang::as_function(.f, env = global_env())
5+
.f <- rlang::as_function(.f, env = rlang::global_env())
66
lapply(.x, .f, ...)
77
}
88
walk <- function(.x, .f, ...) {
@@ -23,14 +23,14 @@ map_chr <- function(.x, .f, ...) {
2323
.rlang_purrr_map_mold(.x, .f, character(1), ...)
2424
}
2525
.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) {
26-
.f <- rlang::as_function(.f, env = global_env())
26+
.f <- rlang::as_function(.f, env = rlang::global_env())
2727
out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE)
2828
names(out) <- names(.x)
2929
out
3030
}
3131

3232
map2 <- function(.x, .y, .f, ...) {
33-
.f <- rlang::as_function(.f, env = global_env())
33+
.f <- rlang::as_function(.f, env = rlang::global_env())
3434
out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE)
3535
if (length(out) == length(.x)) {
3636
rlang::set_names(out, names(.x))

R/dist_quantiles.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ new_quantiles <- function(q = double(), tau = double()) {
1010
q <- q[o]
1111
tau <- tau[o]
1212
}
13-
if (is.unsorted(q)) rlang::abort("q[order(tau)] produces unsorted quantiles.")
13+
if (is.unsorted(q, na.rm = TRUE))
14+
rlang::abort("q[order(tau)] produces unsorted quantiles.")
1415

1516
new_rcrd(list(q = q, tau = tau),
1617
class = c("dist_quantiles", "dist_default"))
@@ -93,8 +94,7 @@ extrapolate_quantiles <- function(x, p, ...) {
9394
#' @export
9495
extrapolate_quantiles.distribution <- function(x, p, ...) {
9596
arg_is_probabilities(p)
96-
p <- distributional:::arg_listable(p, .ptype = double())
97-
dstn <- distributional:::dist_apply(x, extrapolate_quantiles, p = p, ...)
97+
dstn <- lapply(vec_data(x), extrapolate_quantiles, p = p, ...)
9898
distributional:::wrap_dist(dstn)
9999
}
100100

R/layer_residual_quantile.R renamed to R/layer_residual_quantiles.R

+20-20
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,47 @@
22
#'
33
#' @param frosting a `frosting` postprocessor
44
#' @param ... Unused, include for consistency with other layers.
5-
#' @param probs numeric vector of probabilities with values in (0,1) referring to the desired quantile.
6-
#' @param symmetrize logical. If `TRUE` then interval will be symmetrical.
5+
#' @param probs numeric vector of probabilities with values in (0,1)
6+
#' referring to the desired quantile.
7+
#' @param symmetrize logical. If `TRUE` then interval will be symmetric.
78
#' @param .flag a logical to determine if the layer is added. Passed on to
89
#' `add_layer()`. Default `TRUE`.
910
#' @param id a random id string
1011
#'
11-
#' @return an updated `frosting` postprocessor with additional columns of the residual quantiles added to the prediction
12+
#' @return an updated `frosting` postprocessor with additional columns of the
13+
#' residual quantiles added to the prediction
1214
#' @export
1315
#' @examples
14-
#' jhu <- case_death_rate_subset %>%
16+
#' jhu <- case_death_rate_subset %>%
1517
#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
1618
#'
1719
#' r <- epi_recipe(jhu) %>%
1820
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
1921
#' step_epi_ahead(death_rate, ahead = 7) %>%
20-
#' recipes::step_naomit(recipes::all_predictors()) %>%
21-
#' recipes::step_naomit(recipes::all_outcomes(), skip = TRUE)
22+
#' step_epi_naomit()
2223
#'
2324
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>%
2425
#' parsnip::fit(jhu)
2526
#'
2627
#' latest <- get_test_data(recipe = r, x = jhu)
2728
#'
28-
#' f <- epipredict:::frosting() %>%
29-
#' layer_predict() %>%
30-
#' layer_residual_quantile(probs = c(0.0275, 0.975), symmetrize = FALSE) %>%
31-
#' layer_naomit(.pred)
32-
#' wf1 <- wf %>% epipredict:::add_frosting(f)
29+
#' f <- frosting() %>%
30+
#' layer_predict() %>%
31+
#' layer_residual_quantiles(probs = c(0.0275, 0.975), symmetrize = FALSE) %>%
32+
#' layer_naomit(.pred)
33+
#' wf1 <- wf %>% add_frosting(f)
3334
#'
3435
#' p <- predict(wf1, latest)
3536
#' p
36-
layer_residual_quantile <- function(frosting, ...,
37+
layer_residual_quantiles <- function(frosting, ...,
3738
probs = c(0.0275, 0.975),
3839
symmetrize = TRUE,
3940
.flag = TRUE,
40-
id = rand_id("residual_quantile")) {
41+
id = rand_id("residual_quantiles")) {
4142
rlang::check_dots_empty()
4243
add_layer(
4344
frosting,
44-
layer_residual_quantile_new(
45+
layer_residual_quantiles_new(
4546
probs = probs,
4647
symmetrize = symmetrize,
4748
id = id
@@ -50,12 +51,12 @@ layer_residual_quantile <- function(frosting, ...,
5051
)
5152
}
5253

53-
layer_residual_quantile_new <- function(probs, symmetrize, id) {
54-
layer("residual_quantile", probs = probs, symmetrize = symmetrize, id = id)
54+
layer_residual_quantiles_new <- function(probs, symmetrize, id) {
55+
layer("residual_quantiles", probs = probs, symmetrize = symmetrize, id = id)
5556
}
5657

5758
#' @export
58-
slather.layer_residual_quantile <-
59+
slather.layer_residual_quantiles <-
5960
function(object, components, the_fit, the_recipe, ...) {
6061
if (is.null(object$probs)) return(components)
6162

@@ -64,8 +65,7 @@ slather.layer_residual_quantile <-
6465
q <- quantile(c(r, s * r), probs = object$probs, na.rm = TRUE)
6566

6667
estimate <- components$predictions$.pred
67-
interval <- data.frame(outer(estimate, q, "+"))
68-
names(interval)<- probs_to_string(object$probs)
69-
components$predictions <- dplyr::bind_cols(components$predictions,interval)
68+
dstn <- dist_quantiles(map(estimate, "+", q), object$probs)
69+
components$predictions$.quantiles <- dstn
7070
components
7171
}

R/layer_threshold_preds.R

+31-1
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,41 @@ layer_threshold_new <-
6767
layer("threshold", terms = terms, lower = lower, upper = upper, id = id)
6868
}
6969

70-
snap <- function(x, lower, upper) {
70+
71+
72+
snap <- function(x, lower, upper, ...) {
73+
UseMethod("snap")
74+
}
75+
76+
#' @export
77+
snap.default <- function(x, lower, upper, ...) {
78+
rlang::check_dots_empty()
7179
arg_is_scalar(lower, upper)
7280
pmin(pmax(x, lower), upper)
7381
}
7482

83+
#' @export
84+
snap.distribution <- function(x, lower, upper, ...) {
85+
rlang::check_dots_empty()
86+
arg_is_scalar(lower, upper)
87+
dstn <- lapply(vec_data(x), snap, lower = lower, upper = upper)
88+
distributional:::wrap_dist(dstn)
89+
}
90+
91+
#' @export
92+
snap.dist_default <- function(x, lower, upper, ...) {
93+
rlang::check_dots_empty()
94+
x
95+
}
96+
97+
#' @export
98+
snap.dist_quantiles <- function(x, lower, upper, ...) {
99+
q <- field(x, "q")
100+
tau <- field(x, "tau")
101+
q <- snap(q, lower, upper)
102+
new_quantiles(q = q, tau = tau)
103+
}
104+
75105
#' @export
76106
slather.layer_threshold <-
77107
function(object, components, the_fit, the_recipe, ...) {

man/layer_residual_quantile.Rd renamed to man/layer_residual_quantiles.Rd

+17-16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-epi_workflow.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ test_that("outcome of the two methods are the same", {
1919
step_epi_lag(death_rate, lag = c(0, 7)) %>%
2020
step_epi_ahead(death_rate, ahead = 7) %>%
2121
step_epi_lag(case_rate, lag = c(7)) %>%
22-
step_naomit(all_predictors()) %>%
23-
step_naomit(all_outcomes())
22+
step_epi_naomit()
2423

2524
s <- parsnip::linear_reg()
2625
f <- frosting() %>%
2726
layer_predict() %>%
2827
layer_naomit(.pred) %>%
29-
layer_residual_quantile()
28+
layer_residual_quantiles()
3029

3130
ef <- epi_workflow(r, s, f)
3231
ef2 <- epi_workflow(r, s) %>% add_frosting(f)

tests/testthat/test-frosting.R

+3-4
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,20 @@ test_that("layer_predict is added by default if missing", {
5656
r <- epi_recipe(jhu) %>%
5757
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
5858
step_epi_ahead(death_rate, ahead = 7) %>%
59-
step_naomit(all_predictors()) %>%
60-
step_naomit(all_outcomes(), skip = TRUE)
59+
step_epi_naomit()
6160

6261
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
6362

6463
latest <- get_test_data(recipe = r, x = jhu)
6564

6665
f1 <- frosting() %>%
6766
layer_naomit(.pred) %>%
68-
layer_residual_quantile()
67+
layer_residual_quantiles()
6968

7069
f2 <- frosting() %>%
7170
layer_predict() %>%
7271
layer_naomit(.pred) %>%
73-
layer_residual_quantile()
72+
layer_residual_quantiles()
7473

7574
wf1 <- wf %>% add_frosting(f1)
7675
wf2 <- wf %>% add_frosting(f2)

tests/testthat/test-layer_residual_quantile.R renamed to tests/testthat/test-layer_residual_quantiles.R

+10-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ jhu <- case_death_rate_subset %>%
44
r <- epi_recipe(jhu) %>%
55
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
66
step_epi_ahead(death_rate, ahead = 7) %>%
7-
step_naomit(all_predictors()) %>%
8-
step_naomit(all_outcomes(), skip = TRUE)
7+
step_epi_naomit()
98

109
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
1110
latest <- get_test_data(recipe = r, x = jhu)
@@ -15,13 +14,19 @@ test_that("Returns expected number or rows and columns", {
1514
f <- frosting() %>%
1615
layer_predict() %>%
1716
layer_naomit(.pred) %>%
18-
layer_residual_quantile(probs = c(0.0275, 0.8, 0.95), symmetrize = FALSE)
17+
layer_residual_quantiles(probs = c(0.0275, 0.8, 0.95), symmetrize = FALSE)
1918

2019
wf1 <- wf %>% add_frosting(f)
2120

2221
expect_silent(p <- predict(wf1, latest))
23-
expect_equal(ncol(p), 6L)
22+
expect_equal(ncol(p), 4L)
2423
expect_s3_class(p, "epi_df")
2524
expect_equal(nrow(p), 3L)
26-
expect_named(p, c("geo_value", "time_value",".pred","q0.0275","q0.8","q0.95"))
25+
expect_named(p, c("geo_value", "time_value",".pred",".quantiles"))
26+
27+
nested <- p %>% dplyr::mutate(.quantiles = nested_quantiles(.quantiles))
28+
unnested <- nested %>% tidyr::unnest(.quantiles)
29+
30+
expect_equal(nrow(unnested), 9L)
31+
expect_equal(unique(unnested$tau), c(.0275, .8, .95))
2732
})

tests/testthat/test-layer_threshold_preds.R

+12-8
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ jhu <- case_death_rate_subset %>%
33
r <- epi_recipe(jhu) %>%
44
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
55
step_epi_ahead(death_rate, ahead = 7) %>%
6-
step_naomit(all_predictors()) %>%
7-
step_naomit(all_outcomes(), skip = TRUE)
6+
step_epi_naomit()
7+
88
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
99
latest <- jhu %>%
1010
dplyr::filter(time_value >= max(time_value) - 14)
@@ -46,17 +46,21 @@ test_that("thresholds additional columns", {
4646

4747
f <- frosting() %>%
4848
layer_predict() %>%
49-
layer_residual_quantile(probs = c(.1, .9)) %>%
50-
layer_threshold(.pred, dplyr::starts_with("q"), lower = 0.180, upper = 0.31) %>%
49+
layer_residual_quantiles(probs = c(.1, .9)) %>%
50+
layer_threshold(.pred, .quantiles, lower = 0.180, upper = 0.31) %>%
5151
layer_naomit(.pred)
52+
5253
wf2 <- wf %>% add_frosting(f)
5354

5455
expect_silent(p <- predict(wf2, latest))
55-
expect_equal(ncol(p), 5L)
56+
expect_equal(ncol(p), 4L)
5657
expect_s3_class(p, "epi_df")
5758
expect_equal(nrow(p), 3L)
5859
expect_equal(round(p$.pred, digits = 3), c(0.180, 0.180, 0.310))
59-
expect_equal(round(p$q0.1, digits = 3), c(0.180, 0.180, 0.310))
60-
expect_equal(round(p$q0.9, digits = 3), c(0.310, 0.180, 0.310))
61-
expect_named(p, c("geo_value", "time_value", ".pred", "q0.1", "q0.9"))
60+
expect_named(p, c("geo_value", "time_value", ".pred", ".quantiles"))
61+
p <- p %>%
62+
dplyr::mutate(.quantiles = nested_quantiles(.quantiles)) %>%
63+
tidyr::unnest(.quantiles)
64+
expect_equal(round(p$q, digits = 3), c(0.180, 0.31, 0.180, .18, 0.310, .31))
65+
expect_equal(p$tau, rep(c(.1,.9), times = 3))
6266
})

0 commit comments

Comments
 (0)