Skip to content

Commit 98c1370

Browse files
authored
Merge pull request #40 from cmu-delphi/ndefries/baseline-display-and-scaling
Support baseline score scaling and baseline forecaster selection
2 parents edeef98 + 22ca3ea commit 98c1370

File tree

5 files changed

+93
-27
lines changed

5 files changed

+93
-27
lines changed

R/small_utils.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ add_id <- function(df, n_adj = 2) {
2727
mutate(id = hash_animal(id, n_adj = n_adj)$words) %>%
2828
mutate(id = paste(id[1:n_adj], sep="", collapse = " "))
2929
df %<>%
30-
mutate(id = stringified) %>%
30+
mutate(parent_id = stringified$id) %>%
3131
rowwise() %>%
32-
mutate(id = paste(id, ahead, collapse = " ")) %>%
32+
mutate(id = paste(parent_id, ahead, collapse = " ")) %>%
3333
ungroup()
3434
return(df)
3535
}

app.R

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,19 @@ shinyApp(
6767
choices = forecaster_options,
6868
multiple = TRUE
6969
),
70+
selectInput("baseline",
71+
"Baseline forecaster:",
72+
choices = forecaster_options,
73+
multiple = FALSE
74+
),
75+
checkboxInput(
76+
"scale_by_baseline",
77+
"Scale by baseline forecaster",
78+
value = FALSE,
79+
),
7080
radioButtons(
7181
"selected_metric",
72-
"Metric:",
82+
"Error metric:",
7383
c(
7484
"Mean WIS" = "wis",
7585
# "Mean WIS per 100k" = "wis_per_100k",
@@ -88,12 +98,10 @@ shinyApp(
8898
choices = c("forecaster", "ahead", "geo_value"),
8999
multiple = TRUE
90100
),
91-
radioButtons("facets_share_scale",
92-
"Share y scale between subplots:",
93-
c(
94-
"Yes" = "fixed",
95-
"No" = "free_y"
96-
)
101+
checkboxInput(
102+
"facets_share_scale",
103+
"Share y scale between subplots",
104+
value = TRUE,
97105
),
98106
sliderInput("selected_forecast_date_range",
99107
"Forecast date range:",
@@ -125,9 +133,10 @@ shinyApp(
125133
},
126134
server = function(input, output, session) {
127135
filtered_scorecards_reactive <- reactive({
128-
if (length(input$selected_forecasters) == 0) { return(data.frame()) }
136+
agg_forecasters <- unique(c(input$selected_forecasters, input$baseline))
137+
if (length(agg_forecasters) == 0) { return(data.frame()) }
129138

130-
processed_evaluations_internal <- lapply(input$selected_forecasters, function(forecaster) {
139+
processed_evaluations_internal <- lapply(agg_forecasters, function(forecaster) {
131140
load_forecast_data(forecaster) %>>%
132141
filter(
133142
.data$forecast_date %>>% between(.env$input$selected_forecast_date_range[[1L]], .env$input$selected_forecast_date_range[[2L]]),
@@ -141,8 +150,35 @@ shinyApp(
141150
input_df <- filtered_scorecards_reactive()
142151
if (nrow(input_df) == 0) { return() }
143152

153+
# Normalize by baseline scores. This is not relevant for coverage, which is compared
154+
# to the nominal confidence level.
155+
if (input$scale_by_baseline && input$selected_metric != "ic80") {
156+
# These merge keys are overkill; this should be fully specified by
157+
# c("forecast_date", "target_end_date", "geo_value")
158+
merge_keys <- c("forecast_date", "target_end_date", "ahead", "issue", "geo_value")
159+
# Load selected baseline
160+
baseline_scores <- load_forecast_data(input$baseline)[, c(merge_keys, input$selected_metric)]
161+
162+
baseline_scores$score_baseline <- baseline_scores[[input$selected_metric]]
163+
baseline_scores[[input$selected_metric]] <- NULL
164+
165+
# Add on reference scores from baseline forecaster.
166+
# Note that this drops any scores where there isn't a corresponding
167+
# baseline value. If a forecaster and a baseline cover
168+
# non-overlapping dates or use different aheads, the forecaster will
169+
# not be shown.
170+
input_df <- inner_join(
171+
input_df, baseline_scores,
172+
by = merge_keys, suffix = c("", "")
173+
)
174+
# Scale score by baseline forecaster
175+
input_df[[input$selected_metric]] <- input_df[[input$selected_metric]] / input_df$score_baseline
176+
}
177+
178+
144179
x_tick_angle <- list(tickangle = -30)
145180
facet_x_tick_angles <- setNames(rep(list(x_tick_angle), 10), paste0("xaxis", 1:10))
181+
scale_type <- ifelse(input$facets_share_scale, "fixed", "free_y" )
146182

147183
input_df %>>%
148184
# Aggregate scores over all geos
@@ -180,9 +216,9 @@ shinyApp(
180216
`+`(if (length(input$facet_vars) == 0L) {
181217
theme()
182218
} else if (length(input$facet_vars) == 1L) {
183-
facet_wrap(input$facet_vars, scales = input$facets_share_scale)
219+
facet_wrap(input$facet_vars, scales = scale_type)
184220
} else {
185-
facet_grid(as.formula(paste0(input$facet_vars[[1L]], " ~ ", paste(collapse = " + ", input$facet_vars[-1L]))), scales = input$facets_share_scale)
221+
facet_grid(as.formula(paste0(input$facet_vars[[1L]], " ~ ", paste(collapse = " + ", input$facet_vars[-1L]))), scales = scale_type)
186222
}) %>>%
187223
ggplotly() %>>%
188224
{inject(layout(., hovermode = "x unified", legend = list(orientation = "h", title = list(text = "forecaster")), xaxis = x_tick_angle, !!!facet_x_tick_angles))}

covid_hosp_explore.R

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ tar_option_set(
3535
source("covid_hosp_explore/forecaster_instantiation.R")
3636
source("covid_hosp_explore/data_targets.R")
3737

38-
forecasts_and_scores <- tar_map(
38+
forecasts_and_scores_by_ahead <- tar_map(
3939
values = forecaster_param_grids,
4040
names = id,
4141
unlist = FALSE,
42-
tar_target(
43-
name = forecast,
44-
command = {
42+
tar_target_raw(
43+
name = ONE_AHEAD_FORECAST_NAME,
44+
command = expression(
4545
forecaster_pred(
4646
data = joined_archive_data_2022,
4747
outcome = "hhs",
@@ -52,20 +52,39 @@ forecasts_and_scores <- tar_map(
5252
forecaster_args = params,
5353
forecaster_args_names = param_names
5454
)
55-
}
55+
)
5656
),
57-
tar_target(
58-
name = score,
59-
command = {
57+
tar_target_raw(
58+
name = ONE_AHEAD_SCORE_NAME,
59+
command = expression(
6060
run_evaluation_measure(
61-
data = forecast,
61+
data = forecast_by_ahead,
6262
evaluation_data = hhs_evaluation_data,
6363
measure = list(
6464
wis = weighted_interval_score,
6565
ae = absolute_error,
6666
ic80 = interval_coverage(0.8)
6767
)
6868
)
69+
)
70+
)
71+
)
72+
73+
forecasts_and_scores <- tar_map(
74+
values = forecaster_parent_id_map,
75+
names = parent_id,
76+
tar_target(
77+
name = forecast,
78+
command = {
79+
bind_rows(forecast_component_ids) %>%
80+
mutate(parent_forecaster = parent_id)
81+
}
82+
),
83+
tar_target(
84+
name = score,
85+
command = {
86+
bind_rows(score_component_ids) %>%
87+
mutate(parent_forecaster = parent_id)
6988
}
7089
)
7190
)
@@ -89,8 +108,8 @@ ensemble_forecast <- tar_map(
89108
name = ensemble_forecast,
90109
# TODO: Needs a lookup table to select the right forecasters
91110
list(
92-
forecasts_and_scores[["forecast"]][[1]],
93-
forecasts_and_scores[["forecast"]][[2]]
111+
forecasts_and_scores_by_ahead[["forecast_by_ahead"]][[1]],
112+
forecasts_and_scores_by_ahead[["forecast_by_ahead"]][[2]]
94113
),
95114
command = {
96115
bind_rows(!!!.x, .id = "forecaster") %>%
@@ -124,6 +143,7 @@ ensemble_forecast <- tar_map(
124143
list(
125144
data,
126145
forecasters,
146+
forecasts_and_scores_by_ahead,
127147
forecasts_and_scores,
128148
ensembles,
129149
ensemble_forecast

covid_hosp_explore/forecaster_instantiation.R

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,19 @@ grids <- list(
2020
# bind them together and give static ids; if you add a new field to a given
2121
# expand_grid, everything will get a new id, so it's better to add a new
2222
# expand_grid instead
23-
param_grid <- bind_rows(map(grids, add_id)) %>% relocate(id, .after = last_col())
23+
param_grid <- bind_rows(map(grids, add_id)) %>%
24+
relocate(parent_id, id, .after = last_col())
2425

25-
forecaster_param_grids <- make_target_param_grid(param_grid)
26+
ONE_AHEAD_FORECAST_NAME <- "forecast_by_ahead"
27+
ONE_AHEAD_SCORE_NAME <- "score_by_ahead"
28+
forecaster_parent_id_map <- param_grid %>%
29+
group_by(parent_id) %>%
30+
summarize(
31+
forecast_component_ids = list(syms(paste0(ONE_AHEAD_FORECAST_NAME, "_", gsub(" ", ".", id, fixed = TRUE)))),
32+
score_component_ids = list(syms(paste0(ONE_AHEAD_SCORE_NAME, "_", gsub(" ", ".", id, fixed = TRUE))))
33+
)
34+
35+
forecaster_param_grids <- make_target_param_grid(select(param_grid, -parent_id))
2636

2737
# not actually used downstream, this is for lookup during plotting and human evaluation
2838
forecasters <- list(

run.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ tar_make()
5353

5454
# Prevent functions defined in /R dir from being loaded unnecessarily
5555
options(shiny.autoload.r=FALSE)
56-
forecaster_options <- tar_read(forecasters)[["id"]]
56+
forecaster_options <- unique(tar_read(forecasters)[["parent_id"]])
5757
# Map forecaster names to score files
5858
forecaster_options <- setNames(
5959
paste0("score_", gsub(" ", ".", forecaster_options)),

0 commit comments

Comments
 (0)