Skip to content

Commit 4f16887

Browse files
authored
Merge pull request #55 from cmu-delphi/ds/flu-hosp-explore
feat: flu hosp explore
2 parents 87fd65c + e996175 commit 4f16887

32 files changed

+780
-724
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
tmp/
55
extras/**.html
66
*.pdf
7+
.Renviron

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Imports:
2727
purrr,
2828
recipes (>= 1.0.4),
2929
rlang,
30+
targets,
3031
tibble,
3132
tidyr
3233
Suggests:

NAMESPACE

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ export(absolute_error)
44
export(add_id)
55
export(arx_postprocess)
66
export(arx_preprocess)
7+
export(clear_lastminute_nas)
78
export(collapse_cards)
8-
export(confirm_insufficient_data)
9+
export(confirm_sufficient_data)
910
export(covidhub_probs)
1011
export(evaluate_predictions)
1112
export(extend_ahead)
@@ -15,6 +16,12 @@ export(format_storage)
1516
export(id_ahead_ensemble_grid)
1617
export(interval_coverage)
1718
export(lookup_ids)
19+
export(make_data_targets)
20+
export(make_ensemble_targets)
21+
export(make_external_names_and_scores)
22+
export(make_forecasts_and_scores)
23+
export(make_forecasts_and_scores_by_ahead)
24+
export(make_shared_grids)
1825
export(make_target_param_grid)
1926
export(manage_S3_forecast_cache)
2027
export(overprediction)
@@ -71,6 +78,7 @@ importFrom(epipredict,step_epi_lag)
7178
importFrom(epipredict,step_epi_naomit)
7279
importFrom(epipredict,step_population_scaling)
7380
importFrom(epipredict,step_training_window)
81+
importFrom(epiprocess,as_epi_df)
7482
importFrom(epiprocess,epix_slide)
7583
importFrom(here,here)
7684
importFrom(magrittr,"%<>%")
@@ -84,7 +92,10 @@ importFrom(rlang,.data)
8492
importFrom(rlang,quo)
8593
importFrom(rlang,sym)
8694
importFrom(rlang,syms)
95+
importFrom(targets,tar_group)
96+
importFrom(targets,tar_target)
8797
importFrom(tibble,tibble)
98+
importFrom(tidyr,drop_na)
8899
importFrom(tidyr,expand_grid)
89100
importFrom(tidyr,pivot_wider)
90101
importFrom(tidyr,unnest)

R/forecaster.R

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,39 @@ perform_sanity_checks <- function(epi_data,
4343
#' confirm that there's enough data to run this model
4444
#' @description
4545
#' epipredict is a little bit fragile about having enough data to train; we want
46-
#' to be able to return a null result rather than error out; this check say to
47-
#' return a null
46+
#' to be able to return a null result rather than error out.
4847
#' @param epi_data the input data
49-
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
50-
#' this trains on one sample; the default is set so that `linear_reg` isn't
51-
#' rank deficient)
5248
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
5349
#' @param args_input the input as supplied to `forecaster_pred`; lags is the
5450
#' important argument, which may or may not be defined, with the default
5551
#' coming from `arx_args_list`
52+
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
53+
#' this trains on one sample; the default is set so that `linear_reg` isn't
54+
#' rank deficient)
55+
#' @importFrom tidyr drop_na
5656
#' @export
57-
confirm_insufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
57+
confirm_sufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
5858
if (!is.null(args_input$lags)) {
5959
lag_max <- max(args_input$lags)
6060
} else {
6161
lag_max <- 14 # default value of 2 weeks
6262
}
63+
64+
# TODO: Buffer should probably be 2 * n(lags) * n(predictors). But honestly,
65+
# this needs to be fixed in epipredict itself, see
66+
# https://github.com/cmu-delphi/epipredict/issues/106.
67+
6368
return(
64-
is.infinite(ahead) ||
65-
as.integer(max(epi_data$time_value) - min(epi_data$time_value)) <=
66-
lag_max + ahead + buffer
69+
!is.infinite(ahead) &&
70+
epi_data %>%
71+
drop_na() %>%
72+
group_by(geo_value) %>%
73+
summarise(has_enough_data = n_distinct(time_value) >= lag_max + ahead + buffer) %>%
74+
pull(has_enough_data) %>%
75+
any()
6776
)
6877
}
78+
6979
# TODO replace with `step_arx_forecaster`
7080
#' add the default steps for arx_forecaster
7181
#' @description
@@ -149,7 +159,11 @@ run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
149159
latest <- get_test_data(recipe = preproc, x = epi_data)
150160
pred <- predict(workflow, latest)
151161
# the forecast_date may currently be the max time_value
152-
true_forecast_date <- attributes(epi_data)$metadata$as_of
162+
as_of <- attributes(epi_data)$metadata$as_of
163+
if (is.null(as_of)) {
164+
as_of <- max(epi_data$time_value)
165+
}
166+
true_forecast_date <- as_of
153167
return(format_storage(pred, true_forecast_date))
154168
}
155169

@@ -176,6 +190,8 @@ run_workflow_and_format <- function(preproc, postproc, trainer, epi_data) {
176190
#' contain `ahead`
177191
#' @param forecaster_args_names a bit of a hack around targets, it contains
178192
#' the names of the `forecaster_args`.
193+
#' @param date_range_step_size the step size (in days) to use when generating
194+
#' the forecast dates.
179195
#' @importFrom epiprocess epix_slide
180196
#' @importFrom cli cli_abort
181197
#' @importFrom rlang !!
@@ -187,7 +203,8 @@ forecaster_pred <- function(data,
187203
slide_training = 0,
188204
n_training_pad = 5,
189205
forecaster_args = list(),
190-
forecaster_args_names = list()) {
206+
forecaster_args_names = list(),
207+
date_range_step_size = 1L) {
191208
archive <- data
192209
if (length(forecaster_args) > 0) {
193210
names(forecaster_args) <- forecaster_args_names
@@ -210,25 +227,47 @@ forecaster_pred <- function(data,
210227
# restrict the dataset to areas where training is possible
211228
start_date <- min(archive$DT$time_value) + net_slide_training
212229
end_date <- max(archive$DT$time_value) - forecaster_args$ahead
213-
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = 1)
230+
valid_predict_dates <- seq.Date(from = start_date, to = end_date, by = date_range_step_size)
214231

215232
# first generate the forecasts
216233
before <- n_training + n_training_pad - 1
217-
## TODO epix_slide doesn't support infinite `before`
234+
## TODO: epix_slide doesn't support infinite `before`
218235
## https://github.com/cmu-delphi/epiprocess/issues/219
219236
if (before == Inf) before <- 365L * 10000
220237
res <- epix_slide(archive,
221238
function(data, gk, rtv, ...) {
222-
do.call(
223-
forecaster,
224-
append(
225-
list(
226-
epi_data = data,
227-
outcome = outcome,
228-
extra_sources = extra_sources
229-
),
230-
forecaster_args
231-
)
239+
# TODO: Can we get rid of this tryCatch and instead hook it up to targets
240+
# error handling or something else?
241+
# https://github.com/cmu-delphi/exploration-tooling/issues/41
242+
tryCatch(
243+
{
244+
do.call(
245+
forecaster,
246+
append(
247+
list(
248+
epi_data = data,
249+
outcome = outcome,
250+
extra_sources = extra_sources
251+
),
252+
forecaster_args
253+
)
254+
)
255+
},
256+
error = function(e) {
257+
if (interactive()) {
258+
browser()
259+
} else {
260+
dump_vars <- list(
261+
data = data,
262+
rtv = rtv,
263+
forecaster = forecaster,
264+
forecaster_args = forecaster_args,
265+
e = e
266+
)
267+
saveRDS(dump_vars, "forecaster_pred_error.rds")
268+
e
269+
}
270+
}
232271
)
233272
},
234273
before = before,

R/forecaster_flatline.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ flatline_fc <- function(epi_data,
1414
quantile_levels = covidhub_probs(),
1515
...) {
1616
# perform any preprocessing not supported by epipredict
17+
# this is a temp fix until a real fix gets put into epipredict
18+
epi_data <- clear_lastminute_nas(epi_data)
1719
# one that every forecaster will need to handle: how to manage max(time_value)
1820
# that's older than the `as_of` date
1921
epidataAhead <- extend_ahead(epi_data, ahead)
@@ -23,7 +25,7 @@ flatline_fc <- function(epi_data,
2325
effective_ahead <- epidataAhead[[2]]
2426
args_input <- list(...)
2527
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
26-
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
28+
if (!confirm_sufficient_data(epi_data, effective_ahead, args_input)) {
2729
null_result <- tibble(
2830
geo_value = character(),
2931
forecast_date = lubridate::Date(),
@@ -48,6 +50,9 @@ flatline_fc <- function(epi_data,
4850
# since this is just the flatline, we don't need much of anything
4951
res <- flatline_forecaster(epi_data, outcome = outcome, args_list = args_list)
5052
true_forecast_date <- attributes(epi_data)$metadata$as_of
53+
if (is.null(true_forecast_date)) {
54+
true_forecast_date <- max(epi_data$time_value)
55+
}
5156
pred <- format_storage(res$predictions, true_forecast_date)
5257
# (geo_value, forecast_date, target_end_date, quantile, value)
5358
# finally, any postprocessing not supported by epipredict e.g. calibration

R/forecaster_scaled_pop.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ scaled_pop <- function(epi_data,
4949
quantile_levels = covidhub_probs(),
5050
...) {
5151
# perform any preprocessing not supported by epipredict
52+
# this is a temp fix until a real fix gets put into epipredict
53+
epi_data <- clear_lastminute_nas(epi_data)
5254
# one that every forecaster will need to handle: how to manage max(time_value)
5355
# that's older than the `as_of` date
5456
epidataAhead <- extend_ahead(epi_data, ahead)
@@ -58,7 +60,7 @@ scaled_pop <- function(epi_data,
5860
effective_ahead <- epidataAhead[[2]]
5961
args_input <- list(...)
6062
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
61-
if (confirm_insufficient_data(epi_data, effective_ahead, args_input)) {
63+
if (!confirm_sufficient_data(epi_data, effective_ahead, args_input)) {
6264
null_result <- tibble(
6365
geo_value = character(),
6466
forecast_date = lubridate::Date(),
@@ -73,6 +75,7 @@ scaled_pop <- function(epi_data,
7375
args_list <- do.call(arx_args_list, args_input)
7476
# if you want to ignore extra_sources, setting predictors is the way to do it
7577
predictors <- c(outcome, extra_sources)
78+
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
7679
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)
7780
args_list <- argsPredictorsTrainer[[1]]
7881
predictors <- argsPredictorsTrainer[[2]]
@@ -98,7 +101,6 @@ scaled_pop <- function(epi_data,
98101
# postprocessing supported by epipredict
99102
postproc <- frosting()
100103
postproc %<>% arx_postprocess(trainer, args_list)
101-
postproc
102104
if (pop_scaling) {
103105
postproc %<>% layer_population_scaling(
104106
.pred, .pred_distn,

R/latency_adjusting.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
extend_ahead <- function(epi_data, ahead) {
1212
time_values <- epi_data$time_value
1313
if (length(time_values) > 0) {
14+
as_of <- attributes(epi_data)$metadata$as_of
15+
max_time <- max(time_values)
16+
if (is.null(as_of)) {
17+
as_of <- max_time
18+
}
1419
effective_ahead <- as.integer(
15-
as.Date(attributes(epi_data)$metadata$as_of) -
16-
max(time_values) +
17-
ahead
20+
as.Date(as_of) - max_time + ahead
1821
)
1922
} else {
2023
effective_ahead <- Inf

0 commit comments

Comments
 (0)