1
- # ' returns both the point estimate and the quantile distribution, regardless of the underlying trainer
2
- # ' f
1
+ # ' returns both the point estimate and the quantile distribution
2
+ # ' @description
3
+ # ' This function adds a frosting layer that produces both a point estimate as
4
+ # ' well as quantile estimates.
5
+ # ' @param distn_id a random id string for the layer that creates the quantile
6
+ # ' estimate
7
+ # ' @param point_id a random id string for the layer that creates the point
8
+ # ' estimate. Only present for trainers that produce quantiles
9
+ # ' @param point_type character. Either `mean` or `median`.
10
+ # ' @param use_predictive_distribution only usable for `linear_reg` type models
11
+ # ' @param distn_type character. Only used if `use_predictive_distribution=TRUE`,
12
+ # ' for `linear_reg` type models. Either gaussian or student_t
13
+ # ' @param distn_name an alternate name for the distribution column; defaults
14
+ # ' to `.pred_distn`.
15
+ # ' @param point_name an alternate name for the point estimate column; defaults
16
+ # ' to `.pred`.
17
+ # ' @param symmetrize logical. If `TRUE` then interval will be symmetric.
18
+ # ' Applies for residual quantiles only
19
+ # ' @param by_key A character vector of keys to group the residuals by before
20
+ # ' calculating quantiles. The default, `c()` performs no grouping. Only used
21
+ # ' by `layer_residual_quantiles`
22
+ # ' @inheritParams layer_quantile_distn
23
+ # ' @export
24
+ # ' @return an updated `frosting postprocessor` with an additional prediction
25
+ # ' column; if the trainer produces a point estimate, it has added a
26
+ # ' distribution estimate, and vice versa.
3
27
layer_point_and_distn <- function (frosting , trainer , ... ,
4
- probs = c(0.05 , 0.95 ),
28
+ levels = c(0.25 , 0.75 ),
5
29
symmetrize = TRUE ,
6
30
by_key = character (0L ),
7
31
distn_name = " .pred_distn" ,
@@ -10,26 +34,27 @@ layer_point_and_distn <- function(frosting, trainer, ...,
10
34
point_id = NULL ,
11
35
point_type = c(" median" , " mean" ),
12
36
truncate = c(- Inf , Inf ),
13
- use_predictive_distribution = TRUE ,
37
+ use_predictive_distribution = FALSE ,
14
38
dist_type = " gaussian" ) {
15
39
rlang :: check_dots_empty()
16
40
stopifnot(inherits(recipe , " recipe" ))
17
- # not sure what to do about the dots...
41
+ levels <- sort( levels )
18
42
if (inherits(trainer , " quantile_reg" )) {
43
+ # sort the probabilities
19
44
tau <- sort(compare_quantile_args(
20
- args_list $ levels ,
45
+ levels ,
21
46
rlang :: eval_tidy(trainer $ args $ tau )
22
47
))
23
- args_list $ levels <- tau
48
+ levels <- tau
24
49
trainer $ args $ tau <- rlang :: enquo(tau )
25
50
if (is.null(point_id )) {
26
51
point_id <- rand_id(" point_from_distn" )
27
52
}
28
53
if (is.null(distn_id )) {
29
54
distn_id <- rand_id(" quantile_distn" )
30
55
}
31
- frosting %<> % layer_quantile_distn(... ,
32
- levels = tau ,
56
+ frosting %<> % layer_quantile_distn(
57
+ levels = levels ,
33
58
truncate = trucate ,
34
59
name = distn_name ,
35
60
id = distn_id
@@ -44,7 +69,7 @@ layer_point_and_distn <- function(frosting, trainer, ...,
44
69
distn_id <- rand_id(" residual_quantiles" )
45
70
}
46
71
if (inherits(trainer , " linear_reg" ) && use_predictive_distribution ) {
47
- frosting %<> % layer_residual_quantiles (
72
+ frosting %<> % layer_predictive_distn (
48
73
dist_type = dist_type ,
49
74
name = distn_name ,
50
75
id = distn_id
0 commit comments