Skip to content

Commit be5084c

Browse files
committed
wip: production pipeline
1 parent 97232cf commit be5084c

File tree

5 files changed

+332
-1
lines changed

5 files changed

+332
-1
lines changed

R/plotting.R

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
library(dplyr)
2+
library(evalcast)
3+
library(ggplot2)
4+
library(magrittr)
5+
library(tidyr)
6+
7+
8+
get_truth_data <- function(exclude_geos, ...) {
9+
download_args <- list(...)
10+
truth_data <- do.call(download_signal, download_args)
11+
truth_data %<>% filter(!(.data$geo_value %in% exclude_geos))
12+
truth_data <- truth_data %>%
13+
dplyr::select(.data$geo_value, .data$time_value, .data$value) %>%
14+
dplyr::rename(target_end_date = .data$time_value)
15+
return(truth_data %>% tibble())
16+
}
17+
18+
get_quantiles_df <- function(predictions_cards, intervals = c(.5, .9), ...) {
19+
predictions_cards <- predictions_cards %>%
20+
dplyr::select(
21+
.data$geo_value, .data$quantile,
22+
.data$value, .data$forecaster, .data$forecast_date,
23+
.data$target_end_date
24+
)
25+
26+
lower_bounds <- predictions_cards %>%
27+
select(.data$quantile) %>%
28+
filter(.data$quantile < 0.5) %>%
29+
unique() %>%
30+
pull()
31+
quantiles_to_plot <- as.integer(sort(
32+
round(500L * (1 + intervals %o% c(-1L, 1L)))
33+
))
34+
35+
quantiles_df <- predictions_cards %>%
36+
filter(as.integer(round(.data$quantile * 1000)) %in% c(quantiles_to_plot)) %>%
37+
mutate(
38+
endpoint_type = if_else(.data$quantile < 0.5, "lower", "upper"),
39+
alp = if_else(.data$endpoint_type == "lower",
40+
format(2 * .data$quantile, digits = 3, nsmall = 3),
41+
format(2 * (1 - .data$quantile), digits = 3, nsmall = 3)
42+
),
43+
interval = forcats::fct_rev(
44+
paste0((1 - as.numeric(.data$alp)) * 100, "%")
45+
)
46+
) %>%
47+
select(-.data$quantile, -.data$alp) %>%
48+
pivot_wider(names_from = "endpoint_type", values_from = "value")
49+
50+
return(quantiles_df)
51+
}
52+
53+
get_points_df <- function(predictions_cards) {
54+
points_df <- predictions_cards %>%
55+
filter(as.integer(round(.data$quantile * 1000)) == 500L |
56+
is.na(.data$quantile))
57+
if (any(is.na(points_df$quantile))) {
58+
points_df <- points_df %>%
59+
pivot_wider(names_from = "quantile", values_from = "value") %>%
60+
mutate(value = if_else(!is.na(.data$`NA`), .data$`NA`, .data$`0.5`)) %>%
61+
select(-.data$`0.5`, -.data$`NA`)
62+
} else {
63+
points_df <- points_df %>%
64+
select(-.data$quantile)
65+
}
66+
67+
return(points_df)
68+
}
69+
70+
plot_quantiles <- function(g, quantiles_df) {
71+
n_quantiles <- nlevels(quantiles_df$interval)
72+
l_quantiles <- levels(quantiles_df$interval)
73+
74+
alp <- c(.4, .2, .1)
75+
for (qq in n_quantiles:1) {
76+
g <- g +
77+
geom_ribbon(
78+
data = quantiles_df %>%
79+
filter(.data$interval == l_quantiles[qq]),
80+
mapping = aes(
81+
ymin = .data$lower,
82+
ymax = .data$upper,
83+
group = interaction(.data$forecast_date, .data$forecaster),
84+
color = NULL
85+
),
86+
alpha = alp[qq]
87+
)
88+
}
89+
90+
return(g)
91+
}
92+
93+
plot_points <- function(g, points_df) {
94+
g <- g + geom_point(
95+
data = points_df,
96+
mapping = aes(
97+
y = .data$value,
98+
group = interaction(.data$forecast_date, .data$forecaster)
99+
),
100+
size = 0.125
101+
)
102+
103+
return(g)
104+
}
105+
106+
plot_state_forecasters <- function(predictions_cards, exclude_geos = c(), start_day = NULL, ncol = 5, offline_signal_dir = NULL) {
107+
if (nrow(predictions_cards) == 0) {
108+
return(NULL)
109+
}
110+
111+
td1 <- get_truth_data(exclude_geos = exclude_geos, data_source = "hhs", signal = "confirmed_admissions_covid_1d", start_day = start_day, geo_type = "state", offline_signal_dir = offline_signal_dir)
112+
td1$data_source <- "hhs"
113+
td2 <- get_truth_data(exclude_geos = exclude_geos, data_source = "jhu-csse", signal = "confirmed_7dav_incidence_num", start_day = start_day, geo_type = "state", offline_signal_dir = offline_signal_dir)
114+
td2$data_source <- "jhu"
115+
td1.max <- td1 %>%
116+
group_by(geo_value) %>%
117+
summarize(max_value = max(value))
118+
td2.max <- td2 %>%
119+
group_by(geo_value) %>%
120+
summarize(max_value = max(value))
121+
td2.max <- td2.max %>%
122+
left_join(td1.max, by = "geo_value", suffix = c(".2", ".1")) %>%
123+
mutate(max_ratio = max_value.1 / max_value.2)
124+
td2 <- td2 %>%
125+
left_join(td2.max, by = "geo_value") %>%
126+
mutate(scaled_value = value * max_ratio)
127+
td1 <- td1 %>% mutate(forecaster = "hhs hosp truth")
128+
td2 <- td2 %>% mutate(forecaster = "jhu cases truth")
129+
130+
# Setup plot
131+
g <- ggplot(td1, mapping = aes(x = .data$target_end_date, color = .data$forecaster, fill = .data$forecaster))
132+
133+
points_df <- get_points_df(predictions_cards)
134+
g <- plot_points(g, points_df)
135+
136+
quantiles_df <- get_quantiles_df(predictions_cards)
137+
g <- plot_quantiles(g, quantiles_df)
138+
139+
# Plot truth data by geo
140+
g <- g +
141+
geom_line(mapping = aes(y = .data$value)) +
142+
geom_line(data = td2, mapping = aes(x = .data$target_end_date, y = .data$scaled_value)) +
143+
facet_wrap(~ .data$geo_value, scales = "free_y", ncol = ncol, drop = TRUE) +
144+
theme(legend.position = "top", legend.text = element_text(size = 7))
145+
146+
return(g)
147+
}
148+
149+
plot_nation_forecasters <- function(predictions_cards, exclude_geos = c(), start_day = NULL, ncol = 5, offline_signal_dir = NULL) {
150+
if (nrow(predictions_cards) == 0) {
151+
return(NULL)
152+
}
153+
154+
td1 <- get_truth_data(exclude_geos = exclude_geos, data_source = "hhs", signal = "confirmed_admissions_covid_1d", start_day = start_day, geo_type = "nation", offline_signal_dir = offline_signal_dir)
155+
td1$data_source <- "hhs"
156+
td2 <- get_truth_data(exclude_geos = exclude_geos, data_source = "jhu-csse", signal = "confirmed_7dav_incidence_num", start_day = start_day, geo_type = "nation", offline_signal_dir = offline_signal_dir)
157+
td2$data_source <- "jhu"
158+
td1.max <- td1 %>%
159+
summarize(max_value = max(value)) %>%
160+
pull(max_value)
161+
td2.max <- td2 %>%
162+
summarize(max_value = max(value)) %>%
163+
pull(max_value)
164+
td2 <- td2 %>%
165+
mutate(scaled_value = value * td1.max / td2.max)
166+
167+
# Setup plot
168+
g <- ggplot(td1, mapping = aes(x = .data$target_end_date))
169+
170+
quantiles_df <- get_quantiles_df(predictions_cards)
171+
g <- plot_quantiles(g, quantiles_df)
172+
173+
points_df <- get_points_df(predictions_cards)
174+
g <- plot_points(g, points_df)
175+
176+
# Plot truth data by geo
177+
g <- g +
178+
geom_line(mapping = aes(y = .data$value, color = "confirmed admissions")) +
179+
geom_line(data = td2, mapping = aes(x = .data$target_end_date, y = .data$scaled_value, color = "7day case sum")) +
180+
labs(fill = "Reported Signal") +
181+
theme(legend.position = "top", legend.text = element_text(size = 7))
182+
183+
return(g)
184+
}

covid_hosp_prod/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
*
2+
!.gitignore
3+
!meta
4+
!*.R
5+
meta/*
6+
# !meta/meta

scripts/covid_hosp_prod.R

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,92 @@
1-
# TODO
1+
# The COVID Hospitalization Production Forecasting Pipeline.
2+
#
3+
# Ran into some issues with targets:
4+
# https://github.com/ropensci/targets/discussions/666#discussioncomment-9050772
5+
6+
source("scripts/targets-common.R")
7+
8+
9+
#' Get exclusions from a JSON file for a given date
10+
#'
11+
#' @param date A date
12+
#' @param exclusions_json A JSON file with exclusions in the format:
13+
#'
14+
#' {"exclusions": {"2024-03-24": "ak,hi"}}
15+
get_exclusions <- function(date, exclusions_json = here::here("scripts", "geo_exclusions.json")) {
16+
s <- jsonlite::read_json(exclusions_json)$exclusions[[as.character(date)]]
17+
if (!is.null(s)) {
18+
return(s)
19+
}
20+
return("")
21+
}
22+
23+
forecast_generation_date <- as.character(seq.Date(as.Date("2024-01-01"), Sys.Date(), by = "1 week"))
24+
geo_exclusions <- Vectorize(get_exclusions)(forecast_generation_date)
25+
26+
tib1 <- tidyr::expand_grid(
27+
tibble(
28+
forecast_generation_date = forecast_generation_date,
29+
geo_exclusions = geo_exclusions
30+
)
31+
)
32+
33+
rlang::list2(
34+
tar_target(
35+
aheads,
36+
command = {
37+
c(1:7)
38+
}
39+
),
40+
tar_map(
41+
values = tib1,
42+
names = "forecast_generation_date",
43+
tar_target(
44+
hhs_latest_data,
45+
command = {
46+
epidatr::pub_covidcast(
47+
source = "hhs",
48+
signals = "confirmed_admissions_covid_1d",
49+
geo_type = "state",
50+
time_type = "day",
51+
geo_values = "*",
52+
time_values = epidatr::epirange(from = "2020-01-01", to = forecast_generation_date),
53+
as_of = forecast_generation_date,
54+
fetch_args = epidatr::fetch_args_list(return_empty = TRUE, timeout_seconds = 400)
55+
) %>%
56+
select(geo_value, time_value, value, issue) %>%
57+
rename("hhs" := value) %>%
58+
rename(version = issue)
59+
}
60+
),
61+
tar_target(
62+
forecast,
63+
command = {
64+
hhs_latest_data %>%
65+
as_epi_df() %>%
66+
smoothed_scaled(outcome = "hhs", ahead = aheads)
67+
},
68+
pattern = map(aheads)
69+
),
70+
tar_target(
71+
forecast_with_exclusions,
72+
command = {
73+
forecast %>% filter(!geo_value %in% strsplit(geo_exclusions, ",")[[1]])
74+
}
75+
),
76+
tar_target(
77+
notebook,
78+
command = {
79+
rmarkdown::render(
80+
"scripts/covid_hosp_prod.Rmd",
81+
output_file = here::here(
82+
"reports",
83+
sprintf("covid_hosp_prod_%s.html", forecast_generation_date)
84+
),
85+
params = list(
86+
forecast = forecast
87+
)
88+
)
89+
}
90+
)
91+
)
92+
)

scripts/covid_hosp_prod.Rmd

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
---
2+
title: COVID Forecaster Predictions
3+
author: COVID Forecast Team
4+
date: "Rendered: `r format(Sys.time(), '%d %B %Y')`"
5+
output:
6+
html_document:
7+
toc: True
8+
# self_contained: False
9+
# lib_dir: libs
10+
params:
11+
forecast_generation_date: !r Sys.Date()
12+
forecast: ""
13+
---
14+
15+
```{css, echo=FALSE}
16+
body {
17+
display: block;
18+
max-width: 1280px !important;
19+
margin-left: auto;
20+
margin-right: auto;
21+
}
22+
23+
body .main-container {
24+
max-width: 1280px !important;
25+
width: 1280px !important;
26+
}
27+
```
28+
29+
```{r setup, include=FALSE}
30+
library(dplyr)
31+
library(evalcast)
32+
library(here)
33+
library(magrittr)
34+
library(rlang)
35+
library(targets)
36+
library(tidyr)
37+
source(here("R", "plotting.R"))
38+
39+
forecast <- params$forecast
40+
```
41+
42+
```{r}
43+
print(forecast)
44+
```

scripts/geo_exclusions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"exclusions":
3+
{
4+
"2024-03-24": "ak,ca"
5+
}
6+
}

0 commit comments

Comments
 (0)