|
| 1 | + |
| 2 | +#' Smooth quantile regression |
| 3 | +#' |
| 4 | +#' @description |
| 5 | +#' `smooth_quantile_reg()` generates a quantile regression model _specification_ for |
| 6 | +#' the [tidymodels](https://www.tidymodels.org/) framework. Currently, the |
| 7 | +#' only supported engine is [smoothqr::smooth_qr()]. |
| 8 | +#' |
| 9 | +#' @param mode A single character string for the type of model. |
| 10 | +#' The only possible value for this model is "regression". |
| 11 | +#' @param engine Character string naming the fitting function. Currently, only |
| 12 | +#' "smooth_qr" is supported. |
| 13 | +#' @param tau A scalar or vector of values in (0, 1) to determine which |
| 14 | +#' quantiles to estimate (default is 0.5). |
| 15 | +#' @param outcome_locations Defaults to the vector `1:ncol(y)` but if the |
| 16 | +#' responses are observed at a different spacing (or appear in a different |
| 17 | +#' order), that information should be used here. This |
| 18 | +#' argument will be mapped to the `ahead` argument of [smoothqr::smooth_qr()]. |
| 19 | +#' @param degree the number of polynomials used for response smoothing. Must |
| 20 | +#' be no more than the number of responses. |
| 21 | +#' @export |
| 22 | +#' |
| 23 | +#' @seealso [fit.model_spec()], [set_engine()] |
| 24 | +#' |
| 25 | +#' @importFrom quantreg rq |
| 26 | +#' @examples |
| 27 | +#' tib <- data.frame( |
| 28 | +#' y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100), |
| 29 | +#' y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100), |
| 30 | +#' x1 = rnorm(100), x2 = rnorm(100)) |
| 31 | +#' qr_spec <- smooth_quantile_reg(tau = c(.2, .5, .8), outcome_locations = 1:6) |
| 32 | +#' ff <- qr_spec %>% fit(cbind(y1, y2 , y3 , y4 , y5 , y6) ~ ., data = tib) |
| 33 | +#' p <- predict(ff, new_data = tib) |
| 34 | +#' |
| 35 | +#' x <- -99:99 / 100 * 2 * pi |
| 36 | +#' y <- sin(x) + rnorm(length(x), sd = .1) |
| 37 | +#' fd <- x[length(x) - 20] |
| 38 | +#' XY <- smoothqr::lagmat(y[1:(length(y) - 20)], c(-20:20)) |
| 39 | +#' XY <- tibble::as_tibble(XY) |
| 40 | +#' qr_spec <- smooth_quantile_reg(tau = c(.2, .5, .8), outcome_locations = 20:1) |
| 41 | +#' tt <- qr_spec %>% fit_xy(x = XY[,21:41], y = XY[,1:20]) |
| 42 | +#' |
| 43 | +#' library(tidyr) |
| 44 | +#' library(dplyr) |
| 45 | +#' pl <- predict( |
| 46 | +#' object = tt, |
| 47 | +#' new_data = XY[max(which(complete.cases(XY[,21:41]))), 21:41] |
| 48 | +#' ) |
| 49 | +#' pl <- pl %>% |
| 50 | +#' unnest(.pred) %>% |
| 51 | +#' mutate(distn = nested_quantiles(distn)) %>% |
| 52 | +#' unnest(distn) %>% |
| 53 | +#' mutate(x = x[length(x) - 20] + ahead / 100 * 2 * pi, |
| 54 | +#' ahead = NULL) %>% |
| 55 | +#' pivot_wider(names_from = tau, values_from = q) |
| 56 | +#' plot(x, y, pch = 16, xlim = c(pi, 2 * pi), col = "lightgrey") |
| 57 | +#' curve(sin(x), add = TRUE) |
| 58 | +#' abline(v = fd, lty = 2) |
| 59 | +#' lines(pl$x, pl$`0.2`, col = "blue") |
| 60 | +#' lines(pl$x, pl$`0.8`, col = "blue") |
| 61 | +#' lines(pl$x, pl$`0.5`, col = "red") |
| 62 | +#' \dontrun{ |
| 63 | +#' ggplot(data.frame(x = x, y = y), aes(x)) + |
| 64 | +#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + |
| 65 | +#' geom_point(aes(y = y), colour = "grey") + # observed data |
| 66 | +#' geom_function(fun = sin, colour = "black") + # truth |
| 67 | +#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data |
| 68 | +#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction |
| 69 | +#' theme_bw() + |
| 70 | +#' coord_cartesian(xlim = c(0, NA)) + |
| 71 | +#' ylab("y") |
| 72 | +#' } |
| 73 | +smooth_quantile_reg <- function( |
| 74 | + mode = "regression", |
| 75 | + engine = "smoothqr", |
| 76 | + outcome_locations = NULL, |
| 77 | + tau = 0.5, |
| 78 | + degree = 3L) { |
| 79 | + |
| 80 | + # Check for correct mode |
| 81 | + if (mode != "regression") rlang::abort("`mode` must be 'regression'") |
| 82 | + if (engine != "smoothqr") rlang::abort("`engine` must be 'smoothqr'") |
| 83 | + |
| 84 | + arg_is_probabilities(tau) |
| 85 | + arg_is_pos_int(degree) |
| 86 | + arg_is_scalar(degree) |
| 87 | + arg_is_numeric(outcome_locations, allow_null = TRUE) |
| 88 | + if (is.unsorted(tau)) { |
| 89 | + rlang::warn("Sorting tau to increasing order.") |
| 90 | + tau <- sort(tau) |
| 91 | + } |
| 92 | + |
| 93 | + args <- list(tau = rlang::enquo(tau), degree = rlang::enquo(degree), |
| 94 | + outcome_locations = rlang::enquo(outcome_locations)) |
| 95 | + |
| 96 | + # Save some empty slots for future parts of the specification |
| 97 | + parsnip::new_model_spec( |
| 98 | + "smooth_quantile_reg", |
| 99 | + args = args, |
| 100 | + eng_args = NULL, |
| 101 | + mode = mode, |
| 102 | + method = NULL, |
| 103 | + engine = engine |
| 104 | + ) |
| 105 | +} |
| 106 | + |
| 107 | + |
| 108 | +make_smooth_quantile_reg <- function() { |
| 109 | + parsnip::set_new_model("smooth_quantile_reg") |
| 110 | + parsnip::set_model_mode("smooth_quantile_reg", "regression") |
| 111 | + parsnip::set_model_engine("smooth_quantile_reg", "regression", eng = "smoothqr") |
| 112 | + parsnip::set_dependency( |
| 113 | + "smooth_quantile_reg", |
| 114 | + eng = "smoothqr", |
| 115 | + pkg = "smoothqr", |
| 116 | + mode = "regression" |
| 117 | + ) |
| 118 | + |
| 119 | + parsnip::set_model_arg( |
| 120 | + model = "smooth_quantile_reg", |
| 121 | + eng = "smoothqr", |
| 122 | + parsnip = "tau", |
| 123 | + original = "tau", |
| 124 | + func = list(pkg = "smoothqr", fun = "smooth_qr"), |
| 125 | + has_submodel = FALSE |
| 126 | + ) |
| 127 | + parsnip::set_model_arg( |
| 128 | + model = "smooth_quantile_reg", |
| 129 | + eng = "smoothqr", |
| 130 | + parsnip = "degree", |
| 131 | + original = "degree", |
| 132 | + func = list(pkg = "smoothqr", fun = "smooth_qr"), |
| 133 | + has_submodel = FALSE |
| 134 | + ) |
| 135 | + parsnip::set_model_arg( |
| 136 | + model = "smooth_quantile_reg", |
| 137 | + eng = "smoothqr", |
| 138 | + parsnip = "outcome_locations", |
| 139 | + original = "aheads", |
| 140 | + func = list(pkg = "smoothqr", fun = "smooth_qr"), |
| 141 | + has_submodel = FALSE |
| 142 | + ) |
| 143 | + |
| 144 | + parsnip::set_fit( |
| 145 | + model = "smooth_quantile_reg", |
| 146 | + eng = "smoothqr", |
| 147 | + mode = "regression", |
| 148 | + value = list( |
| 149 | + interface = "data.frame", |
| 150 | + protect = c("x", "y"), # prevent user from touching these |
| 151 | + func = c(pkg = "smoothqr", fun = "smooth_qr"), |
| 152 | + defaults = list(intercept = TRUE) |
| 153 | + ) |
| 154 | + ) |
| 155 | + |
| 156 | + parsnip::set_encoding( |
| 157 | + model = "smooth_quantile_reg", |
| 158 | + eng = "smoothqr", |
| 159 | + mode = "regression", |
| 160 | + options = list( |
| 161 | + predictor_indicators = "traditional", # factor -> dummy conversion w/ baseline |
| 162 | + compute_intercept = TRUE, # put an intercept into the design matrix |
| 163 | + remove_intercept = TRUE, # but then remove it, we'll put it back in the function |
| 164 | + allow_sparse_x = FALSE # quantgen::rq can't handle sparse x, unfortunately |
| 165 | + ) |
| 166 | + ) |
| 167 | + |
| 168 | + process_smooth_qr_preds <- function(x, object) { |
| 169 | + object <- parsnip::extract_fit_engine(object) |
| 170 | + list_of_pred_distns <- lapply(x, function(p) { |
| 171 | + x <- unname(apply( |
| 172 | + p, 1, function(q) unname(sort(q, na.last = TRUE)), simplify = FALSE |
| 173 | + )) |
| 174 | + dist_quantiles(x, list(object$tau)) |
| 175 | + }) |
| 176 | + n_preds <- length(list_of_pred_distns[[1]]) |
| 177 | + nout <- length(list_of_pred_distns) |
| 178 | + tib <- tibble::tibble( |
| 179 | + ids = rep(seq(n_preds), times = nout), |
| 180 | + ahead = rep(object$aheads, each = n_preds), |
| 181 | + distn = do.call(c, unname(list_of_pred_distns))) %>% |
| 182 | + tidyr::nest(.pred = c(ahead, distn)) |
| 183 | + |
| 184 | + return(tib[".pred"]) |
| 185 | + } |
| 186 | + |
| 187 | + |
| 188 | + parsnip::set_pred( |
| 189 | + model = "smooth_quantile_reg", |
| 190 | + eng = "smoothqr", |
| 191 | + mode = "regression", |
| 192 | + type = "numeric", |
| 193 | + value = list( |
| 194 | + pre = NULL, |
| 195 | + post = process_smooth_qr_preds, |
| 196 | + func = c(fun = "predict"), |
| 197 | + args = list(object = quote(object$fit), newdata = quote(new_data)) |
| 198 | + ) |
| 199 | + ) |
| 200 | +} |
| 201 | + |
0 commit comments