Skip to content

Commit d508069

Browse files
dajmcdondsweber2
authored andcommitted
updates to WIS
1 parent 316065b commit d508069

File tree

3 files changed

+38
-86
lines changed

3 files changed

+38
-86
lines changed

NAMESPACE

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,7 @@ S3method(vec_arith,quantile_pred)
121121
S3method(vec_arith.numeric,quantile_pred)
122122
S3method(vec_arith.quantile_pred,numeric)
123123
S3method(vec_math,quantile_pred)
124-
S3method(weighted_interval_score,default)
125-
S3method(weighted_interval_score,dist_default)
126-
S3method(weighted_interval_score,dist_quantiles)
127-
S3method(weighted_interval_score,distribution)
124+
S3method(weighted_interval_score,quantile_pred)
128125
export("%>%")
129126
export(Add_model)
130127
export(Remove_model)

R/weighted_interval_score.R

Lines changed: 32 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313
#' @param actual double. Actual value(s)
1414
#' @param quantile_levels probabilities. If specified, the score will be
1515
#' computed at this set of levels.
16+
#' @param na_handling character. Determines how `quantile_levels` without a
17+
#' corresponding `value` are handled. For `"impute"`, missing values will be
18+
#' calculated if possible using the available quantiles. For `"drop"`,
19+
#' explicitly missing values are ignored in the calculation of the score, but
20+
#' implicitly missing values are imputed if possible.
21+
#' For `"propogate"`, the resulting score will be `NA` if any missing values
22+
#' exist in the original `quantile_levels`. Finally, if
23+
#' `quantile_levels` is specified, `"fail"` will result in
24+
#' the score being `NA` when any required quantile levels (implicit or explicit)
25+
#' are do not have corresponding values.
1626
#' @param ... not used
1727
#'
1828
#' @return a vector of nonnegative scores.
@@ -44,104 +54,58 @@
4454
#'
4555
#' # Using some actual forecasts --------
4656
#' library(dplyr)
47-
#' jhu <- covid_case_death_rates %>%
57+
#' jhu <- case_death_rate_subset %>%
4858
#' filter(time_value >= "2021-10-01", time_value <= "2021-12-01")
4959
#' preds <- flatline_forecaster(
5060
#' jhu, "death_rate",
5161
#' flatline_args_list(quantile_levels = c(.01, .025, 1:19 / 20, .975, .99))
5262
#' )$predictions
53-
#' actuals <- covid_case_death_rates %>%
63+
#' actuals <- case_death_rate_subset %>%
5464
#' filter(time_value == as.Date("2021-12-01") + 7) %>%
5565
#' select(geo_value, time_value, actual = death_rate)
5666
#' preds <- left_join(preds, actuals,
5767
#' by = c("target_date" = "time_value", "geo_value")
5868
#' ) %>%
5969
#' mutate(wis = weighted_interval_score(.pred_distn, actual))
6070
#' preds
61-
weighted_interval_score <- function(x, actual, quantile_levels = NULL, ...) {
71+
weighted_interval_score <- function(
72+
x,
73+
actual,
74+
quantile_levels = NULL,
75+
na_handling = c("impute", "drop", "propagate", "fail"),
76+
...) {
6277
UseMethod("weighted_interval_score")
6378
}
6479

65-
#' @export
66-
weighted_interval_score.default <- function(x, actual,
67-
quantile_levels = NULL, ...) {
68-
cli_abort(c(
69-
"Weighted interval score can only be calculated if `x`",
70-
"has class {.cls distribution}."
71-
))
72-
}
73-
74-
#' @export
75-
weighted_interval_score.distribution <- function(
76-
x, actual,
77-
quantile_levels = NULL, ...) {
78-
assert_numeric(actual, finite = TRUE)
79-
l <- vctrs::vec_recycle_common(x = x, actual = actual)
80-
map2_dbl(
81-
.x = vctrs::vec_data(l$x),
82-
.y = l$actual,
83-
.f = weighted_interval_score,
84-
quantile_levels = quantile_levels,
85-
...
86-
)
87-
}
88-
89-
#' @export
90-
weighted_interval_score.dist_default <- function(x, actual,
91-
quantile_levels = NULL, ...) {
92-
rlang::check_dots_empty()
93-
if (is.null(quantile_levels)) {
94-
cli_warn(c(
95-
"Weighted interval score isn't implemented for {.cls {class(x)}}",
96-
"as we don't know what set of quantile levels to use.",
97-
"Use a {.cls dist_quantiles} or pass `quantile_levels`.",
98-
"The result for this element will be `NA`."
99-
))
100-
return(NA)
101-
}
102-
x <- extrapolate_quantiles(x, probs = quantile_levels)
103-
weighted_interval_score(x, actual, quantile_levels = NULL)
104-
}
10580

106-
#' @param na_handling character. Determines how `quantile_levels` without a
107-
#' corresponding `value` are handled. For `"impute"`, missing values will be
108-
#' calculated if possible using the available quantiles. For `"drop"`,
109-
#' explicitly missing values are ignored in the calculation of the score, but
110-
#' implicitly missing values are imputed if possible.
111-
#' For `"propogate"`, the resulting score will be `NA` if any missing values
112-
#' exist in the original `quantile_levels`. Finally, if
113-
#' `quantile_levels` is specified, `"fail"` will result in
114-
#' the score being `NA` when any required quantile levels (implicit or explicit)
115-
#' are do not have corresponding values.
116-
#' @describeIn weighted_interval_score Weighted interval score with
117-
#' `dist_quantiles` allows for different `NA` behaviours.
11881
#' @export
119-
weighted_interval_score.dist_quantiles <- function(
82+
weighted_interval_score.quantile_pred <- function(
12083
x, actual,
12184
quantile_levels = NULL,
12285
na_handling = c("impute", "drop", "propagate", "fail"),
12386
...) {
12487
rlang::check_dots_empty()
125-
if (is.na(actual)) {
126-
return(NA)
127-
}
128-
if (all(is.na(vctrs::field(x, "values")))) {
129-
return(NA)
130-
}
88+
n <- vctrs::vec_size(x)
89+
if (length(actual) == 1L) actual <- rep(actual, n)
90+
assert_numeric(actual, finite = TRUE, len = n)
91+
assert_numeric(quantile_levels, lower = 0, upper = 1, null.ok = TRUE)
13192
na_handling <- rlang::arg_match(na_handling)
132-
old_quantile_levels <- field(x, "quantile_levels")
93+
old_quantile_levels <- x %@% "quantile_levels"
13394
if (na_handling == "fail") {
13495
if (is.null(quantile_levels)) {
13596
cli_abort('`na_handling = "fail"` requires `quantile_levels` to be specified.')
13697
}
137-
old_values <- field(x, "values")
138-
if (!all(quantile_levels %in% old_quantile_levels) || any(is.na(old_values))) {
139-
return(NA)
98+
if (!all(quantile_levels %in% old_quantile_levels)) {
99+
return(rep(NA_real_, n))
140100
}
141101
}
142102
tau <- quantile_levels %||% old_quantile_levels
143-
x <- extrapolate_quantiles(x, probs = tau, replace_na = (na_handling == "impute"))
144-
q <- field(x, "values")[field(x, "quantile_levels") %in% tau]
103+
x <- extrapolate_quantiles(x, tau, replace_na = (na_handling == "impute"))
104+
x <- as.matrix(x)[, attr(x, "quantile_levels") %in% tau]
145105
na_rm <- (na_handling == "drop")
106+
map2_dbl(vctrs::vec_chop(x), actual, ~ wis_one_quantile(.x, tau, .y, na_rm))
107+
}
108+
109+
wis_one_quantile <- function(q, tau, actual, na_rm) {
146110
2 * mean(pmax(tau * (actual - q), (1 - tau) * (q - actual)), na.rm = na_rm)
147111
}

man/weighted_interval_score.Rd

Lines changed: 5 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)