Skip to content

Commit c73af47

Browse files
committed
test: modify pipeline and refactor ab comparison script
1 parent a595f89 commit c73af47

File tree

3 files changed

+103
-42
lines changed

3 files changed

+103
-42
lines changed

scripts/covid_hosp_explore.R

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,41 @@ source("scripts/targets-common.R")
66
# Add custom parameter combinations in the list below.
77
make_unique_grids <- function() {
88
list(
9+
# tidyr::expand_grid(
10+
# forecaster = "scaled_pop",
11+
# trainer = c("linreg", "quantreg"),
12+
# ahead = c(1:7, 14, 21, 28),
13+
# pop_scaling = TRUE
14+
# ),
915
tidyr::expand_grid(
1016
forecaster = "scaled_pop",
11-
trainer = c("linreg", "quantreg"),
17+
trainer = c("linreg"),
1218
ahead = c(1:7, 14, 21, 28),
13-
pop_scaling = TRUE
14-
),
15-
tidyr::expand_grid(
16-
forecaster = "scaled_pop",
17-
trainer = c("linreg", "quantreg"),
18-
ahead = c(1:7, 14, 21, 28),
19-
lags = list(c(0, 3, 5, 7, 14), c(0, 7, 14)),
20-
pop_scaling = TRUE
21-
),
22-
tidyr::expand_grid(
23-
forecaster = "smoothed_scaled",
24-
trainer = c("quantreg"),
25-
ahead = c(1:7, 14, 21, 28),
26-
#
27-
lags = list(
28-
# smoothed, sd, smoothed, sd
29-
list(c(0, 3, 5, 7, 14), c(0)),
30-
list(c(0, 7, 14, 21, 28), c(0)),
31-
list(c(0, 2, 4, 7, 14, 21, 28), c(0))
32-
),
19+
lags = list(c(0, 3, 5, 7, 14)),
3320
pop_scaling = TRUE
3421
)
22+
23+
# tidyr::expand_grid(
24+
# forecaster = "scaled_pop",
25+
# trainer = c("linreg", "quantreg"),
26+
# ahead = c(1:7, 14, 21, 28),
27+
# lags = list(c(0, 3, 5, 7, 14), c(0, 7, 14)),
28+
# pop_scaling = TRUE
29+
# )
30+
# ,
31+
# tidyr::expand_grid(
32+
# forecaster = "smoothed_scaled",
33+
# trainer = c("quantreg"),
34+
# ahead = c(1:7, 14, 21, 28),
35+
# #
36+
# lags = list(
37+
# # smoothed, sd, smoothed, sd
38+
# list(c(0, 3, 5, 7, 14), c(0)),
39+
# list(c(0, 7, 14, 21, 28), c(0)),
40+
# list(c(0, 2, 4, 7, 14, 21, 28), c(0))
41+
# ),
42+
# pop_scaling = TRUE
43+
# )
3544
)
3645
}
3746
#
@@ -50,24 +59,25 @@ make_unique_ensemble_grid <- function() {
5059
lags = c(0L, 3L, 5L, 7L, 14L)
5160
),
5261
list(forecaster = "flatline_fc")
53-
),
54-
# median forecaster
55-
"ensemble_average",
56-
list(average_type = "median"),
57-
list(
58-
list(
59-
forecaster = "scaled_pop",
60-
trainer = "linreg",
61-
pop_scaling = TRUE,
62-
lags = c(0, 3, 5, 7, 14)
63-
),
64-
list(
65-
forecaster = "scaled_pop",
66-
trainer = "linreg",
67-
pop_scaling = FALSE,
68-
lags = c(0, 3, 5, 7, 14)
69-
)
70-
),
62+
)
63+
# ,
64+
# # median forecaster
65+
# "ensemble_average",
66+
# list(average_type = "median"),
67+
# list(
68+
# list(
69+
# forecaster = "scaled_pop",
70+
# trainer = "linreg",
71+
# pop_scaling = TRUE,
72+
# lags = c(0, 3, 5, 7, 14)
73+
# ),
74+
# list(
75+
# forecaster = "scaled_pop",
76+
# trainer = "linreg",
77+
# pop_scaling = FALSE,
78+
# lags = c(0, 3, 5, 7, 14)
79+
# )
80+
# ),
7181
)
7282
}
7383

scripts/one_offs/r6_refactor.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# R6 refactor comparison script.
2+
#
3+
# This script is used to compare the old and new R6 refactor objects to ensure
4+
# that the refactor did not change the forecast output. This script assumes that
5+
# you:
6+
#
7+
# 1. Ran the covid_hosp_explore pipeline (or downloaded the objects using `make
8+
# download`)
9+
# 2. Copied the cache objects to a new directory (e.g. `covid_hosp_explore
10+
# copy`)
11+
# 3. Installed the new epiprocess branch
12+
# `renv::install("cmu-delphi/epiprocess@ds/r6-clean")`
13+
# 4. Ran the covid_hosp_explore pipeline again (should take about 3.5 hours)
14+
#
15+
# Once that is done, you should be able to run the script below and find no
16+
# differences in the forecasts.
17+
18+
library(dplyr)
19+
library(magrittr)
20+
library(purrr)
21+
library(qs)
22+
23+
df <- targets::tar_manifest()
24+
25+
# Both have already been produced, so we can just read them in. Let's do a loop to compare them.
26+
old_forecasts <- list.files("covid_hosp_explore copy/objects", full.names = TRUE) %>%
27+
keep(~ basename(.) %in% df$name) %>%
28+
sort()
29+
new_forecasts <- list.files("covid_hosp_explore/objects", full.names = TRUE) %>%
30+
keep(~ basename(.) %in% df$name) %>%
31+
sort()
32+
33+
# Make sure the lists are the same length and the basenames match
34+
assertthat::assert_that(
35+
c(
36+
length(old_forecasts) == length(new_forecasts),
37+
basename(old_forecasts) == basename(new_forecasts)
38+
) %>% all()
39+
)
40+
41+
tib <- tibble::tibble(
42+
old_forecasts = old_forecasts,
43+
new_forecasts = new_forecasts,
44+
compare = purrr::map2_chr(old_forecasts, new_forecasts, function(x, y) {
45+
all.equal(qs::qread(x), qs::qread(y)) %>% as.character()
46+
})
47+
)
48+
tib %>%
49+
filter(compare != "TRUE") %>%
50+
slice(1:5) %>%
51+
c()
52+
53+
x <- qread("covid_hosp_explore copy/objects/joined_archive_data_2022")
54+
y <- qread("covid_hosp_explore/objects/joined_archive_data_2022")

scripts/run.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
# # Save to disk
2727
# saveRDS(scorecards, "exploration-scorecards-2023-10-04.RDS")
2828

29-
30-
31-
3229
tar_project <- Sys.getenv("TAR_PROJECT", "covid_hosp_explore")
3330
external_scores_path <- Sys.getenv("EXTERNAL_SCORES_PATH", "")
3431
debug_mode <- as.logical(Sys.getenv("DEBUG_MODE", TRUE))

0 commit comments

Comments
 (0)