|
13 | 13 | #' @param actual double. Actual value(s)
|
14 | 14 | #' @param quantile_levels probabilities. If specified, the score will be
|
15 | 15 | #' computed at this set of levels.
|
| 16 | +#' @param na_handling character. Determines how `quantile_levels` without a |
| 17 | +#' corresponding `value` are handled. For `"impute"`, missing values will be |
| 18 | +#' calculated if possible using the available quantiles. For `"drop"`, |
| 19 | +#' explicitly missing values are ignored in the calculation of the score, but |
| 20 | +#' implicitly missing values are imputed if possible. |
| 21 | +#' For `"propogate"`, the resulting score will be `NA` if any missing values |
| 22 | +#' exist in the original `quantile_levels`. Finally, if |
| 23 | +#' `quantile_levels` is specified, `"fail"` will result in |
| 24 | +#' the score being `NA` when any required quantile levels (implicit or explicit) |
| 25 | +#' are do not have corresponding values. |
16 | 26 | #' @param ... not used
|
17 | 27 | #'
|
18 | 28 | #' @return a vector of nonnegative scores.
|
|
44 | 54 | #'
|
45 | 55 | #' # Using some actual forecasts --------
|
46 | 56 | #' library(dplyr)
|
47 |
| -#' jhu <- covid_case_death_rates %>% |
| 57 | +#' jhu <- case_death_rate_subset %>% |
48 | 58 | #' filter(time_value >= "2021-10-01", time_value <= "2021-12-01")
|
49 | 59 | #' preds <- flatline_forecaster(
|
50 | 60 | #' jhu, "death_rate",
|
51 | 61 | #' flatline_args_list(quantile_levels = c(.01, .025, 1:19 / 20, .975, .99))
|
52 | 62 | #' )$predictions
|
53 |
| -#' actuals <- covid_case_death_rates %>% |
| 63 | +#' actuals <- case_death_rate_subset %>% |
54 | 64 | #' filter(time_value == as.Date("2021-12-01") + 7) %>%
|
55 | 65 | #' select(geo_value, time_value, actual = death_rate)
|
56 | 66 | #' preds <- left_join(preds, actuals,
|
57 | 67 | #' by = c("target_date" = "time_value", "geo_value")
|
58 | 68 | #' ) %>%
|
59 | 69 | #' mutate(wis = weighted_interval_score(.pred_distn, actual))
|
60 | 70 | #' preds
|
61 |
| -weighted_interval_score <- function(x, actual, quantile_levels = NULL, ...) { |
| 71 | +weighted_interval_score <- function( |
| 72 | + x, |
| 73 | + actual, |
| 74 | + quantile_levels = NULL, |
| 75 | + na_handling = c("impute", "drop", "propagate", "fail"), |
| 76 | + ...) { |
62 | 77 | UseMethod("weighted_interval_score")
|
63 | 78 | }
|
64 | 79 |
|
65 |
| -#' @export |
66 |
| -weighted_interval_score.default <- function(x, actual, |
67 |
| - quantile_levels = NULL, ...) { |
68 |
| - cli_abort(c( |
69 |
| - "Weighted interval score can only be calculated if `x`", |
70 |
| - "has class {.cls distribution}." |
71 |
| - )) |
72 |
| -} |
73 |
| - |
74 |
| -#' @export |
75 |
| -weighted_interval_score.distribution <- function( |
76 |
| - x, actual, |
77 |
| - quantile_levels = NULL, ...) { |
78 |
| - assert_numeric(actual, finite = TRUE) |
79 |
| - l <- vctrs::vec_recycle_common(x = x, actual = actual) |
80 |
| - map2_dbl( |
81 |
| - .x = vctrs::vec_data(l$x), |
82 |
| - .y = l$actual, |
83 |
| - .f = weighted_interval_score, |
84 |
| - quantile_levels = quantile_levels, |
85 |
| - ... |
86 |
| - ) |
87 |
| -} |
88 |
| - |
89 |
| -#' @export |
90 |
| -weighted_interval_score.dist_default <- function(x, actual, |
91 |
| - quantile_levels = NULL, ...) { |
92 |
| - rlang::check_dots_empty() |
93 |
| - if (is.null(quantile_levels)) { |
94 |
| - cli_warn(c( |
95 |
| - "Weighted interval score isn't implemented for {.cls {class(x)}}", |
96 |
| - "as we don't know what set of quantile levels to use.", |
97 |
| - "Use a {.cls dist_quantiles} or pass `quantile_levels`.", |
98 |
| - "The result for this element will be `NA`." |
99 |
| - )) |
100 |
| - return(NA) |
101 |
| - } |
102 |
| - x <- extrapolate_quantiles(x, probs = quantile_levels) |
103 |
| - weighted_interval_score(x, actual, quantile_levels = NULL) |
104 |
| -} |
105 | 80 |
|
106 |
| -#' @param na_handling character. Determines how `quantile_levels` without a |
107 |
| -#' corresponding `value` are handled. For `"impute"`, missing values will be |
108 |
| -#' calculated if possible using the available quantiles. For `"drop"`, |
109 |
| -#' explicitly missing values are ignored in the calculation of the score, but |
110 |
| -#' implicitly missing values are imputed if possible. |
111 |
| -#' For `"propogate"`, the resulting score will be `NA` if any missing values |
112 |
| -#' exist in the original `quantile_levels`. Finally, if |
113 |
| -#' `quantile_levels` is specified, `"fail"` will result in |
114 |
| -#' the score being `NA` when any required quantile levels (implicit or explicit) |
115 |
| -#' are do not have corresponding values. |
116 |
| -#' @describeIn weighted_interval_score Weighted interval score with |
117 |
| -#' `dist_quantiles` allows for different `NA` behaviours. |
118 | 81 | #' @export
|
119 |
| -weighted_interval_score.dist_quantiles <- function( |
| 82 | +weighted_interval_score.quantile_pred <- function( |
120 | 83 | x, actual,
|
121 | 84 | quantile_levels = NULL,
|
122 | 85 | na_handling = c("impute", "drop", "propagate", "fail"),
|
123 | 86 | ...) {
|
124 | 87 | rlang::check_dots_empty()
|
125 |
| - if (is.na(actual)) { |
126 |
| - return(NA) |
127 |
| - } |
128 |
| - if (all(is.na(vctrs::field(x, "values")))) { |
129 |
| - return(NA) |
130 |
| - } |
| 88 | + n <- vctrs::vec_size(x) |
| 89 | + if (length(actual) == 1L) actual <- rep(actual, n) |
| 90 | + assert_numeric(actual, finite = TRUE, len = n) |
| 91 | + assert_numeric(quantile_levels, lower = 0, upper = 1, null.ok = TRUE) |
131 | 92 | na_handling <- rlang::arg_match(na_handling)
|
132 |
| - old_quantile_levels <- field(x, "quantile_levels") |
| 93 | + old_quantile_levels <- x %@% "quantile_levels" |
133 | 94 | if (na_handling == "fail") {
|
134 | 95 | if (is.null(quantile_levels)) {
|
135 | 96 | cli_abort('`na_handling = "fail"` requires `quantile_levels` to be specified.')
|
136 | 97 | }
|
137 |
| - old_values <- field(x, "values") |
138 |
| - if (!all(quantile_levels %in% old_quantile_levels) || any(is.na(old_values))) { |
139 |
| - return(NA) |
| 98 | + if (!all(quantile_levels %in% old_quantile_levels)) { |
| 99 | + return(rep(NA_real_, n)) |
140 | 100 | }
|
141 | 101 | }
|
142 | 102 | tau <- quantile_levels %||% old_quantile_levels
|
143 |
| - x <- extrapolate_quantiles(x, probs = tau, replace_na = (na_handling == "impute")) |
144 |
| - q <- field(x, "values")[field(x, "quantile_levels") %in% tau] |
| 103 | + x <- extrapolate_quantiles(x, tau, replace_na = (na_handling == "impute")) |
| 104 | + x <- as.matrix(x)[, attr(x, "quantile_levels") %in% tau] |
145 | 105 | na_rm <- (na_handling == "drop")
|
| 106 | + map2_dbl(vctrs::vec_chop(x), actual, ~ wis_one_quantile(.x, tau, .y, na_rm)) |
| 107 | +} |
| 108 | + |
| 109 | +wis_one_quantile <- function(q, tau, actual, na_rm) { |
146 | 110 | 2 * mean(pmax(tau * (actual - q), (1 - tau) * (q - actual)), na.rm = na_rm)
|
147 | 111 | }
|
0 commit comments