Skip to content

Commit d7301dc

Browse files
authored
Merge pull request #88 from cmu-delphi/add-distn
Add distn
2 parents 018aab2 + 1306521 commit d7301dc

27 files changed

+959
-172
lines changed

DESCRIPTION

+6-7
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,39 @@ Depends:
1515
Imports:
1616
assertthat,
1717
cli,
18+
distributional,
1819
dplyr,
1920
epiprocess,
2021
fs,
2122
generics,
2223
glue,
23-
hardhat (>= 1.1.0.9000),
24+
hardhat (>= 1.2.0),
2425
magrittr,
2526
purrr,
26-
recipes (>= 0.2.0.9001),
27+
recipes (>= 1.0.0),
2728
rlang,
2829
stats,
2930
tensr,
3031
tibble,
3132
tidyr,
3233
tidyselect,
3334
usethis,
35+
vctrs,
3436
workflows
3537
Suggests:
3638
covidcast,
3739
data.table,
3840
ggplot2,
3941
knitr,
4042
lubridate,
41-
parsnip (>= 0.2.1.9001),
43+
parsnip (>= 1.0.0),
4244
RcppRoll,
4345
rmarkdown,
4446
testthat (>= 3.0.0)
4547
VignetteBuilder:
4648
knitr
4749
Remotes:
48-
dajmcdon/epiprocess,
49-
tidymodels/hardhat,
50-
tidymodels/parsnip,
51-
tidymodels/recipes
50+
dajmcdon/epiprocess
5251
Config/testthat/edition: 3
5352
Encoding: UTF-8
5453
Roxygen: list(markdown = TRUE)

NAMESPACE

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(Math,dist_quantiles)
4+
S3method(Ops,dist_quantiles)
35
S3method(apply_frosting,default)
46
S3method(apply_frosting,epi_workflow)
57
S3method(augment,epi_workflow)
@@ -20,6 +22,11 @@ S3method(extract_argument,recipe)
2022
S3method(extract_argument,step)
2123
S3method(extract_layers,frosting)
2224
S3method(extract_layers,workflow)
25+
S3method(extrapolate_quantiles,dist_default)
26+
S3method(extrapolate_quantiles,dist_quantiles)
27+
S3method(extrapolate_quantiles,distribution)
28+
S3method(format,dist_quantiles)
29+
S3method(median,dist_quantiles)
2330
S3method(predict,epi_workflow)
2431
S3method(prep,epi_recipe)
2532
S3method(prep,step_epi_ahead)
@@ -28,12 +35,20 @@ S3method(print,epi_workflow)
2835
S3method(print,frosting)
2936
S3method(print,step_epi_ahead)
3037
S3method(print,step_epi_lag)
38+
S3method(quantile,dist_quantiles)
3139
S3method(refresh_blueprint,default_epi_recipe_blueprint)
3240
S3method(run_mold,default_epi_recipe_blueprint)
3341
S3method(slather,layer_naomit)
3442
S3method(slather,layer_predict)
35-
S3method(slather,layer_residual_quantile)
43+
S3method(slather,layer_predictive_distn)
44+
S3method(slather,layer_residual_quantiles)
3645
S3method(slather,layer_threshold)
46+
S3method(snap,default)
47+
S3method(snap,dist_default)
48+
S3method(snap,dist_quantiles)
49+
S3method(snap,distribution)
50+
S3method(vec_ptype_abbr,dist_quantiles)
51+
S3method(vec_ptype_full,dist_quantiles)
3752
export("%>%")
3853
export(add_epi_recipe)
3954
export(add_frosting)
@@ -46,12 +61,14 @@ export(create_layer)
4661
export(default_epi_recipe_blueprint)
4762
export(detect_layer)
4863
export(df_mat_mul)
64+
export(dist_quantiles)
4965
export(epi_keys)
5066
export(epi_recipe)
5167
export(epi_recipe_blueprint)
5268
export(epi_workflow)
5369
export(extract_argument)
5470
export(extract_layers)
71+
export(extrapolate_quantiles)
5572
export(frosting)
5673
export(get_precision)
5774
export(get_test_data)
@@ -66,8 +83,10 @@ export(knnarx_forecaster)
6683
export(layer)
6784
export(layer_naomit)
6885
export(layer_predict)
69-
export(layer_residual_quantile)
86+
export(layer_predictive_distn)
87+
export(layer_residual_quantiles)
7088
export(layer_threshold)
89+
export(nested_quantiles)
7190
export(new_default_epi_recipe_blueprint)
7291
export(new_epi_recipe_blueprint)
7392
export(remove_frosting)
@@ -78,7 +97,9 @@ export(step_epi_ahead)
7897
export(step_epi_lag)
7998
export(step_epi_naomit)
8099
export(validate_layer)
100+
import(distributional)
81101
import(recipes)
102+
import(vctrs)
82103
importFrom(generics,augment)
83104
importFrom(generics,fit)
84105
importFrom(hardhat,refresh_blueprint)
@@ -91,10 +112,13 @@ importFrom(rlang,abort)
91112
importFrom(rlang,caller_env)
92113
importFrom(rlang,is_null)
93114
importFrom(stats,as.formula)
115+
importFrom(stats,family)
94116
importFrom(stats,lm)
117+
importFrom(stats,median)
95118
importFrom(stats,model.frame)
96119
importFrom(stats,poly)
97120
importFrom(stats,predict)
121+
importFrom(stats,qnorm)
98122
importFrom(stats,quantile)
99123
importFrom(stats,residuals)
100124
importFrom(stats,setNames)

R/arx_forecaster_mod.R

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
arx_epi_forecaster <- function(epi_data, response,
2+
...,
3+
trainer = parsnip::linear_reg(),
4+
args_list = arx_args_list()) {
5+
6+
r <- epi_recipe(epi_data) %>%
7+
step_epi_lag(..., lag = args_list$lags) %>% # hmmm, same for all predictors
8+
step_epi_ahead(response, ahead = args_list$ahead) %>%
9+
# should use the internal function (in an open PR)
10+
recipes::step_naomit(recipes::all_predictors()) %>%
11+
recipes::step_naomit(recipes::all_outcomes(), skip = TRUE)
12+
# should limit the training window here (in an open PR)
13+
# What to do if insufficient training data? Add issue.
14+
# remove intercept? not sure how this is implemented in tidymodels
15+
f <- frosting() %>%
16+
layer_predict() %>%
17+
layer_naomit(.pred) %>%
18+
layer_residual_quantile(
19+
probs = args_list$levels,
20+
symmetrize = args_list$symmetrize) %>%
21+
layer_threshold(.pred, dplyr::starts_with("q")) #, .flag = args_list$nonneg) in open PR
22+
# need the target date processing here
23+
24+
latest <- get_test_data(r, epi_data)
25+
26+
epi_workflow(r, trainer) %>% # bug, issue 72
27+
add_frosting(f)
28+
29+
}

R/compat-purrr.R

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
map <- function(.x, .f, ...) {
5-
.f <- rlang::as_function(.f, env = global_env())
5+
.f <- rlang::as_function(.f, env = rlang::global_env())
66
lapply(.x, .f, ...)
77
}
88
walk <- function(.x, .f, ...) {
@@ -23,7 +23,7 @@ map_chr <- function(.x, .f, ...) {
2323
.rlang_purrr_map_mold(.x, .f, character(1), ...)
2424
}
2525
.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) {
26-
.f <- rlang::as_function(.f, env = global_env())
26+
.f <- rlang::as_function(.f, env = rlang::global_env())
2727
out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE)
2828
names(out) <- names(.x)
2929
out
@@ -41,12 +41,12 @@ map_chr <- function(.x, .f, ...) {
4141
}
4242

4343
map2 <- function(.x, .y, .f, ...) {
44-
.f <- as_function(.f, env = global_env())
44+
.f <- rlang::as_function(.f, env = rlang::global_env())
4545
out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE)
4646
if (length(out) == length(.x)) {
47-
set_names(out, names(.x))
47+
rlang::set_names(out, names(.x))
4848
} else {
49-
set_names(out, NULL)
49+
rlang::set_names(out, NULL)
5050
}
5151
}
5252
map2_lgl <- function(.x, .y, .f, ...) {

0 commit comments

Comments
 (0)