Skip to content

Commit 4481a3b

Browse files
committed
feat: add covid hosp prod pipeline
1 parent 97232cf commit 4481a3b

File tree

6 files changed

+431
-1
lines changed

6 files changed

+431
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ scripts/**.html
99
nohup.out
1010
run.Rout
1111
tmp.R
12+
reports/

R/plotting.R

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
library(dplyr)
2+
library(ggplot2)
3+
library(magrittr)
4+
library(tidyr)
5+
6+
7+
get_quantiles_df <- function(predictions_cards, intervals = c(.5, .9), ...) {
8+
predictions_cards <- predictions_cards %>%
9+
dplyr::select(
10+
geo_value,
11+
quantile,
12+
value,
13+
forecaster,
14+
forecast_date,
15+
target_end_date
16+
)
17+
18+
lower_bounds <- predictions_cards %>%
19+
select(.data$quantile) %>%
20+
filter(.data$quantile < 0.5) %>%
21+
unique() %>%
22+
pull()
23+
quantiles_to_plot <- as.integer(sort(
24+
round(500L * (1 + intervals %o% c(-1L, 1L)))
25+
))
26+
27+
quantiles_df <- predictions_cards %>%
28+
filter(as.integer(round(.data$quantile * 1000)) %in% c(quantiles_to_plot)) %>%
29+
mutate(
30+
endpoint_type = if_else(.data$quantile < 0.5, "lower", "upper"),
31+
alp = if_else(.data$endpoint_type == "lower",
32+
format(2 * .data$quantile, digits = 3, nsmall = 3),
33+
format(2 * (1 - .data$quantile), digits = 3, nsmall = 3)
34+
),
35+
interval = forcats::fct_rev(
36+
paste0((1 - as.numeric(.data$alp)) * 100, "%")
37+
)
38+
) %>%
39+
select(-.data$quantile, -.data$alp) %>%
40+
pivot_wider(names_from = "endpoint_type", values_from = "value")
41+
42+
return(quantiles_df)
43+
}
44+
45+
get_points_df <- function(predictions_cards) {
46+
points_df <- predictions_cards %>%
47+
filter(as.integer(round(.data$quantile * 1000)) == 500L |
48+
is.na(.data$quantile))
49+
if (any(is.na(points_df$quantile))) {
50+
points_df <- points_df %>%
51+
pivot_wider(names_from = "quantile", values_from = "value") %>%
52+
mutate(value = if_else(!is.na(.data$`NA`), .data$`NA`, .data$`0.5`)) %>%
53+
select(-.data$`0.5`, -.data$`NA`)
54+
} else {
55+
points_df <- points_df %>%
56+
select(-.data$quantile)
57+
}
58+
59+
return(points_df)
60+
}
61+
62+
plot_quantiles <- function(g, quantiles_df) {
63+
n_quantiles <- nlevels(quantiles_df$interval)
64+
l_quantiles <- levels(quantiles_df$interval)
65+
66+
alp <- c(.4, .2, .1)
67+
for (qq in n_quantiles:1) {
68+
g <- g +
69+
geom_ribbon(
70+
data = quantiles_df %>%
71+
filter(.data$interval == l_quantiles[qq]),
72+
mapping = aes(
73+
ymin = .data$lower,
74+
ymax = .data$upper,
75+
group = interaction(.data$forecast_date, .data$forecaster),
76+
color = NULL
77+
),
78+
alpha = alp[qq]
79+
)
80+
}
81+
82+
return(g)
83+
}
84+
85+
plot_points <- function(g, points_df) {
86+
g <- g + geom_point(
87+
data = points_df,
88+
mapping = aes(
89+
y = .data$value,
90+
group = interaction(.data$forecast_date, .data$forecaster)
91+
),
92+
size = 0.125
93+
)
94+
95+
return(g)
96+
}
97+
98+
plot_state_forecasters <- function(predictions_cards, exclude_geos = c(), start_day = NULL, ncol = 5) {
99+
if (nrow(predictions_cards) == 0) {
100+
return(NULL)
101+
}
102+
103+
td1 <- epidatr::pub_covidcast(
104+
source = "hhs",
105+
signals = "confirmed_admissions_covid_1d",
106+
geo_type = "state",
107+
time_type = "day",
108+
geo_values = "*",
109+
time_values = epidatr::epirange(start_day, Sys.Date())
110+
) %>%
111+
filter(!(.data$geo_value %in% exclude_geos)) %>%
112+
dplyr::select(.data$geo_value, .data$time_value, .data$value) %>%
113+
dplyr::rename(target_end_date = .data$time_value) %>%
114+
mutate(data_source = "hhs")
115+
td2 <- epidatr::pub_covidcast(
116+
source = "jhu-csse",
117+
signals = "confirmed_7dav_incidence_num",
118+
geo_type = "state",
119+
time_type = "day",
120+
geo_values = "*",
121+
time_values = epidatr::epirange(start_day, Sys.Date())
122+
) %>%
123+
filter(!(.data$geo_value %in% exclude_geos)) %>%
124+
dplyr::select(.data$geo_value, .data$time_value, .data$value) %>%
125+
dplyr::rename(target_end_date = .data$time_value) %>%
126+
mutate(data_source = "jhu")
127+
128+
td1.max <- td1 %>%
129+
group_by(geo_value) %>%
130+
summarize(max_value = max(value))
131+
td2.max <- td2 %>%
132+
group_by(geo_value) %>%
133+
summarize(max_value = max(value))
134+
td2.max <- td2.max %>%
135+
left_join(td1.max, by = "geo_value", suffix = c(".2", ".1")) %>%
136+
mutate(max_ratio = max_value.1 / max_value.2)
137+
td2 <- td2 %>%
138+
left_join(td2.max, by = "geo_value") %>%
139+
mutate(scaled_value = value * max_ratio)
140+
td1 <- td1 %>% mutate(forecaster = "hhs hosp truth")
141+
td2 <- td2 %>% mutate(forecaster = "jhu cases truth")
142+
143+
# Setup plot
144+
g <- ggplot(td1, mapping = aes(x = .data$target_end_date, color = .data$forecaster, fill = .data$forecaster))
145+
g <- plot_points(g, get_points_df(predictions_cards))
146+
g <- plot_quantiles(g, get_quantiles_df(predictions_cards))
147+
g <- g +
148+
geom_line(mapping = aes(y = .data$value)) +
149+
geom_line(data = td2, mapping = aes(x = .data$target_end_date, y = .data$scaled_value)) +
150+
facet_wrap(~ .data$geo_value, scales = "free_y", ncol = ncol, drop = TRUE) +
151+
theme(legend.position = "top", legend.text = element_text(size = 7))
152+
153+
return(g)
154+
}
155+
156+
plot_nation_forecasters <- function(predictions_cards, exclude_geos = c(), start_day = NULL, ncol = 5) {
157+
if (nrow(predictions_cards) == 0) {
158+
return(NULL)
159+
}
160+
161+
td1 <- epidatr::pub_covidcast(
162+
source = "hhs",
163+
signals = "confirmed_admissions_covid_1d",
164+
geo_type = "nation",
165+
time_type = "day",
166+
geo_values = "*",
167+
time_values = epidatr::epirange(start_day, Sys.Date())
168+
) %>%
169+
filter(!(.data$geo_value %in% exclude_geos)) %>%
170+
dplyr::select(.data$time_value, .data$value) %>%
171+
dplyr::rename(target_end_date = .data$time_value) %>%
172+
mutate(data_source = "hhs")
173+
td2 <- epidatr::pub_covidcast(
174+
source = "jhu-csse",
175+
signals = "confirmed_7dav_incidence_num",
176+
geo_type = "nation",
177+
time_type = "day",
178+
geo_values = "*",
179+
time_values = epidatr::epirange(start_day, Sys.Date())
180+
) %>%
181+
filter(!(.data$geo_value %in% exclude_geos)) %>%
182+
dplyr::select(.data$time_value, .data$value) %>%
183+
dplyr::rename(target_end_date = .data$time_value) %>%
184+
mutate(data_source = "jhu")
185+
td1.max <- td1 %>%
186+
summarize(max_value = max(value)) %>%
187+
pull(max_value)
188+
td2.max <- td2 %>%
189+
summarize(max_value = max(value)) %>%
190+
pull(max_value)
191+
td2 <- td2 %>% mutate(scaled_value = value * td1.max / td2.max)
192+
193+
# Setup plot
194+
g <- ggplot(td1, mapping = aes(x = .data$target_end_date))
195+
g <- plot_quantiles(g, get_quantiles_df(predictions_cards))
196+
g <- plot_points(g, get_points_df(predictions_cards))
197+
g <- g +
198+
geom_line(mapping = aes(y = .data$value, color = "confirmed admissions")) +
199+
geom_line(data = td2, mapping = aes(x = .data$target_end_date, y = .data$scaled_value, color = "7day case sum")) +
200+
labs(fill = "Reported Signal") +
201+
theme(legend.position = "top", legend.text = element_text(size = 7))
202+
203+
return(g)
204+
}

covid_hosp_prod/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
*
2+
!.gitignore
3+
!meta
4+
!*.R
5+
meta/*
6+
# !meta/meta

scripts/covid_hosp_prod.R

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,90 @@
1-
# TODO
1+
# The COVID Hospitalization Production Forecasting Pipeline.
2+
#
3+
# Ran into some issues with targets:
4+
# https://github.com/ropensci/targets/discussions/666#discussioncomment-9050772
5+
6+
source("scripts/targets-common.R")
7+
8+
9+
#' Get exclusions from a JSON file for a given date
10+
#'
11+
#' @param date A date
12+
#' @param exclusions_json A JSON file with exclusions in the format:
13+
#'
14+
#' {"exclusions": {"2024-03-24": "ak,hi"}}
15+
get_exclusions <- function(date, exclusions_json = here::here("scripts", "geo_exclusions.json")) {
16+
s <- jsonlite::read_json(exclusions_json)$exclusions[[as.character(date)]]
17+
if (!is.null(s)) {
18+
return(s)
19+
}
20+
return("")
21+
}
22+
23+
forecast_generation_date <- as.character(seq.Date(as.Date("2024-01-01"), Sys.Date(), by = "1 week"))
24+
geo_exclusions <- Vectorize(get_exclusions)(forecast_generation_date)
25+
26+
rlang::list2(
27+
tar_target(
28+
aheads,
29+
command = {
30+
c(1:7)
31+
}
32+
),
33+
tar_map(
34+
values = tidyr::expand_grid(
35+
tibble(
36+
forecast_generation_date = forecast_generation_date,
37+
geo_exclusions = geo_exclusions
38+
)
39+
),
40+
names = "forecast_generation_date",
41+
tar_target(
42+
hhs_latest_data,
43+
command = {
44+
epidatr::pub_covidcast(
45+
source = "hhs",
46+
signals = "confirmed_admissions_covid_1d",
47+
geo_type = "state",
48+
time_type = "day",
49+
geo_values = "*",
50+
time_values = epidatr::epirange(from = "2020-01-01", to = forecast_generation_date),
51+
as_of = forecast_generation_date,
52+
fetch_args = epidatr::fetch_args_list(return_empty = TRUE, timeout_seconds = 400)
53+
) %>%
54+
select(geo_value, time_value, value, issue) %>%
55+
rename("hhs" := value) %>%
56+
rename(version = issue)
57+
}
58+
),
59+
tar_target(
60+
forecast,
61+
command = {
62+
hhs_latest_data %>%
63+
as_epi_df() %>%
64+
smoothed_scaled(outcome = "hhs", ahead = aheads)
65+
},
66+
pattern = map(aheads)
67+
),
68+
tar_target(
69+
forecast_with_exclusions,
70+
command = {
71+
forecast %>% filter(!geo_value %in% strsplit(geo_exclusions, ",")[[1]])
72+
}
73+
),
74+
tar_target(
75+
notebook,
76+
command = {
77+
rmarkdown::render(
78+
"scripts/covid_hosp_prod.Rmd",
79+
output_file = here::here(
80+
"reports",
81+
sprintf("covid_hosp_prod_%s.html", forecast_generation_date)
82+
),
83+
params = list(
84+
forecast = forecast
85+
)
86+
)
87+
}
88+
)
89+
)
90+
)

0 commit comments

Comments
 (0)