Skip to content

Commit 8ea28c1

Browse files
committed
add step to aggregate all aheads for a given model into one obj
1 parent 3ba7ea6 commit 8ea28c1

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

R/small_utils.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ add_id <- function(df, n_adj = 2) {
2727
mutate(id = hash_animal(id, n_adj = n_adj)$words) %>%
2828
mutate(id = paste(id[1:n_adj], sep="", collapse = " "))
2929
df %<>%
30-
mutate(id = stringified) %>%
30+
mutate(parent_id = stringified$id) %>%
3131
rowwise() %>%
32-
mutate(id = paste(id, ahead, collapse = " ")) %>%
32+
mutate(id = paste(parent_id, ahead, collapse = " ")) %>%
3333
ungroup()
3434
return(df)
3535
}

covid_hosp_explore.R

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ tar_option_set(
3535
source("covid_hosp_explore/forecaster_instantiation.R")
3636
source("covid_hosp_explore/data_targets.R")
3737

38-
forecasts_and_scores <- tar_map(
38+
forecasts_and_scores_separate_aheads <- tar_map(
3939
values = forecaster_param_grids,
4040
names = id,
4141
unlist = FALSE,
4242
tar_target(
43-
name = forecast,
43+
name = forecast_by_ahead,
4444
command = {
4545
forecaster_pred(
4646
data = joined_archive_data_2022,
@@ -55,10 +55,10 @@ forecasts_and_scores <- tar_map(
5555
}
5656
),
5757
tar_target(
58-
name = score,
58+
name = score_by_ahead,
5959
command = {
6060
run_evaluation_measure(
61-
data = forecast,
61+
data = forecast_by_ahead,
6262
evaluation_data = hhs_evaluation_data,
6363
measure = list(
6464
wis = weighted_interval_score,
@@ -70,6 +70,43 @@ forecasts_and_scores <- tar_map(
7070
)
7171
)
7272

73+
forecasts_and_scores <- tar_map(
74+
values = forecaster_parent_id_map,
75+
names = parent_id,
76+
tar_target(
77+
name = forecast,
78+
command = {
79+
bind_rows(forecast_component_ids) %>%
80+
mutate(parent_forecaster = parent_id)
81+
}
82+
),
83+
tar_target(
84+
name = score,
85+
command = {
86+
bind_rows(score_component_ids) %>%
87+
mutate(parent_forecaster = parent_id)
88+
}
89+
)
90+
# tar_target(
91+
# name = score,
92+
# command = {
93+
# bind_rows(!!!component_ids) %>%
94+
# mutate(parent_forecaster = parent_id)
95+
# }
96+
# )
97+
#,
98+
# tar_combine(
99+
# name = score,
100+
# list(
101+
# forecasts_and_scores_separate_aheads[["score_by_ahead"]][filter(param_grid, parent_id == value)[["rownum"]]]
102+
# ),
103+
# command = {
104+
# bind_rows(!!!.x) %>%
105+
# mutate(parent_forecaster = name)
106+
# }
107+
# )
108+
)
109+
73110
ensemble_keys <- list(a = c(300, 15))
74111
ensembles <- list(
75112
tar_target(
@@ -89,8 +126,8 @@ ensemble_forecast <- tar_map(
89126
name = ensemble_forecast,
90127
# TODO: Needs a lookup table to select the right forecasters
91128
list(
92-
forecasts_and_scores[["forecast"]][[1]],
93-
forecasts_and_scores[["forecast"]][[2]]
129+
forecasts_and_scores_separate_aheads[["forecast_by_ahead"]][[1]],
130+
forecasts_and_scores_separate_aheads[["forecast_by_ahead"]][[2]]
94131
),
95132
command = {
96133
bind_rows(!!!.x, .id = "forecaster") %>%
@@ -124,6 +161,7 @@ ensemble_forecast <- tar_map(
124161
list(
125162
data,
126163
forecasters,
164+
forecasts_and_scores_separate_aheads,
127165
forecasts_and_scores,
128166
ensembles,
129167
ensemble_forecast

covid_hosp_explore/forecaster_instantiation.R

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,17 @@ grids <- list(
2020
# bind them together and give static ids; if you add a new field to a given
2121
# expand_grid, everything will get a new id, so it's better to add a new
2222
# expand_grid instead
23-
param_grid <- bind_rows(map(grids, add_id)) %>% relocate(id, .after = last_col())
23+
param_grid <- bind_rows(map(grids, add_id)) %>%
24+
relocate(parent_id, id, .after = last_col())
2425

25-
forecaster_param_grids <- make_target_param_grid(param_grid)
26+
forecaster_parent_id_map <- param_grid %>%
27+
group_by(parent_id) %>%
28+
summarize(
29+
forecast_component_ids = list(syms(paste0("forecast_by_ahead_", gsub(" ", ".", id, fixed = TRUE)))),
30+
score_component_ids = list(syms(paste0("score_by_ahead_", gsub(" ", ".", id, fixed = TRUE))))
31+
)
32+
33+
forecaster_param_grids <- make_target_param_grid(select(param_grid, -parent_id))
2634

2735
# not actually used downstream, this is for lookup during plotting and human evaluation
2836
forecasters <- list(

run.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ tar_make()
5353

5454
# Prevent functions defined in /R dir from being loaded unnecessarily
5555
options(shiny.autoload.r=FALSE)
56-
forecaster_options <- tar_read(forecasters)[["id"]]
56+
forecaster_options <- unique(tar_read(forecasters)[["parent_id"]])
5757
# Map forecaster names to score files
5858
forecaster_options <- setNames(
5959
paste0("score_", gsub(" ", ".", forecaster_options)),

0 commit comments

Comments
 (0)