Skip to content

Commit 9da2783

Browse files
authored
Merge pull request #54 from cmu-delphi/ds/refactors
refactor: small bugfixes and cleanup
2 parents 69af77d + 6ddb930 commit 9da2783

34 files changed

+875
-332
lines changed

DESCRIPTION

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,60 +2,39 @@ Package: epieval
22
Title: Evaluating Timeseries Forecasting on Archival Data
33
Version: 0.1.0
44
Date: 2023-09-28
5-
Authors@R:
5+
Authors@R:
66
c(
77
person("David", "Weber", email = "[email protected]", role = c("ctb", "cre")),
8-
person("Dmitry", "Shemetov", email = "[email protected]", role = c("aut"))
8+
person("Dmitry", "Shemetov", email = "[email protected]", role = c("aut")),
9+
person("Nat", "DeFries", email = "[email protected]", role = c("aut"))
910
)
1011
Description: Given a timeseries and accompanying auxillary timeseries, evaluate a collection of forecasters implementable with epipredict using several metrics on historical data. Assuming the source of your timeseries provides versioned data, the evaluation at any given timepoint will only use data that was available at that point.
1112
License: MIT + file LICENSE
1213
Depends:
13-
epiprocess (>= 0.6.0),
14-
epipredict,
1514
R (>= 3.5.0)
1615
Imports:
1716
assertthat,
1817
aws.s3,
1918
cli,
20-
distributional,
2119
dplyr,
2220
epidatr,
2321
epipredict,
24-
fs,
25-
generics,
26-
glue,
27-
hardhat (>= 1.3.0),
22+
epiprocess,
2823
here,
2924
lubridate,
3025
magrittr,
31-
methods,
32-
openssl,
3326
parsnip (>= 1.0.0),
3427
purrr,
35-
quantreg,
3628
recipes (>= 1.0.4),
3729
rlang,
38-
smoothqr,
39-
stats,
40-
targets,
4130
tibble,
42-
tidyr,
43-
tidyselect,
44-
usethis,
45-
vctrs,
46-
workflows (>= 1.0.0)
31+
tidyr
4732
Suggests:
48-
covidcast,
49-
data.table,
5033
ggplot2,
5134
knitr,
52-
pipeR,
5335
plotly,
54-
poissonreg,
5536
rmarkdown,
56-
shiny,
5737
testthat (>= 3.0.0),
58-
xgboost
5938
VignetteBuilder:
6039
knitr
6140
Remotes:

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
install:
2-
Rscript -e "install.packages(c('renv', 'pak'))"
2+
Rscript -e "install.packages(c('renv', 'pak', 'rspm'))"
33
Rscript -e "renv::restore()"
44

55
run:

NAMESPACE

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export(id_ahead_ensemble_grid)
1616
export(interval_coverage)
1717
export(lookup_ids)
1818
export(make_target_param_grid)
19+
export(manage_S3_forecast_cache)
1920
export(overprediction)
2021
export(perform_sanity_checks)
2122
export(read_external_predictions_data)
@@ -26,24 +27,64 @@ export(sharpness)
2627
export(single_id)
2728
export(underprediction)
2829
export(weighted_interval_score)
29-
import(dplyr)
30-
import(epipredict)
31-
import(openssl)
32-
import(recipes)
33-
import(rlang)
3430
importFrom(assertthat,assert_that)
31+
importFrom(aws.s3,get_bucket)
32+
importFrom(aws.s3,s3sync)
33+
importFrom(cli,cli_abort)
3534
importFrom(cli,hash_animal)
35+
importFrom(dplyr,across)
36+
importFrom(dplyr,any_of)
37+
importFrom(dplyr,everything)
38+
importFrom(dplyr,filter)
39+
importFrom(dplyr,group_by)
40+
importFrom(dplyr,inner_join)
41+
importFrom(dplyr,join_by)
42+
importFrom(dplyr,left_join)
43+
importFrom(dplyr,mutate)
44+
importFrom(dplyr,reframe)
45+
importFrom(dplyr,relocate)
46+
importFrom(dplyr,rename)
47+
importFrom(dplyr,rowwise)
48+
importFrom(dplyr,select)
49+
importFrom(dplyr,summarize)
50+
importFrom(dplyr,ungroup)
51+
importFrom(epipredict,add_frosting)
52+
importFrom(epipredict,arx_args_list)
3653
importFrom(epipredict,epi_recipe)
54+
importFrom(epipredict,epi_workflow)
55+
importFrom(epipredict,fit)
56+
importFrom(epipredict,flatline_args_list)
57+
importFrom(epipredict,flatline_forecaster)
58+
importFrom(epipredict,frosting)
59+
importFrom(epipredict,get_test_data)
60+
importFrom(epipredict,layer_add_target_date)
61+
importFrom(epipredict,layer_naomit)
62+
importFrom(epipredict,layer_point_from_distn)
63+
importFrom(epipredict,layer_population_scaling)
64+
importFrom(epipredict,layer_predict)
65+
importFrom(epipredict,layer_quantile_distn)
66+
importFrom(epipredict,layer_residual_quantiles)
67+
importFrom(epipredict,layer_threshold)
68+
importFrom(epipredict,nested_quantiles)
69+
importFrom(epipredict,step_epi_ahead)
70+
importFrom(epipredict,step_epi_lag)
71+
importFrom(epipredict,step_epi_naomit)
3772
importFrom(epipredict,step_population_scaling)
73+
importFrom(epipredict,step_training_window)
3874
importFrom(epiprocess,epix_slide)
39-
importFrom(lubridate,Date)
75+
importFrom(here,here)
4076
importFrom(magrittr,"%<>%")
4177
importFrom(magrittr,"%>%")
4278
importFrom(purrr,map)
79+
importFrom(purrr,map2_vec)
4380
importFrom(purrr,transpose)
81+
importFrom(recipes,all_numeric)
82+
importFrom(rlang,"!!")
4483
importFrom(rlang,.data)
4584
importFrom(rlang,quo)
4685
importFrom(rlang,sym)
4786
importFrom(rlang,syms)
4887
importFrom(tibble,tibble)
88+
importFrom(tidyr,expand_grid)
4989
importFrom(tidyr,pivot_wider)
90+
importFrom(tidyr,unnest)

R/epieval-package.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#' @importFrom magrittr %>% %<>%
2+
#' @importFrom dplyr select rename inner_join join_by mutate relocate any_of
3+
#' group_by reframe summarize left_join across filter rowwise everything ungroup
4+
#' @importFrom purrr transpose map map2_vec
5+
#' @keywords internal
6+
"_PACKAGE"
7+
globalVariables(c("ahead", "id", "parent_id", "all_of", "last_col", "time_value", "geo_value", "target_end_date", "forecast_date", "quantile", ".pred_distn", "quantiles", "quantile_levels", "signal", ".dstn", "values", ".", "forecasters","forecaster", "trainer", "forecast_date", ".pred", "n_distinct", "target_date", "value"))
8+
.onLoad <- function(libname, pkgname) {
9+
epidatr::set_cache(
10+
cache_dir = ".exploration_cache",
11+
days = 14,
12+
max_size = 4,
13+
confirm = FALSE
14+
)
15+
}

R/evaluation_utils.R

Lines changed: 0 additions & 36 deletions
This file was deleted.

R/exploration-tooling-package.R

Lines changed: 0 additions & 10 deletions
This file was deleted.

R/forecaster.R

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
7575
#' @param predictors a character vector of the columns used as predictors
7676
#' @param args_list an [`epipredict::arx_args_list`]
7777
#' @seealso [arx_postprocess] for the layer equivalent
78+
#' @importFrom epipredict step_epi_lag step_epi_ahead step_epi_naomit step_training_window
7879
#' @export
7980
arx_preprocess <- function(rec, outcome, predictors, args_list) {
8081
# input already validated
@@ -104,6 +105,7 @@ arx_preprocess <- function(rec, outcome, predictors, args_list) {
104105
#' the default of `layer_add_target_date`, which is either
105106
#' `forecast_date+ahead`, or the `max time_value + ahead`
106107
#' @seealso [arx_preprocess] for the step equivalent
108+
#' @importFrom epipredict layer_predict layer_quantile_distn layer_point_from_distn layer_residual_quantiles layer_threshold layer_naomit layer_add_target_date
107109
#' @export
108110
arx_postprocess <- function(postproc,
109111
trainer,
@@ -112,7 +114,9 @@ arx_postprocess <- function(postproc,
112114
target_date = NULL) {
113115
postproc %<>% layer_predict()
114116
if (inherits(trainer, "quantile_reg")) {
115-
postproc %<>% layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>% layer_point_from_distn()
117+
postproc %<>%
118+
layer_quantile_distn(quantile_levels = args_list$quantile_levels) %>%
119+
layer_point_from_distn()
116120
} else {
117121
postproc %<>% layer_residual_quantiles(
118122
quantile_levels = args_list$quantile_levels, symmetrize = args_list$symmetrize,
@@ -123,7 +127,8 @@ arx_postprocess <- function(postproc,
123127
postproc %<>% layer_threshold(dplyr::starts_with(".pred"))
124128
}
125129

126-
postproc %<>% layer_naomit(dplyr::starts_with(".pred")) %>%
130+
postproc %<>%
131+
layer_naomit(dplyr::starts_with(".pred")) %>%
127132
layer_add_target_date(target_date = target_date)
128133
return(postproc)
129134
}
@@ -136,7 +141,7 @@ arx_postprocess <- function(postproc,
136141
#' @param trainer the parsnip trainer
137142
#' @param epi_data the actual epi_df to train on
138143
#' @export
139-
#' @import epipredict recipes
144+
#' @importFrom epipredict epi_workflow fit add_frosting get_test_data
140145
run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
141146
workflow <- epi_workflow(preproc, trainer) %>%
142147
fit(epi_data) %>%
@@ -171,8 +176,9 @@ run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
171176
#' contain `ahead`
172177
#' @param forecaster_args_names a bit of a hack around targets, it contains
173178
#' the names of the `forecaster_args`.
174-
#' @import rlang epipredict dplyr
175179
#' @importFrom epiprocess epix_slide
180+
#' @importFrom cli cli_abort
181+
#' @importFrom rlang !!
176182
#' @export
177183
forecaster_pred <- function(data,
178184
outcome,
@@ -187,7 +193,7 @@ forecaster_pred <- function(data,
187193
names(forecaster_args) <- forecaster_args_names
188194
}
189195
if (is.null(forecaster_args$ahead)) {
190-
cli::cli_abort(
196+
cli_abort(
191197
c(
192198
"exploration-tooling error: forecaster_pred needs some value for ahead."
193199
),

R/forecaster_flatline.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#' does not support `lags` as a parameter, but otherwise has the same parameters as `arx_forecaster`
55
#' @inheritParams scaled_pop
66
#' @importFrom rlang sym
7+
#' @importFrom epipredict flatline_forecaster flatline_args_list
78
#' @export
89
flatline_fc <- function(epi_data,
910
outcome,
@@ -25,8 +26,8 @@ flatline_fc <- function(epi_data,
2526
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
2627
null_result <- tibble(
2728
geo_value = character(),
28-
forecast_date = Date(),
29-
target_end_date = Date(),
29+
forecast_date = lubridate::Date(),
30+
target_end_date = lubridate::Date(),
3031
quantile = numeric(),
3132
value = numeric()
3233
)

R/forecaster_scaled_pop.R

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@
3636
#' covidhub.
3737
#' @seealso some utilities for making forecasters: [format_storage],
3838
#' [perform_sanity_checks]
39-
#' @import recipes epipredict
40-
#' @importFrom magrittr %>% %<>%
41-
#' @importFrom epipredict epi_recipe step_population_scaling
39+
#' @importFrom epipredict epi_recipe step_population_scaling frosting arx_args_list layer_population_scaling
4240
#' @importFrom tibble tibble
43-
#' @importFrom lubridate Date
41+
#' @importFrom recipes all_numeric
4442
#' @export
4543
scaled_pop <- function(epi_data,
4644
outcome,
@@ -63,8 +61,8 @@ scaled_pop <- function(epi_data,
6361
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
6462
null_result <- tibble(
6563
geo_value = character(),
66-
forecast_date = Date(),
67-
target_end_date = Date(),
64+
forecast_date = lubridate::Date(),
65+
target_end_date = lubridate::Date(),
6866
quantile = numeric(),
6967
value = numeric()
7068
)
@@ -88,7 +86,7 @@ scaled_pop <- function(epi_data,
8886
if (pop_scaling) {
8987
preproc %<>% step_population_scaling(
9088
all_numeric(),
91-
df = state_census,
89+
df = epipredict::state_census,
9290
df_pop_col = "pop",
9391
create_new = FALSE,
9492
rate_rescaling = 1e5,
@@ -104,7 +102,7 @@ scaled_pop <- function(epi_data,
104102
if (pop_scaling) {
105103
postproc %<>% layer_population_scaling(
106104
.pred, .pred_distn,
107-
df = state_census,
105+
df = epipredict::state_census,
108106
df_pop_col = "pop",
109107
create_new = FALSE,
110108
rate_rescaling = 1e5,

R/formatters.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
#' @param true_forecast_date the actual date from which the model is
99
#' making the forecast, rather than the last day of available data
1010
#' @param target_end_date the date of the prediction
11-
#' @import dplyr epipredict
12-
#' @importFrom magrittr %>% %<>%
11+
#' @importFrom epipredict nested_quantiles
12+
#' @importFrom tidyr unnest
1313
#' @export
1414
format_storage <- function(pred, true_forecast_date, target_end_date) {
1515
pred %>%
@@ -34,7 +34,6 @@ format_storage <- function(pred, true_forecast_date, target_end_date) {
3434
#' making the forecast, rather than the last day of available data
3535
#' @param target_end_date the date of the prediction
3636
#' @param quantile_levels the quantile levels
37-
#' @import dplyr
3837
format_covidhub <- function(pred, true_forecast_date, target_end_date, quantile_levels) {
3938
pred %<>%
4039
group_by(forecast_date, geo_value, target_date) %>%

R/manage_S3_cache.R

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1+
#' Manage S3 cache
2+
#' @param rel_cache_dir The relative path to the cache directory, e.g.
3+
#' "data-processed/2021-09-01".
4+
#' @param bucket_name The name of the S3 bucket to sync.
5+
#' @param direction Set 'download' to download files or 'upload' to upload
6+
#' files.
7+
#' @param verbose Set to TRUE to print the files being synced.
8+
#'
9+
#' @importFrom aws.s3 s3sync get_bucket
10+
#' @importFrom here here
11+
#' @export
112
manage_S3_forecast_cache <- function(rel_cache_dir, bucket_name = "forecasting-team-data", direction = "download", verbose = FALSE) {
2-
cache_path <- here::here(rel_cache_dir)
13+
cache_path <- here(rel_cache_dir)
314
if (!dir.exists(cache_path)) dir.create(cache_path)
415

5-
s3b <- aws.s3::get_bucket(bucket_name)
16+
s3b <- get_bucket(bucket_name)
617
if (verbose) {
7-
aws.s3::s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction)
18+
s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction)
819
} else {
920
sink("/dev/null")
10-
aws.s3::s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction, verbose = FALSE)
21+
s3sync(cache_path, s3b, paste0("covid-hosp-forecast/", rel_cache_dir), direction = direction, verbose = FALSE)
1122
sink()
1223
}
1324
return(TRUE)

0 commit comments

Comments
 (0)