Skip to content

Commit ffa6909

Browse files
committed
refactor: clean epieval package imports
1 parent dddacac commit ffa6909

16 files changed

+165
-91
lines changed

DESCRIPTION

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,35 @@ Authors@R:
66
c(
77
person("David", "Weber", email = "[email protected]", role = c("ctb", "cre")),
88
person("Dmitry", "Shemetov", email = "[email protected]", role = c("aut")),
9-
person("Nat", "DeFries", email = "[email protected]", role = c("aut")),
9+
person("Nat", "DeFries", email = "[email protected]", role = c("aut"))
1010
)
1111
Description: Given a timeseries and accompanying auxillary timeseries, evaluate a collection of forecasters implementable with epipredict using several metrics on historical data. Assuming the source of your timeseries provides versioned data, the evaluation at any given timepoint will only use data that was available at that point.
1212
License: MIT + file LICENSE
1313
Depends:
14-
epiprocess (>= 0.6.0),
15-
epipredict,
1614
R (>= 3.5.0)
1715
Imports:
1816
assertthat,
1917
aws.s3,
2018
cli,
21-
distributional,
2219
dplyr,
2320
epidatr,
2421
epipredict,
25-
fs,
26-
generics,
27-
glue,
28-
hardhat (>= 1.3.0),
22+
epiprocess,
2923
here,
3024
lubridate,
3125
magrittr,
32-
methods,
33-
openssl,
3426
parsnip (>= 1.0.0),
3527
purrr,
36-
quantreg,
3728
recipes (>= 1.0.4),
3829
rlang,
39-
smoothqr,
40-
stats,
41-
targets,
4230
tibble,
43-
tidyr,
44-
tidyselect,
45-
usethis,
46-
vctrs,
47-
workflows (>= 1.0.0)
31+
tidyr
4832
Suggests:
49-
covidcast,
50-
data.table,
5133
ggplot2,
5234
knitr,
53-
pipeR,
5435
plotly,
55-
poissonreg,
5636
rmarkdown,
57-
shiny,
5837
testthat (>= 3.0.0),
59-
xgboost
6038
VignetteBuilder:
6139
knitr
6240
Remotes:

NAMESPACE

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
export("%>%")
34
export(absolute_error)
45
export(add_id)
56
export(arx_postprocess)
@@ -16,6 +17,7 @@ export(id_ahead_ensemble_grid)
1617
export(interval_coverage)
1718
export(lookup_ids)
1819
export(make_target_param_grid)
20+
export(manage_S3_forecast_cache)
1921
export(overprediction)
2022
export(perform_sanity_checks)
2123
export(read_external_predictions_data)
@@ -26,24 +28,63 @@ export(sharpness)
2628
export(single_id)
2729
export(underprediction)
2830
export(weighted_interval_score)
29-
import(dplyr)
30-
import(epipredict)
31-
import(openssl)
32-
import(recipes)
33-
import(rlang)
3431
importFrom(assertthat,assert_that)
32+
importFrom(aws.s3,get_bucket)
33+
importFrom(aws.s3,s3sync)
34+
importFrom(cli,cli_abort)
3535
importFrom(cli,hash_animal)
36+
importFrom(dplyr,across)
37+
importFrom(dplyr,any_of)
38+
importFrom(dplyr,everything)
39+
importFrom(dplyr,filter)
40+
importFrom(dplyr,group_by)
41+
importFrom(dplyr,inner_join)
42+
importFrom(dplyr,join_by)
43+
importFrom(dplyr,left_join)
44+
importFrom(dplyr,mutate)
45+
importFrom(dplyr,reframe)
46+
importFrom(dplyr,relocate)
47+
importFrom(dplyr,rename)
48+
importFrom(dplyr,rowwise)
49+
importFrom(dplyr,select)
50+
importFrom(dplyr,summarize)
51+
importFrom(dplyr,ungroup)
52+
importFrom(epipredict,add_frosting)
53+
importFrom(epipredict,arx_args_list)
3654
importFrom(epipredict,epi_recipe)
55+
importFrom(epipredict,epi_workflow)
56+
importFrom(epipredict,fit)
57+
importFrom(epipredict,flatline_args_list)
58+
importFrom(epipredict,flatline_forecaster)
59+
importFrom(epipredict,frosting)
60+
importFrom(epipredict,get_test_data)
61+
importFrom(epipredict,layer_add_target_date)
62+
importFrom(epipredict,layer_naomit)
63+
importFrom(epipredict,layer_point_from_distn)
64+
importFrom(epipredict,layer_population_scaling)
65+
importFrom(epipredict,layer_predict)
66+
importFrom(epipredict,layer_quantile_distn)
67+
importFrom(epipredict,layer_residual_quantiles)
68+
importFrom(epipredict,layer_threshold)
69+
importFrom(epipredict,nested_quantiles)
70+
importFrom(epipredict,step_epi_ahead)
71+
importFrom(epipredict,step_epi_lag)
72+
importFrom(epipredict,step_epi_naomit)
3773
importFrom(epipredict,step_population_scaling)
74+
importFrom(epipredict,step_training_window)
3875
importFrom(epiprocess,epix_slide)
39-
importFrom(lubridate,Date)
40-
importFrom(magrittr,"%<>%")
76+
importFrom(here,here)
4177
importFrom(magrittr,"%>%")
4278
importFrom(purrr,map)
79+
importFrom(purrr,map2_vec)
4380
importFrom(purrr,transpose)
81+
importFrom(recipes,all_numeric)
82+
importFrom(rlang,"!!")
4483
importFrom(rlang,.data)
4584
importFrom(rlang,quo)
4685
importFrom(rlang,sym)
4786
importFrom(rlang,syms)
4887
importFrom(tibble,tibble)
88+
importFrom(tidyr,expand_grid)
4989
importFrom(tidyr,pivot_wider)
90+
importFrom(tidyr,unnest)

R/forecaster.R

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,16 @@ confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
7575
#' @param predictors a character vector of the columns used as predictors
7676
#' @param args_list an [`epipredict::arx_args_list`]
7777
#' @seealso [arx_postprocess] for the layer equivalent
78+
#' @importFrom epipredict step_epi_lag step_epi_ahead step_epi_naomit step_training_window
7879
#' @export
7980
arx_preprocess <- function(rec, outcome, predictors, args_list) {
8081
# input already validated
8182
lags <- args_list$lags
8283
for (l in seq_along(lags)) {
8384
p <- predictors[l]
84-
rec %<>% step_epi_lag(!!p, lag = lags[[l]])
85+
rec <- rec %>% step_epi_lag(!!p, lag = lags[[l]])
8586
}
86-
rec %<>%
87+
rec <- rec %>%
8788
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
8889
step_epi_naomit() %>%
8990
step_training_window(n_recent = args_list$n_training)
@@ -104,26 +105,30 @@ arx_preprocess <- function(rec, outcome, predictors, args_list) {
104105
#' the default of `layer_add_target_date`, which is either
105106
#' `forecast_date+ahead`, or the `max time_value + ahead`
106107
#' @seealso [arx_preprocess] for the step equivalent
108+
#' @importFrom epipredict layer_predict layer_quantile_distn layer_point_from_distn layer_residual_quantiles layer_threshold layer_naomit layer_add_target_date
107109
#' @export
108110
arx_postprocess <- function(postproc,
109111
trainer,
110112
args_list,
111113
forecast_date = NULL,
112114
target_date = NULL) {
113-
postproc %<>% layer_predict()
115+
postproc <- postproc %>% layer_predict()
114116
if (inherits(trainer, "quantile_reg")) {
115-
postproc %<>% layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>% layer_point_from_distn()
117+
postproc <- postproc %>%
118+
layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>%
119+
layer_point_from_distn()
116120
} else {
117-
postproc %<>% layer_residual_quantiles(
121+
postproc <- postproc %>% layer_residual_quantiles(
118122
quantile_levels = args_list$quantile_levels, symmetrize = args_list$symmetrize,
119123
by_key = args_list$quantile_by_key
120124
)
121125
}
122126
if (args_list$nonneg) {
123-
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
127+
postproc <- postproc %>% layer_threshold(dplyr::starts_with(".pred"))
124128
}
125129

126-
postproc %<>% layer_naomit(dplyr::starts_with(".pred")) %>%
130+
postproc <- postproc %>%
131+
layer_naomit(dplyr::starts_with(".pred")) %>%
127132
layer_add_target_date(target_date = target_date)
128133
return(postproc)
129134
}
@@ -136,7 +141,7 @@ arx_postprocess <- function(postproc,
136141
#' @param trainer the parsnip trainer
137142
#' @param epi_data the actual epi_df to train on
138143
#' @export
139-
#' @import epipredict recipes
144+
#' @importFrom epipredict epi_workflow fit add_frosting get_test_data
140145
run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
141146
workflow <- epi_workflow(preproc, trainer) %>%
142147
fit(epi_data) %>%
@@ -171,8 +176,10 @@ run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
171176
#' contain `ahead`
172177
#' @param forecaster_args_names a bit of a hack around targets, it contains
173178
#' the names of the `forecaster_args`.
174-
#' @import rlang epipredict dplyr
175179
#' @importFrom epiprocess epix_slide
180+
#' @importFrom dplyr select rename inner_join join_by
181+
#' @importFrom cli cli_abort
182+
#' @importFrom rlang !!
176183
#' @export
177184
forecaster_pred <- function(data,
178185
outcome,
@@ -187,7 +194,7 @@ forecaster_pred <- function(data,
187194
names(forecaster_args) <- forecaster_args_names
188195
}
189196
if (is.null(forecaster_args$ahead)) {
190-
cli::cli_abort(
197+
cli_abort(
191198
c(
192199
"exploration-tooling error: forecaster_pred needs some value for ahead."
193200
),
@@ -228,14 +235,14 @@ forecaster_pred <- function(data,
228235
before = before,
229236
ref_time_values = valid_predict_dates,
230237
)
231-
res %<>% select(-time_value)
238+
res <- res %>% select(-time_value)
232239
names(res) <- sub("^slide_value_", "", names(res))
233240

234241
# append the truth data
235242
true_value <- archive$as_of(archive$versions_end) %>%
236243
select(geo_value, time_value, !!outcome) %>%
237244
rename(true_value = !!outcome)
238-
res %<>%
245+
res <- res %>%
239246
inner_join(true_value,
240247
by = join_by(geo_value, target_end_date == time_value)
241248
)

R/forecaster_flatline.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#' does not support `lags` as a parameter, but otherwise has the same parameters as `arx_forecaster`
55
#' @inheritParams scaled_pop
66
#' @importFrom rlang sym
7+
#' @importFrom epipredict flatline_forecaster flatline_args_list
78
#' @export
89
flatline_fc <- function(epi_data,
910
outcome,
@@ -25,8 +26,8 @@ flatline_fc <- function(epi_data,
2526
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
2627
null_result <- tibble(
2728
geo_value = character(),
28-
forecast_date = Date(),
29-
target_end_date = Date(),
29+
forecast_date = lubridate::Date(),
30+
target_end_date = lubridate::Date(),
3031
quantile = numeric(),
3132
value = numeric()
3233
)

R/forecaster_scaled_pop.R

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@
3636
#' covidhub.
3737
#' @seealso some utilities for making forecasters: [format_storage],
3838
#' [perform_sanity_checks]
39-
#' @import recipes epipredict
40-
#' @importFrom magrittr %>% %<>%
41-
#' @importFrom epipredict epi_recipe step_population_scaling
39+
#' @importFrom epipredict epi_recipe step_population_scaling frosting arx_args_list layer_population_scaling
4240
#' @importFrom tibble tibble
43-
#' @importFrom lubridate Date
41+
#' @importFrom recipes all_numeric
4442
#' @export
4543
scaled_pop <- function(epi_data,
4644
outcome,
@@ -63,8 +61,8 @@ scaled_pop <- function(epi_data,
6361
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
6462
null_result <- tibble(
6563
geo_value = character(),
66-
forecast_date = Date(),
67-
target_end_date = Date(),
64+
forecast_date = lubridate::Date(),
65+
target_end_date = lubridate::Date(),
6866
quantile = numeric(),
6967
value = numeric()
7068
)
@@ -86,25 +84,25 @@ scaled_pop <- function(epi_data,
8684
# preprocessing supported by epipredict
8785
preproc <- epi_recipe(epi_data)
8886
if (pop_scaling) {
89-
preproc %<>% step_population_scaling(
87+
preproc <- preproc %>% step_population_scaling(
9088
all_numeric(),
91-
df = state_census,
89+
df = epipredict::state_census,
9290
df_pop_col = "pop",
9391
create_new = FALSE,
9492
rate_rescaling = 1e5,
9593
by = c("geo_value" = "abbr")
9694
)
9795
}
98-
preproc %<>% arx_preprocess(outcome, predictors, args_list)
96+
preproc <- preproc %>% arx_preprocess(outcome, predictors, args_list)
9997

10098
# postprocessing supported by epipredict
10199
postproc <- frosting()
102-
postproc %<>% arx_postprocess(trainer, args_list)
100+
postproc <- postproc %>% arx_postprocess(trainer, args_list)
103101
postproc
104102
if (pop_scaling) {
105-
postproc %<>% layer_population_scaling(
103+
postproc <- postproc %>% layer_population_scaling(
106104
.pred, .pred_distn,
107-
df = state_census,
105+
df = epipredict::state_census,
108106
df_pop_col = "pop",
109107
create_new = FALSE,
110108
rate_rescaling = 1e5,

R/formatters.R

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
#' @param true_forecast_date the actual date from which the model is
99
#' making the forecast, rather than the last day of available data
1010
#' @param target_end_date the date of the prediction
11-
#' @import dplyr epipredict
12-
#' @importFrom magrittr %>% %<>%
11+
#' @importFrom epipredict nested_quantiles
12+
#' @importFrom dplyr mutate select rename relocate any_of
13+
#' @importFrom tidyr unnest
1314
#' @export
1415
format_storage <- function(pred, true_forecast_date, target_end_date) {
1516
pred %>%
@@ -34,14 +35,14 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
3435
#' making the forecast, rather than the last day of available data
3536
#' @param target_end_date the date of the prediction
3637
#' @param quantile_levels the quantile levels
37-
#' @import dplyr
38+
#' @importFrom dplyr group_by rename reframe mutate
3839
format_covidhub <- function(pred, true_forecast_date, target_end_date, quantile_levels) {
39-
pred %<>%
40+
pred <- pred %>%
4041
group_by(forecast_date, geo_value, target_date) %>%
4142
rename(target_end_date = target_date) %>%
4243
reframe(quantile = quantile_levels, value = quantile(.pred_distn, quantile_levels)[[1]])
4344
forecasts$ahead <- ahead
44-
forecasts %<>%
45+
forecasts <- forecasts %>%
4546
group_by(forecast_date, geo_value, target_date) %>%
4647
mutate(forecast_date = target_date - ahead) %>%
4748
rename(target_end_date = target_date) %>%

R/manage_S3_cache.R

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1+
#' Manage S3 cache
2+
#' @param rel_cache_dir The relative path to the cache directory, e.g.
3+
#' "data-processed/2021-09-01".
4+
#' @param bucket_name The name of the S3 bucket to sync.
5+
#' @param direction Set 'download' to download files or 'upload' to upload
6+
#' files.
7+
#' @param verbose Set to TRUE to print the files being synced.
8+
#'
9+
#' @importFrom aws.s3 s3sync get_bucket
10+
#' @importFrom here here
11+
#' @export
112
manage_S3_forecast_cache <- function(rel_cache_dir, bucket_name = "forecasting-team-data", direction = "download", verbose = FALSE) {
2-
cache_path <- here::here(rel_cache_dir)
13+
cache_path <- here(rel_cache_dir)
314
if (!dir.exists(cache_path)) dir.create(cache_path)
415

5-
s3b <- aws.s3::get_bucket(bucket_name)
16+
s3b <- get_bucket(bucket_name)
617
if (verbose) {
7-
aws.s3::s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction)
18+
s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction)
819
} else {
920
sink("/dev/null")
10-
aws.s3::s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction, verbose = FALSE)
21+
s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction, verbose = FALSE)
1122
sink()
1223
}
1324
return(TRUE)

0 commit comments

Comments
 (0)