Skip to content

Commit 77a2a09

Browse files
authored
Merge pull request #183 from cmu-delphi/smooth-quant-reg
Smooth quant reg
2 parents d6e685a + f5c794a commit 77a2a09

11 files changed

+422
-22
lines changed

DESCRIPTION

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Imports:
3737
quantreg,
3838
recipes (>= 1.0.4),
3939
rlang,
40+
smoothqr,
4041
stats,
4142
tibble,
4243
tidyr,
@@ -61,7 +62,8 @@ VignetteBuilder:
6162
knitr
6263
Remotes:
6364
cmu-delphi/epidatr,
64-
cmu-delphi/epiprocess@dev
65+
cmu-delphi/epiprocess@dev,
66+
dajmcdon/smoothqr
6567
Config/testthat/edition: 3
6668
Encoding: UTF-8
6769
LazyData: true

NAMESPACE

+6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ S3method(extrapolate_quantiles,dist_default)
3333
S3method(extrapolate_quantiles,dist_quantiles)
3434
S3method(extrapolate_quantiles,distribution)
3535
S3method(format,dist_quantiles)
36+
S3method(is.na,dist_quantiles)
37+
S3method(is.na,distribution)
3638
S3method(mean,dist_quantiles)
3739
S3method(median,dist_quantiles)
3840
S3method(predict,epi_workflow)
@@ -61,6 +63,7 @@ S3method(print,layer_predictive_distn)
6163
S3method(print,layer_quantile_distn)
6264
S3method(print,layer_residual_quantiles)
6365
S3method(print,layer_threshold)
66+
S3method(print,layer_unnest)
6467
S3method(print,step_epi_ahead)
6568
S3method(print,step_epi_lag)
6669
S3method(print,step_growth_rate)
@@ -81,6 +84,7 @@ S3method(slather,layer_predictive_distn)
8184
S3method(slather,layer_quantile_distn)
8285
S3method(slather,layer_residual_quantiles)
8386
S3method(slather,layer_threshold)
87+
S3method(slather,layer_unnest)
8488
S3method(snap,default)
8589
S3method(snap,dist_default)
8690
S3method(snap,dist_quantiles)
@@ -132,6 +136,7 @@ export(layer_predictive_distn)
132136
export(layer_quantile_distn)
133137
export(layer_residual_quantiles)
134138
export(layer_threshold)
139+
export(layer_unnest)
135140
export(nested_quantiles)
136141
export(new_default_epi_recipe_blueprint)
137142
export(new_epi_recipe_blueprint)
@@ -140,6 +145,7 @@ export(prep)
140145
export(quantile_reg)
141146
export(remove_frosting)
142147
export(slather)
148+
export(smooth_quantile_reg)
143149
export(step_epi_ahead)
144150
export(step_epi_lag)
145151
export(step_epi_naomit)

R/create-layer.R

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
#'
2121
create_layer <- function(name = NULL, open = rlang::is_interactive()) {
2222
name <- name %||% usethis:::get_active_r_file(path = "R")
23+
if (substr(name, 1, 5) == "layer") {
24+
nn <- substring(name, 6)
25+
if (substr(nn, 1, 1) == "_") nn <- substring(nn, 2)
26+
cli::cli_abort(
27+
c('`name` should not begin with "layer" or "layer_".',
28+
i = 'Did you mean to use `create_layer("{ nn }")`?')
29+
)
30+
}
2331
layer_name <- name
2432
name <- paste0("layer_", name)
2533
name <- usethis:::slug(name, "R")

R/dist_quantiles.R

+12
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,15 @@ Ops.dist_quantiles <- function(e1, e2) {
360360
new_quantiles(q = q, tau = tau)
361361
}
362362

363+
#' @method is.na distribution
364+
#' @export
365+
is.na.distribution <- function(x) {
366+
sapply(vctrs::vec_data(x), is.na)
367+
}
368+
369+
#' @method is.na dist_quantiles
370+
#' @export
371+
is.na.dist_quantiles <- function(x) {
372+
q <- field(x, "q")
373+
all(is.na(q))
374+
}

R/layer_unnest.R

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#' Unnest prediction list-cols
2+
#'
3+
#' @param frosting a `frosting` postprocessor
4+
#' @param ... <[`tidy-select`][dplyr::dplyr_tidy_select]> One or more unquoted
5+
#' expressions separated by commas. Variable names can be used as if they
6+
#' were positions in the data frame, so expressions like `x:y` can
7+
#' be used to select a range of variables.
8+
#' @param id a random id string
9+
#'
10+
#' @return an updated `frosting` postprocessor
11+
#' @export
12+
layer_unnest <- function(frosting, ..., id = rand_id("unnest")) {
13+
arg_is_chr_scalar(id)
14+
15+
add_layer(
16+
frosting,
17+
layer_unnest_new(
18+
terms = dplyr::enquos(...),
19+
id = id
20+
)
21+
)
22+
}
23+
24+
layer_unnest_new <- function(terms, id) {
25+
layer("unnest", terms = terms, id = id)
26+
}
27+
28+
#' @export
29+
slather.layer_unnest <-
30+
function(object, components, the_fit, the_recipe, ...) {
31+
exprs <- rlang::expr(c(!!!object$terms))
32+
pos <- tidyselect::eval_select(exprs, components$predictions)
33+
col_names <- names(pos)
34+
components$predictions <- components$predictions %>%
35+
tidyr::unnest(col_names)
36+
37+
components
38+
}
39+
40+
#' @export
41+
print.layer_unnest <- function(
42+
x, width = max(20, options()$width - 30), ...) {
43+
44+
title <- "Unnesting prediction list-cols"
45+
print_layer(x$terms, title = title, width = width)
46+
}

R/make_smooth_quantile_reg.R

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
2+
#' Smooth quantile regression
3+
#'
4+
#' @description
5+
#' `smooth_quantile_reg()` generates a quantile regression model _specification_ for
6+
#' the [tidymodels](https://www.tidymodels.org/) framework. Currently, the
7+
#' only supported engine is [smoothqr::smooth_qr()].
8+
#'
9+
#' @param mode A single character string for the type of model.
10+
#' The only possible value for this model is "regression".
11+
#' @param engine Character string naming the fitting function. Currently, only
12+
#' "smooth_qr" is supported.
13+
#' @param tau A scalar or vector of values in (0, 1) to determine which
14+
#' quantiles to estimate (default is 0.5).
15+
#' @param outcome_locations Defaults to the vector `1:ncol(y)` but if the
16+
#' responses are observed at a different spacing (or appear in a different
17+
#' order), that information should be used here. This
18+
#' argument will be mapped to the `ahead` argument of [smoothqr::smooth_qr()].
19+
#' @param degree the number of polynomials used for response smoothing. Must
20+
#' be no more than the number of responses.
21+
#' @export
22+
#'
23+
#' @seealso [fit.model_spec()], [set_engine()]
24+
#'
25+
#' @importFrom quantreg rq
26+
#' @examples
27+
#' tib <- data.frame(
28+
#' y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100),
29+
#' y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100),
30+
#' x1 = rnorm(100), x2 = rnorm(100))
31+
#' qr_spec <- smooth_quantile_reg(tau = c(.2, .5, .8), outcome_locations = 1:6)
32+
#' ff <- qr_spec %>% fit(cbind(y1, y2 , y3 , y4 , y5 , y6) ~ ., data = tib)
33+
#' p <- predict(ff, new_data = tib)
34+
#'
35+
#' x <- -99:99 / 100 * 2 * pi
36+
#' y <- sin(x) + rnorm(length(x), sd = .1)
37+
#' fd <- x[length(x) - 20]
38+
#' XY <- smoothqr::lagmat(y[1:(length(y) - 20)], c(-20:20))
39+
#' XY <- tibble::as_tibble(XY)
40+
#' qr_spec <- smooth_quantile_reg(tau = c(.2, .5, .8), outcome_locations = 20:1)
41+
#' tt <- qr_spec %>% fit_xy(x = XY[,21:41], y = XY[,1:20])
42+
#'
43+
#' library(tidyr)
44+
#' library(dplyr)
45+
#' pl <- predict(
46+
#' object = tt,
47+
#' new_data = XY[max(which(complete.cases(XY[,21:41]))), 21:41]
48+
#' )
49+
#' pl <- pl %>%
50+
#' unnest(.pred) %>%
51+
#' mutate(distn = nested_quantiles(distn)) %>%
52+
#' unnest(distn) %>%
53+
#' mutate(x = x[length(x) - 20] + ahead / 100 * 2 * pi,
54+
#' ahead = NULL) %>%
55+
#' pivot_wider(names_from = tau, values_from = q)
56+
#' plot(x, y, pch = 16, xlim = c(pi, 2 * pi), col = "lightgrey")
57+
#' curve(sin(x), add = TRUE)
58+
#' abline(v = fd, lty = 2)
59+
#' lines(pl$x, pl$`0.2`, col = "blue")
60+
#' lines(pl$x, pl$`0.8`, col = "blue")
61+
#' lines(pl$x, pl$`0.5`, col = "red")
62+
#' \dontrun{
63+
#' ggplot(data.frame(x = x, y = y), aes(x)) +
64+
#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") +
65+
#' geom_point(aes(y = y), colour = "grey") + # observed data
66+
#' geom_function(fun = sin, colour = "black") + # truth
67+
#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data
68+
#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction
69+
#' theme_bw() +
70+
#' coord_cartesian(xlim = c(0, NA)) +
71+
#' ylab("y")
72+
#' }
73+
smooth_quantile_reg <- function(
74+
mode = "regression",
75+
engine = "smoothqr",
76+
outcome_locations = NULL,
77+
tau = 0.5,
78+
degree = 3L) {
79+
80+
# Check for correct mode
81+
if (mode != "regression") rlang::abort("`mode` must be 'regression'")
82+
if (engine != "smoothqr") rlang::abort("`engine` must be 'smoothqr'")
83+
84+
arg_is_probabilities(tau)
85+
arg_is_pos_int(degree)
86+
arg_is_scalar(degree)
87+
arg_is_numeric(outcome_locations, allow_null = TRUE)
88+
if (is.unsorted(tau)) {
89+
rlang::warn("Sorting tau to increasing order.")
90+
tau <- sort(tau)
91+
}
92+
93+
args <- list(tau = rlang::enquo(tau), degree = rlang::enquo(degree),
94+
outcome_locations = rlang::enquo(outcome_locations))
95+
96+
# Save some empty slots for future parts of the specification
97+
parsnip::new_model_spec(
98+
"smooth_quantile_reg",
99+
args = args,
100+
eng_args = NULL,
101+
mode = mode,
102+
method = NULL,
103+
engine = engine
104+
)
105+
}
106+
107+
108+
make_smooth_quantile_reg <- function() {
109+
parsnip::set_new_model("smooth_quantile_reg")
110+
parsnip::set_model_mode("smooth_quantile_reg", "regression")
111+
parsnip::set_model_engine("smooth_quantile_reg", "regression", eng = "smoothqr")
112+
parsnip::set_dependency(
113+
"smooth_quantile_reg",
114+
eng = "smoothqr",
115+
pkg = "smoothqr",
116+
mode = "regression"
117+
)
118+
119+
parsnip::set_model_arg(
120+
model = "smooth_quantile_reg",
121+
eng = "smoothqr",
122+
parsnip = "tau",
123+
original = "tau",
124+
func = list(pkg = "smoothqr", fun = "smooth_qr"),
125+
has_submodel = FALSE
126+
)
127+
parsnip::set_model_arg(
128+
model = "smooth_quantile_reg",
129+
eng = "smoothqr",
130+
parsnip = "degree",
131+
original = "degree",
132+
func = list(pkg = "smoothqr", fun = "smooth_qr"),
133+
has_submodel = FALSE
134+
)
135+
parsnip::set_model_arg(
136+
model = "smooth_quantile_reg",
137+
eng = "smoothqr",
138+
parsnip = "outcome_locations",
139+
original = "aheads",
140+
func = list(pkg = "smoothqr", fun = "smooth_qr"),
141+
has_submodel = FALSE
142+
)
143+
144+
parsnip::set_fit(
145+
model = "smooth_quantile_reg",
146+
eng = "smoothqr",
147+
mode = "regression",
148+
value = list(
149+
interface = "data.frame",
150+
protect = c("x", "y"), # prevent user from touching these
151+
func = c(pkg = "smoothqr", fun = "smooth_qr"),
152+
defaults = list(intercept = TRUE)
153+
)
154+
)
155+
156+
parsnip::set_encoding(
157+
model = "smooth_quantile_reg",
158+
eng = "smoothqr",
159+
mode = "regression",
160+
options = list(
161+
predictor_indicators = "traditional", # factor -> dummy conversion w/ baseline
162+
compute_intercept = TRUE, # put an intercept into the design matrix
163+
remove_intercept = TRUE, # but then remove it, we'll put it back in the function
164+
allow_sparse_x = FALSE # quantgen::rq can't handle sparse x, unfortunately
165+
)
166+
)
167+
168+
process_smooth_qr_preds <- function(x, object) {
169+
object <- parsnip::extract_fit_engine(object)
170+
list_of_pred_distns <- lapply(x, function(p) {
171+
x <- unname(apply(
172+
p, 1, function(q) unname(sort(q, na.last = TRUE)), simplify = FALSE
173+
))
174+
dist_quantiles(x, list(object$tau))
175+
})
176+
n_preds <- length(list_of_pred_distns[[1]])
177+
nout <- length(list_of_pred_distns)
178+
tib <- tibble::tibble(
179+
ids = rep(seq(n_preds), times = nout),
180+
ahead = rep(object$aheads, each = n_preds),
181+
distn = do.call(c, unname(list_of_pred_distns))) %>%
182+
tidyr::nest(.pred = c(ahead, distn))
183+
184+
return(tib[".pred"])
185+
}
186+
187+
188+
parsnip::set_pred(
189+
model = "smooth_quantile_reg",
190+
eng = "smoothqr",
191+
mode = "regression",
192+
type = "numeric",
193+
value = list(
194+
pre = NULL,
195+
post = process_smooth_qr_preds,
196+
func = c(fun = "predict"),
197+
args = list(object = quote(object$fit), newdata = quote(new_data))
198+
)
199+
)
200+
}
201+

0 commit comments

Comments
 (0)