Skip to content

Commit 99ed809

Browse files
committed
refactor: big reduction in redundant pipeline code
* make targets factories in targets_utils.R * simplify covid_hosp_explore and flu_hosp_explore
1 parent 70680df commit 99ed809

14 files changed

+517
-798
lines changed

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ export(format_storage)
1515
export(id_ahead_ensemble_grid)
1616
export(interval_coverage)
1717
export(lookup_ids)
18+
export(make_data_targets)
19+
export(make_ensemble_targets)
20+
export(make_external_names_and_scores)
21+
export(make_forecasts_and_scores)
22+
export(make_forecasts_and_scores_by_ahead)
23+
export(make_shared_grids)
1824
export(make_target_param_grid)
1925
export(manage_S3_forecast_cache)
2026
export(overprediction)

R/forecaster.R

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,26 @@ perform_sanity_checks <- function(epi_data,
4545
#' epipredict is a little bit fragile about having enough data to train; we want
4646
#' to be able to return a null result rather than error out.
4747
#' @param epi_data the input data
48-
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
49-
#' this trains on one sample; the default is set so that `linear_reg` isn't
50-
#' rank deficient)
5148
#' @param ahead the effective ahead; may be infinite if there isn't enough data.
5249
#' @param args_input the input as supplied to `forecaster_pred`; lags is the
5350
#' important argument, which may or may not be defined, with the default
5451
#' coming from `arx_args_list`
55-
#'
56-
#' # TODO: Buffer should probably be 2 * n(lags) * n(predictors).
52+
#' @param buffer how many training data to insist on having (e.g. if `buffer=1`,
53+
#' this trains on one sample; the default is set so that `linear_reg` isn't
54+
#' rank deficient)
5755
#'
5856
#' @export
59-
confirm_sufficient_data <- function(epi_data, ahead, args_input, buffer = 15) {
57+
confirm_sufficient_data <- function(epi_data, ahead, args_input, buffer = 20) {
6058
if (!is.null(args_input$lags)) {
6159
lag_max <- max(args_input$lags)
6260
} else {
6361
lag_max <- 14 # default value of 2 weeks
6462
}
63+
64+
# TODO: Buffer should probably be 2 * n(lags) * n(predictors). But honestly,
65+
# this needs to be fixed in epipredict itself, see
66+
# https://github.com/cmu-delphi/epipredict/issues/106.
67+
6568
return(
6669
!is.infinite(ahead) &&
6770
epi_data %>%
@@ -233,6 +236,7 @@ forecaster_pred <- function(data,
233236
function(data, gk, rtv, ...) {
234237
# TODO: Can we get rid of this tryCatch and instead hook it up to targets
235238
# error handling or something else?
239+
# https://github.com/cmu-delphi/exploration-tooling/issues/41
236240
tryCatch(
237241
{
238242
do.call(
@@ -259,6 +263,7 @@ forecaster_pred <- function(data,
259263
e = e
260264
)
261265
saveRDS(dump_vars, "forecaster_pred_error.rds")
266+
e
262267
}
263268
}
264269
)

R/forecaster_scaled_pop.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ scaled_pop <- function(epi_data,
7373
args_list <- do.call(arx_args_list, args_input)
7474
# if you want to ignore extra_sources, setting predictors is the way to do it
7575
predictors <- c(outcome, extra_sources)
76-
# TODO: Partial match quantile_level coming from here
76+
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
7777
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)
7878
args_list <- argsPredictorsTrainer[[1]]
7979
predictors <- argsPredictorsTrainer[[2]]

R/targets_utils.R

Lines changed: 281 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
#' @export
77
#' @importFrom rlang syms
88
make_target_param_grid <- function(param_grid) {
9-
param_grid %<>% mutate(forecaster = syms(forecaster))
10-
param_grid %<>% mutate(trainer = syms(trainer))
9+
param_grid %>%
10+
select(-any_of("parent_id")) %>%
11+
mutate(forecaster = syms(forecaster)) %>%
12+
mutate(trainer = syms(trainer))
1113
list_of_params <- lists_of_real_values(param_grid)
1214
list_names <- map(list_of_params, names)
1315
tibble(
@@ -27,3 +29,280 @@ lists_of_real_values <- function(param_grid) {
2729
}
2830
map(full_lists, filter_nonvalues)
2931
}
32+
33+
#' Make common targets for fetching data
34+
#'
35+
#' Relies on the following globals:
36+
#' - `hhs_signal`
37+
#' - `chng_signal`
38+
#' - `geo_type`
39+
#' - `time_type`
40+
#' - `geo_values`
41+
#' - `time_values`
42+
#' - `issues`
43+
#' - `fetch_args`
44+
#'
45+
#' @export
46+
make_data_targets <- function() {
47+
list(
48+
tar_target(
49+
name = hhs_latest_data,
50+
command = {
51+
epidatr::pub_covidcast(
52+
source = "hhs",
53+
signals = hhs_signal,
54+
geo_type = geo_type,
55+
time_type = time_type,
56+
geo_values = geo_values,
57+
time_values = time_values,
58+
fetch_args = fetch_args
59+
)
60+
}
61+
),
62+
tar_target(
63+
name = chng_latest_data,
64+
command = {
65+
epidatr::pub_covidcast(
66+
source = "chng",
67+
signals = chng_signal,
68+
geo_type = geo_type,
69+
time_type = time_type,
70+
geo_values = geo_values,
71+
time_values = time_values,
72+
fetch_args = fetch_args
73+
)
74+
}
75+
),
76+
tar_target(
77+
name = hhs_evaluation_data,
78+
command = {
79+
hhs_latest_data %>%
80+
rename(
81+
actual = value,
82+
target_end_date = time_value
83+
)
84+
}
85+
),
86+
tar_target(
87+
name = hhs_latest_data_2022,
88+
command = {
89+
hhs_latest_data %>%
90+
filter(time_value >= "2022-01-01")
91+
}
92+
),
93+
tar_target(
94+
name = chng_latest_data_2022,
95+
command = {
96+
chng_latest_data %>%
97+
filter(time_value >= "2022-01-01")
98+
}
99+
),
100+
tar_target(
101+
name = hhs_archive_data_2022,
102+
command = {
103+
epidatr::pub_covidcast(
104+
source = "hhs",
105+
signals = hhs_signal,
106+
geo_type = geo_type,
107+
time_type = time_type,
108+
geo_values = geo_values,
109+
time_values = time_values,
110+
issues = issues,
111+
fetch_args = fetch_args
112+
)
113+
}
114+
),
115+
tar_target(
116+
name = chng_archive_data_2022,
117+
command = {
118+
epidatr::pub_covidcast(
119+
source = "chng",
120+
signals = chng_signal,
121+
geo_type = geo_type,
122+
time_type = time_type,
123+
geo_values = geo_values,
124+
time_values = time_values,
125+
issues = issues,
126+
fetch_args = fetch_args
127+
)
128+
}
129+
),
130+
tar_target(
131+
name = joined_archive_data_2022,
132+
command = {
133+
hhs_archive_data_2022 %<>%
134+
select(geo_value, time_value, value, issue) %>%
135+
rename("hhs" := value) %>%
136+
rename(version = issue) %>%
137+
as_epi_archive(
138+
geo_type = geo_type,
139+
time_type = time_type,
140+
compactify = TRUE
141+
)
142+
chng_archive_data_2022 %<>%
143+
select(geo_value, time_value, value, issue) %>%
144+
rename("chng" := value) %>%
145+
rename(version = issue) %>%
146+
as_epi_archive(
147+
geo_type = geo_type,
148+
time_type = time_type,
149+
compactify = TRUE
150+
)
151+
epix_merge(hhs_archive_data_2022, chng_archive_data_2022, sync = "locf")$DT %>%
152+
drop_na() %>%
153+
filter(!geo_value %in% c("as", "pr", "vi", "gu", "mp")) %>%
154+
epiprocess::as_epi_archive()
155+
}
156+
)
157+
)
158+
}
159+
160+
#' Make common targets for forecasting experiments
161+
#' @export
162+
make_shared_grids <- function() {
163+
list(
164+
tidyr::expand_grid(
165+
forecaster = "scaled_pop",
166+
trainer = c("linreg", "quantreg"),
167+
ahead = 1:4,
168+
pop_scaling = c(FALSE)
169+
),
170+
tidyr::expand_grid(
171+
forecaster = "scaled_pop",
172+
trainer = c("linreg", "quantreg"),
173+
ahead = 5:7,
174+
lags = list(c(0, 3, 5, 7, 14), c(0, 7, 14)),
175+
pop_scaling = c(FALSE)
176+
)
177+
)
178+
}
179+
180+
#' Make forecasts and scores by ahead targets
181+
#' @export
182+
make_forecasts_and_scores_by_ahead <- function() {
183+
tar_map(
184+
values = targets_param_grid,
185+
names = id,
186+
unlist = FALSE,
187+
tar_target_raw(
188+
name = ONE_AHEAD_FORECAST_NAME,
189+
command = expression(
190+
forecaster_pred(
191+
data = joined_archive_data_2022,
192+
outcome = "hhs",
193+
extra_sources = "",
194+
forecaster = forecaster,
195+
n_training_pad = 30L,
196+
forecaster_args = params,
197+
forecaster_args_names = param_names,
198+
date_range_step_size = 7L
199+
)
200+
)
201+
),
202+
tar_target_raw(
203+
name = ONE_AHEAD_SCORE_NAME,
204+
command = expression(
205+
run_evaluation_measure(
206+
data = forecast_by_ahead,
207+
evaluation_data = hhs_evaluation_data,
208+
measure = list(
209+
wis = weighted_interval_score,
210+
ae = absolute_error,
211+
cov_80 = interval_coverage(0.8)
212+
)
213+
)
214+
)
215+
)
216+
)
217+
}
218+
219+
#' Make forecasts and scores targets
220+
#' @export
221+
make_forecasts_and_scores <- function() {
222+
tar_map(
223+
values = forecaster_parent_id_map,
224+
names = parent_id,
225+
tar_target(
226+
name = forecast,
227+
command = {
228+
bind_rows(forecast_component_ids) %>%
229+
mutate(parent_forecaster = parent_id)
230+
}
231+
),
232+
tar_target(
233+
name = score,
234+
command = {
235+
bind_rows(score_component_ids) %>%
236+
mutate(parent_forecaster = parent_id)
237+
}
238+
)
239+
)
240+
}
241+
242+
#' Make ensemble targets
243+
#' @export
244+
make_ensemble_targets <- function() {
245+
list()
246+
}
247+
248+
249+
#' Make external names and scores targets
250+
#' @export
251+
make_external_names_and_scores <- function() {
252+
external_scores_path <- Sys.getenv("EXTERNAL_SCORES_PATH", "")
253+
if (external_scores_path != "") {
254+
external_names_and_scores <- list(
255+
tar_target(
256+
name = external_scores_df,
257+
command = {
258+
readRDS(external_scores_path) %>%
259+
group_by(forecaster) %>%
260+
targets::tar_group()
261+
},
262+
iteration = "group",
263+
garbage_collection = TRUE
264+
),
265+
tar_target(
266+
name = external_names,
267+
command = {
268+
external_scores_df %>%
269+
group_by(forecaster) %>%
270+
group_keys() %>%
271+
pull(forecaster)
272+
},
273+
garbage_collection = TRUE
274+
),
275+
tar_target(
276+
name = external_scores,
277+
pattern = map(external_scores_df),
278+
command = {
279+
external_scores_df
280+
},
281+
# This step causes the pipeline to exit with an error, apparently due to
282+
# running out of memory. Run this in series on a non-parallel `crew`
283+
# controller to avoid.
284+
# https://books.ropensci.org/targets/crew.html#heterogeneous-workers
285+
resources = tar_resources(
286+
crew = tar_resources_crew(controller = "serial_controller")
287+
),
288+
memory = "transient",
289+
garbage_collection = TRUE
290+
)
291+
)
292+
} else {
293+
external_names_and_scores <- list(
294+
tar_target(
295+
name = external_names,
296+
command = {
297+
c()
298+
}
299+
),
300+
tar_target(
301+
name = external_scores,
302+
command = {
303+
data.frame()
304+
}
305+
)
306+
)
307+
}
308+
}
File renamed without changes.

_targets.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
covid_hosp_explore:
2-
script: covid_hosp_explore.R
3-
store: covid_hosp_explore
4-
use_crew: yes
2+
script: covid_hosp_explore.R
3+
store: covid_hosp_explore
4+
use_crew: no
55
flu_hosp_explore:
6-
script: flu_hosp_explore.R
7-
store: flu_hosp_explore
8-
use_crew: yes
6+
script: flu_hosp_explore.R
7+
store: flu_hosp_explore
8+
use_crew: no
99
flu_hosp_prod:
10-
script: flu_hosp_prod.R
11-
store: flu_hosp_prod
12-
use_crew: yes
10+
script: flu_hosp_prod.R
11+
store: flu_hosp_prod
12+
use_crew: no
1313
covid_hosp_prod:
14-
script: covid_hosp_prod.R
15-
store: covid_hosp_prod
16-
use_crew: yes
14+
script: covid_hosp_prod.R
15+
store: covid_hosp_prod
16+
use_crew: no

0 commit comments

Comments
 (0)