Skip to content

Commit 8a5d36f

Browse files
committed
merge frosting
Merge branch 'frosting' into add-distn # Conflicts: # NAMESPACE
2 parents 637f735 + 018aab2 commit 8a5d36f

14 files changed

+367
-184
lines changed

NAMESPACE

+10-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ S3method(Ops,dist_quantiles)
55
S3method(apply_frosting,default)
66
S3method(apply_frosting,epi_workflow)
77
S3method(augment,epi_workflow)
8-
S3method(bake,step_epi_shift)
8+
S3method(bake,step_epi_ahead)
9+
S3method(bake,step_epi_lag)
910
S3method(detect_layer,frosting)
1011
S3method(detect_layer,workflow)
1112
S3method(epi_keys,default)
@@ -14,6 +15,11 @@ S3method(epi_keys,recipe)
1415
S3method(epi_recipe,default)
1516
S3method(epi_recipe,epi_df)
1617
S3method(epi_recipe,formula)
18+
S3method(extract_argument,epi_workflow)
19+
S3method(extract_argument,frosting)
20+
S3method(extract_argument,layer)
21+
S3method(extract_argument,recipe)
22+
S3method(extract_argument,step)
1723
S3method(extract_layers,frosting)
1824
S3method(extract_layers,workflow)
1925
S3method(extrapolate_quantiles,dist_default)
@@ -23,12 +29,12 @@ S3method(format,dist_quantiles)
2329
S3method(median,dist_quantiles)
2430
S3method(predict,epi_workflow)
2531
S3method(prep,epi_recipe)
26-
S3method(prep,step_epi_shift)
32+
S3method(prep,step_epi_ahead)
33+
S3method(prep,step_epi_lag)
2734
S3method(print,epi_workflow)
2835
S3method(print,frosting)
2936
S3method(print,step_epi_ahead)
3037
S3method(print,step_epi_lag)
31-
S3method(print,step_epi_shift)
3238
S3method(quantile,dist_quantiles)
3339
S3method(refresh_blueprint,default_epi_recipe_blueprint)
3440
S3method(run_mold,default_epi_recipe_blueprint)
@@ -60,6 +66,7 @@ export(epi_keys)
6066
export(epi_recipe)
6167
export(epi_recipe_blueprint)
6268
export(epi_workflow)
69+
export(extract_argument)
6370
export(extract_layers)
6471
export(extrapolate_quantiles)
6572
export(frosting)
@@ -89,7 +96,6 @@ export(smooth_arx_forecaster)
8996
export(step_epi_ahead)
9097
export(step_epi_lag)
9198
export(step_epi_naomit)
92-
export(step_epi_shift)
9399
export(validate_layer)
94100
import(distributional)
95101
import(recipes)

R/compat-purrr.R

+16
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ map_chr <- function(.x, .f, ...) {
2929
out
3030
}
3131

32+
.rlang_purrr_args_recycle <- function(args) {
33+
lengths <- map_int(args, length)
34+
n <- max(lengths)
35+
36+
stopifnot(all(lengths == 1L | lengths == n))
37+
to_recycle <- lengths == 1L
38+
args[to_recycle] <- map(args[to_recycle], function(x) rep.int(x, n))
39+
40+
args
41+
}
42+
3243
map2 <- function(.x, .y, .f, ...) {
3344
.f <- rlang::as_function(.f, env = rlang::global_env())
3445
out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE)
@@ -63,3 +74,8 @@ pmap <- function(.l, .f, ...) {
6374
SIMPLIFY = FALSE, USE.NAMES = FALSE
6475
))
6576
}
77+
78+
reduce <- function(.x, .f, ..., .init) {
79+
f <- function(x, y) .f(x, y, ...)
80+
Reduce(f, .x, init = .init)
81+
}

R/epi_recipe.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ epi_recipe.default <- function(x, ...) {
4646
#' library(recipes)
4747
#'
4848
#' jhu <- case_death_rate_subset %>%
49-
#' filter(time_value > "2021-08-01") %>%
49+
#' dplyr::filter(time_value > "2021-08-01") %>%
5050
#' dplyr::arrange(geo_value, time_value)
5151
#'
5252
#' r <- epi_recipe(jhu) %>%

R/extract.R

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#' Extract an argument made to a frosting layer or recipe step
2+
#'
3+
#' @param x an epi_workflow, epi_recipe, frosting, step, or layer object
4+
#' @param name the name of the layer
5+
#' @param arg the name of the argument
6+
#' @param ... not used
7+
#'
8+
#' @return An object originally passed as an argument to a layer or step
9+
#' @export
10+
#'
11+
#' @examples
12+
#' f <- frosting() %>%
13+
#' layer_predict() %>%
14+
#' layer_residual_quantile(probs = c(0.0275, 0.975), symmetrize = FALSE) %>%
15+
#' layer_naomit(.pred)
16+
#'
17+
#' extract_argument(f, "layer_residual_quantile", "symmetrize")
18+
extract_argument <- function(x, name, arg, ...) {
19+
UseMethod("extract_argument")
20+
}
21+
22+
#' @export
23+
extract_argument.layer <- function(x, name, arg, ...) {
24+
rlang::check_dots_empty()
25+
arg_is_chr_scalar(name, arg)
26+
in_layer_name = class(x)[1]
27+
if (name != in_layer_name)
28+
cli_stop("Requested {name} not found. This is a(n) {in_layer_name}.")
29+
if (! arg %in% names(x))
30+
cli_stop("Requested argument {arg} not found in {name}.")
31+
x[[arg]]
32+
}
33+
34+
#' @export
35+
extract_argument.step <- function(x, name, arg, ...) {
36+
rlang::check_dots_empty()
37+
arg_is_chr_scalar(name, arg)
38+
in_step_name = class(x)[1]
39+
if (name != in_step_name)
40+
cli_stop("Requested {name} not found. This is a {in_step_name}.")
41+
if (! arg %in% names(x))
42+
cli_stop("Requested argument {arg} not found in {name}.")
43+
x[[arg]]
44+
}
45+
46+
#' @export
47+
extract_argument.recipe <- function(x, name, arg, ...){
48+
rlang::check_dots_empty()
49+
step_names <- map_chr(x$steps, ~class(.x)[1])
50+
has_step <- name %in% step_names
51+
if (!has_step)
52+
cli_stop("recipe object does not contain a {name}.")
53+
step_locations <- which(name == step_names)
54+
out <- map(x$steps[step_locations], extract_argument, name = name, arg = arg)
55+
if (length(out) == 1) out <- out[[1]]
56+
out
57+
}
58+
59+
#' @export
60+
extract_argument.frosting <- function(x, name, arg, ...) {
61+
rlang::check_dots_empty()
62+
layer_names <- map_chr(x$layers, ~ class(.x)[1])
63+
has_layer <- name %in% layer_names
64+
if (! has_layer)
65+
cli_stop("frosting object does not contain a {name} layer.")
66+
layer_locations <- which(name == layer_names)
67+
out <- map(x$layers[layer_locations], extract_argument, name = name, arg = arg)
68+
if (length(out) == 1) out <- out[[1]]
69+
out
70+
}
71+
72+
#' @export
73+
extract_argument.epi_workflow <- function(x, name, arg, ...) {
74+
rlang::check_dots_empty()
75+
type <- sub("_.*", "", name)
76+
if (type %in% c("check", "step")) {
77+
if (!workflows:::has_preprocessor_recipe(x))
78+
cli_stop("The workflow must have a recipe preprocessor.")
79+
out <- extract_argument(x$pre$actions$recipe$recipe, name, arg)
80+
}
81+
if (type %in% "layer")
82+
out <- extract_argument(extract_frosting(x), name, arg)
83+
if (! type %in% c("check", "step", "layer"))
84+
cli_stop("{name} must begin with one of step, check, or layer")
85+
return(out)
86+
}
87+
88+
89+
#' @export
90+
#' @rdname layer-processors
91+
extract_layers <- function(x, ...) {
92+
UseMethod("extract_layers")
93+
}
94+
95+
96+
#' @export
97+
#' @rdname layer-processors
98+
extract_layers.frosting <- function(x, ...) {
99+
rlang::check_dots_empty()
100+
x$layers
101+
}
102+
103+
#' @export
104+
#' @rdname layer-processors
105+
extract_layers.workflow <- function(x, ...) {
106+
rlang::check_dots_empty()
107+
validate_has_postprocessor(x)
108+
extract_layers(x$post$actions$frosting$frosting)
109+
}

R/get_test_data.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ get_test_data <- function(recipe, x){
2828
}
2929
## CHECK if it is epi_df?
3030

31-
max_lags <- max(map_dbl(recipe$steps, ~ max(.x$shift %||% 0)))
31+
max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)))
3232

3333
# CHECK: Return NA if insufficient training data
3434
if (dplyr::n_distinct(x$time_value) < max_lags) {

R/layers.R

-20
Original file line numberDiff line numberDiff line change
@@ -104,27 +104,7 @@ detect_layer.workflow <- function(x, name, ...) {
104104
detect_layer(x$post$actions$frosting$frosting, name)
105105
}
106106

107-
#' @export
108-
#' @rdname layer-processors
109-
extract_layers <- function(x, ...) {
110-
UseMethod("extract_layers")
111-
}
112-
113-
114-
#' @export
115-
#' @rdname layer-processors
116-
extract_layers.frosting <- function(x, ...) {
117-
rlang::check_dots_empty()
118-
x$layers
119-
}
120107

121-
#' @export
122-
#' @rdname layer-processors
123-
extract_layers.workflow <- function(x, ...) {
124-
rlang::check_dots_empty()
125-
validate_has_postprocessor(x)
126-
extract_layers(x$post$actions$frosting$frosting)
127-
}
128108

129109
#' Spread a layer of frosting on a fitted workflow
130110
#'

0 commit comments

Comments
 (0)