Skip to content

Commit 786b3dc

Browse files
committed
docs+fix: point_and_distn
1 parent 6a58a9e commit 786b3dc

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ export(layer)
132132
export(layer_add_forecast_date)
133133
export(layer_add_target_date)
134134
export(layer_naomit)
135+
export(layer_point_and_distn)
135136
export(layer_point_from_distn)
136137
export(layer_population_scaling)
137138
export(layer_predict)

R/layer_point_and_distn.R

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,31 @@
1-
#' returns both the point estimate and the quantile distribution, regardless of the underlying trainer
2-
#' f
1+
#' returns both the point estimate and the quantile distribution
2+
#' @description
3+
#' This function adds a frosting layer that produces both a point estimate as
4+
#' well as quantile estimates.
5+
#' @param distn_id a random id string for the layer that creates the quantile
6+
#' estimate
7+
#' @param point_id a random id string for the layer that creates the point
8+
#' estimate. Only present for trainers that produce quantiles
9+
#' @param point_type character. Either `mean` or `median`.
10+
#' @param use_predictive_distribution only usable for `linear_reg` type models
11+
#' @param distn_type character. Only used if `use_predictive_distribution=TRUE`,
12+
#' for `linear_reg` type models. Either gaussian or student_t
13+
#' @param distn_name an alternate name for the distribution column; defaults
14+
#' to `.pred_distn`.
15+
#' @param point_name an alternate name for the point estimate column; defaults
16+
#' to `.pred`.
17+
#' @param symmetrize logical. If `TRUE` then interval will be symmetric.
18+
#' Applies for residual quantiles only
19+
#' @param by_key A character vector of keys to group the residuals by before
20+
#' calculating quantiles. The default, `c()` performs no grouping. Only used
21+
#' by `layer_residual_quantiles`
22+
#' @inheritParams layer_quantile_distn
23+
#' @export
24+
#' @return an updated `frosting postprocessor` with an additional prediction
25+
#' column; if the trainer produces a point estimate, it has added a
26+
#' distribution estimate, and vice versa.
327
layer_point_and_distn <- function(frosting, trainer, ...,
4-
probs = c(0.05, 0.95),
28+
levels = c(0.25, 0.75),
529
symmetrize = TRUE,
630
by_key = character(0L),
731
distn_name = ".pred_distn",
@@ -10,26 +34,27 @@ layer_point_and_distn <- function(frosting, trainer, ...,
1034
point_id = NULL,
1135
point_type = c("median", "mean"),
1236
truncate = c(-Inf, Inf),
13-
use_predictive_distribution = TRUE,
37+
use_predictive_distribution = FALSE,
1438
dist_type = "gaussian") {
1539
rlang::check_dots_empty()
1640
stopifnot(inherits(recipe, "recipe"))
17-
# not sure what to do about the dots...
41+
levels <- sort(levels)
1842
if (inherits(trainer, "quantile_reg")) {
43+
# sort the probabilities
1944
tau <- sort(compare_quantile_args(
20-
args_list$levels,
45+
levels,
2146
rlang::eval_tidy(trainer$args$tau)
2247
))
23-
args_list$levels <- tau
48+
levels <- tau
2449
trainer$args$tau <- rlang::enquo(tau)
2550
if (is.null(point_id)) {
2651
point_id <- rand_id("point_from_distn")
2752
}
2853
if (is.null(distn_id)) {
2954
distn_id <- rand_id("quantile_distn")
3055
}
31-
frosting %<>% layer_quantile_distn(...,
32-
levels = tau,
56+
frosting %<>% layer_quantile_distn(
57+
levels = levels,
3358
truncate = trucate,
3459
name = distn_name,
3560
id = distn_id
@@ -44,7 +69,7 @@ layer_point_and_distn <- function(frosting, trainer, ...,
4469
distn_id <- rand_id("residual_quantiles")
4570
}
4671
if (inherits(trainer, "linear_reg") && use_predictive_distribution) {
47-
frosting %<>% layer_residual_quantiles(
72+
frosting %<>% layer_predictive_distn(
4873
dist_type = dist_type,
4974
name = distn_name,
5075
id = distn_id

0 commit comments

Comments
 (0)