Skip to content

Commit b4caa58

Browse files
authored
Merge pull request #1802 from cmu-delphi/ndefries/backfill/speed
Make backfill corrections faster pt 1
2 parents ec9a30f + 523b1a6 commit b4caa58

File tree

12 files changed

+112
-137
lines changed

12 files changed

+112
-137
lines changed

backfill_corrections/delphiBackfillCorrection/DESCRIPTION

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ Imports:
2525
tidyr,
2626
zoo,
2727
utils,
28-
rlang,
2928
parallel
3029
Suggests:
3130
knitr (>= 1.15),

backfill_corrections/delphiBackfillCorrection/NAMESPACE

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ importFrom(dplyr,filter)
3333
importFrom(dplyr,group_by)
3434
importFrom(dplyr,group_split)
3535
importFrom(dplyr,if_else)
36-
importFrom(dplyr,mutate)
3736
importFrom(dplyr,pull)
38-
importFrom(dplyr,rename)
3937
importFrom(dplyr,select)
4038
importFrom(dplyr,starts_with)
4139
importFrom(dplyr,summarize)
@@ -50,9 +48,6 @@ importFrom(lubridate,year)
5048
importFrom(parallel,detectCores)
5149
importFrom(quantgen,quantile_lasso)
5250
importFrom(readr,write_csv)
53-
importFrom(rlang,":=")
54-
importFrom(rlang,.data)
55-
importFrom(rlang,.env)
5651
importFrom(stats,coef)
5752
importFrom(stats,nlm)
5853
importFrom(stats,pbeta)

backfill_corrections/delphiBackfillCorrection/R/beta_prior_estimation.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ objective <- function(theta, x, prob, ...) {
5353
#' @param model_save_dir directory containing trained models
5454
#'
5555
#' @importFrom stats nlm predict
56-
#' @importFrom dplyr %>% filter
56+
#' @importFrom dplyr filter
5757
#' @importFrom quantgen quantile_lasso
5858
#'
5959
est_priors <- function(train_data, prior_test_data, geo, value_type, dw, taus,
@@ -63,8 +63,8 @@ est_priors <- function(train_data, prior_test_data, geo, value_type, dw, taus,
6363
model_save_dir, start=c(0, log(10)),
6464
base_pseudo_denom=1000, base_pseudo_num=10,
6565
train_models = TRUE, make_predictions = TRUE) {
66-
sub_train_data <- train_data %>% filter(train_data[[dw]] == 1)
67-
sub_test_data <- prior_test_data %>% filter(prior_test_data[[dw]] == 1)
66+
sub_train_data <- filter(train_data, train_data[[dw]] == 1)
67+
sub_test_data <- filter(prior_test_data, prior_test_data[[dw]] == 1)
6868
if (nrow(sub_test_data) == 0) {
6969
pseudo_denom <- base_pseudo_denom
7070
pseudo_num <- base_pseudo_num

backfill_corrections/delphiBackfillCorrection/R/io.R

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,15 @@ read_data <- function(input_file) {
1313
#' Make sure data contains a `geo_value` field
1414
#'
1515
#' @template df-template
16-
#'
17-
#' @importFrom dplyr rename select
18-
#' @importFrom rlang .data
1916
fips_to_geovalue <- function(df) {
2017
if ( !("geo_value" %in% colnames(df)) ) {
2118
if ( !("fips" %in% colnames(df)) ) {
2219
stop("Either `fips` or `geo_value` field must be available")
2320
}
24-
df <- rename(df, geo_value = .data$fips)
21+
df$geo_value <- df$fips
2522
}
2623
if ( "fips" %in% colnames(df) ) {
27-
df <- select(df, -.data$fips)
24+
df$fips <- NULL
2825
}
2926
return(df)
3027
}
@@ -63,10 +60,10 @@ export_test_result <- function(test_data, coef_data, indicator, signal,
6360
dir.create(file.path(export_dir, signal_dir), showWarnings = FALSE)
6461

6562
if (nrow(test_data) == 0) {
66-
warning(str_interp("No test data available for ${signal_info}"))
63+
warning("No test data available for ", signal_info)
6764
} else {
68-
msg_ts(str_interp("Saving predictions to disk for ${signal_info} "))
69-
pred_output_file <- str_interp("prediction_${base_name}")
65+
msg_ts("Saving predictions to disk for ", signal_info)
66+
pred_output_file <- paste0("prediction_", base_name)
7067

7168
prediction_col <- colnames(test_data)[grepl("^predicted", colnames(test_data))]
7269
expected_col <- c("time_value", "issue_date", "lag", "geo_value",
@@ -75,10 +72,10 @@ export_test_result <- function(test_data, coef_data, indicator, signal,
7572
}
7673

7774
if (nrow(coef_data) == 0) {
78-
warning(str_interp("No coef data available for ${signal_info}"))
75+
warning("No coef data available for ", signal_info)
7976
} else {
80-
msg_ts(str_interp("Saving coefficients to disk for ${signal_info}"))
81-
coef_output_file <- str_interp("coefs_${base_name}")
77+
msg_ts("Saving coefficients to disk for ", signal_info)
78+
coef_output_file <- paste0("coefs_", base_name)
8279
write_csv(coef_data, file.path(export_dir, signal_dir, coef_output_file))
8380
}
8481
}

backfill_corrections/delphiBackfillCorrection/R/main.R

Lines changed: 46 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
#' @template indicator-template
1010
#' @template signal-template
1111
#'
12-
#' @importFrom dplyr %>% filter select group_by summarize across everything group_split ungroup
12+
#' @importFrom dplyr %>% filter group_by summarize across everything group_split ungroup
1313
#' @importFrom tidyr drop_na
14-
#' @importFrom rlang .data .env
15-
#' @importFrom stringr str_interp
1614
#'
1715
#' @export
1816
run_backfill <- function(df, params,
1917
refd_col = "time_value", lag_col = "lag", issued_col = "issue_date",
2018
signal_suffixes = c(""), indicator = "", signal = "") {
21-
df <- filter(df, .data$lag < params$ref_lag + 30) # a rough filtration to save memory
19+
df <- filter(df, lag < params$ref_lag + 30) # a rough filtration to save memory
2220

2321
geo_levels <- params$geo_levels
2422
if ("state" %in% geo_levels) {
@@ -28,23 +26,24 @@ run_backfill <- function(df, params,
2826
}
2927

3028
for (geo_level in geo_levels) {
31-
msg_ts(str_interp("geo level ${geo_level}"))
29+
msg_ts("geo level ", geo_level)
3230
# Get full list of interested locations
3331
if (geo_level == "state") {
3432
# Drop county field and make new "geo_value" field from "state_id".
3533
# Aggregate counties up to state level
3634
agg_cols <- c("geo_value", issued_col, refd_col, lag_col)
3735
# Sum all non-agg columns. Summarized columns keep original names
36+
df$geo_value <- df$state_id
37+
df$state_id <- NULL
3838
df <- df %>%
39-
select(-.data$geo_value, geo_value = .data$state_id) %>%
4039
group_by(across(agg_cols)) %>%
4140
summarize(across(everything(), sum)) %>%
4241
ungroup()
4342
}
4443
if (geo_level == "county") {
4544
# Keep only 200 most populous (within the US) counties
4645
top_200_geos <- get_populous_counties()
47-
df <- filter(df, .data$geo_value %in% top_200_geos)
46+
df <- filter(df, geo_value %in% top_200_geos)
4847
}
4948

5049
test_data_list <- list()
@@ -59,13 +58,13 @@ run_backfill <- function(df, params,
5958
}
6059

6160
msg_ts("Splitting data into geo groups")
62-
group_dfs <- group_split(df, .data$geo_value)
61+
group_dfs <- group_split(df, geo_value)
6362

6463
# Build model for each location
6564
for (subdf in group_dfs) {
6665
geo <- subdf$geo_value[1]
6766

68-
msg_ts(str_interp("Processing ${geo} geo group"))
67+
msg_ts("Processing ", geo, " geo group")
6968

7069
min_refd <- min(subdf[[refd_col]])
7170
max_refd <- max(subdf[[refd_col]])
@@ -78,7 +77,7 @@ run_backfill <- function(df, params,
7877
# process again. Main use case is for quidel which has overall and
7978
# age-based signals.
8079
if (signal_suffix != "") {
81-
msg_ts(str_interp("signal suffix ${signal_suffix}"))
80+
msg_ts("signal suffix ", signal_suffix)
8281
num_col <- paste(params$num_col, signal_suffix, sep = "_")
8382
denom_col <- paste(params$denom_col, signal_suffix, sep = "_")
8483
} else {
@@ -87,7 +86,7 @@ run_backfill <- function(df, params,
8786
}
8887

8988
for (value_type in params$value_types) {
90-
msg_ts(str_interp("value type ${value_type}"))
89+
msg_ts("value type ", value_type)
9190
# Handle different signal types
9291
if (value_type == "count") { # For counts data only
9392
combined_df <- fill_missing_updates(subdf, num_col, refd_col, lag_col)
@@ -113,15 +112,17 @@ run_backfill <- function(df, params,
113112
)
114113
}
115114
combined_df <- add_params_for_dates(combined_df, refd_col, lag_col)
116-
combined_df <- combined_df %>% filter(.data$lag < params$ref_lag)
115+
combined_df <- filter(combined_df, lag < params$ref_lag)
117116

118-
geo_train_data <- combined_df %>%
119-
filter(.data$issue_date < params$training_end_date) %>%
120-
filter(.data$target_date <= params$training_end_date) %>%
121-
filter(.data$target_date > params$training_start_date) %>%
117+
geo_train_data <- filter(combined_df,
118+
issue_date < params$training_end_date,
119+
target_date <= params$training_end_date,
120+
target_date > params$training_start_date,
121+
) %>%
122122
drop_na()
123-
geo_test_data <- combined_df %>%
124-
filter(.data$issue_date %in% params$test_dates) %>%
123+
geo_test_data <- filter(combined_df,
124+
issue_date %in% params$test_dates
125+
) %>%
125126
drop_na()
126127

127128
if (nrow(geo_test_data) == 0) {
@@ -135,9 +136,10 @@ run_backfill <- function(df, params,
135136

136137
if (value_type == "fraction") {
137138
# Use beta prior approach to adjust fractions
138-
geo_prior_test_data = combined_df %>%
139-
filter(.data$issue_date > min(params$test_dates) - 7) %>%
140-
filter(.data$issue_date <= max(params$test_dates))
139+
geo_prior_test_data = filter(combined_df,
140+
issue_date > min(params$test_dates) - 7,
141+
issue_date <= max(params$test_dates)
142+
)
141143
updated_data <- frac_adj(geo_train_data, geo_test_data, geo_prior_test_data,
142144
indicator = indicator, signal = signal,
143145
geo_level = geo_level, signal_suffix = signal_suffix,
@@ -154,16 +156,15 @@ run_backfill <- function(df, params,
154156
}
155157
max_raw = sqrt(max(geo_train_data$value_raw))
156158
for (test_lag in params$test_lags) {
157-
msg_ts(str_interp("test lag ${test_lag}"))
159+
msg_ts("test lag ", test_lag)
158160
filtered_data <- data_filteration(test_lag, geo_train_data,
159161
geo_test_data, params$lag_pad)
160162
train_data <- filtered_data[[1]]
161163
test_data <- filtered_data[[2]]
162164

163165
if (nrow(train_data) == 0 || nrow(test_data) == 0) {
164-
msg_ts(str_interp(
165-
"Not enough data to either train or test for test_lag ${test_lag}, skipping"
166-
))
166+
msg_ts("Not enough data to either train or test for test_lag ",
167+
test_lag, ", skipping")
167168
next
168169
}
169170

@@ -238,9 +239,8 @@ run_backfill <- function(df, params,
238239
#' @template lag_col-template
239240
#' @template issued_col-template
240241
#'
241-
#' @importFrom dplyr bind_rows mutate %>%
242+
#' @importFrom dplyr bind_rows %>%
242243
#' @importFrom parallel detectCores
243-
#' @importFrom rlang .data :=
244244
#' @importFrom stringr str_interp
245245
#'
246246
#' @export
@@ -253,7 +253,7 @@ main <- function(params,
253253

254254
indicators_subset <- INDICATORS_AND_SIGNALS
255255
if (params$indicators != "all") {
256-
indicators_subset <- filter(indicators_subset, .data$indicator == params$indicators)
256+
indicators_subset <- filter(indicators_subset, indicator == params$indicators)
257257
}
258258
if (nrow(indicators_subset) == 0) {
259259
stop("no indicators to process")
@@ -288,62 +288,51 @@ main <- function(params,
288288
params$training_start_date <- result$training_start_date
289289
params$training_end_date <- result$training_end_date
290290

291-
msg_ts(paste0(
292-
str_interp("training_start_date is ${params$training_start_date}, "),
293-
str_interp("training_end_date is ${params$training_end_date}")
294-
))
291+
msg_ts("training_start_date is ", params$training_start_date,
292+
", training_end_date is ", params$training_end_date)
295293

296294
# Loop over every indicator + signal combination.
297295
for (group_i in seq_len(nrow(indicators_subset))) {
298296
input_group <- indicators_subset[group_i,]
299-
msg_ts(str_interp(
300-
"Processing indicator ${input_group$indicator} signal ${input_group$signal}"
301-
))
297+
msg_ts("Processing indicator ", input_group$indicator, " signal ", input_group$signal)
302298

303299
files_list <- get_files_list(
304300
input_group$indicator, input_group$signal, params, input_group$sub_dir
305301
)
306302
if (length(files_list) == 0) {
307-
warning(str_interp(
308-
"No files found for indicator ${input_group$indicator} signal ${input_group$signal}, skipping"
309-
))
303+
warning("No files found for indicator indicator ", input_group$indicator,
304+
" signal ", input_group$signal, ", skipping")
310305
next
311306
}
312307

313308
msg_ts("Reading in and combining associated files")
314309
input_data <- lapply(
315310
files_list,
316311
function(file) {
312+
# refd_col and issued_col read in as strings
317313
read_data(file) %>%
318-
fips_to_geovalue() %>%
319-
mutate(
320-
# Use `glue` syntax to construct a new field by variable,
321-
# from https://stackoverflow.com/a/26003971/14401472
322-
"{refd_col}" := as.Date(.data[[refd_col]], "%Y-%m-%d"),
323-
"{issued_col}" := as.Date(.data[[issued_col]], "%Y-%m-%d")
324-
)
314+
fips_to_geovalue()
325315
}
326316
) %>%
327317
bind_rows()
328318

329319
if (nrow(input_data) == 0) {
330-
warning(str_interp(
331-
"No data available for indicator ${input_group$indicator} signal ${input_group$signal}, skipping"
332-
))
320+
warning("No data available for indicator ", input_group$indicator,
321+
" signal ", input_group$signal, ", skipping")
333322
next
334323
}
335324

336325
# Check data type and required columns
337326
msg_ts("Validating input data")
338-
for (value_type in params$value_types) {
339-
msg_ts(str_interp("for ${value_type}"))
340-
result <- validity_checks(
341-
input_data, value_type,
342-
params$num_col, params$denom_col, input_group$name_suffix,
343-
refd_col = refd_col, lag_col = lag_col, issued_col = issued_col
344-
)
345-
input_data <- result[["df"]]
346-
}
327+
# Validate while date fields still stored as strings for speed.
328+
input_data <- validity_checks(
329+
input_data, params$value_types,
330+
params$num_col, params$denom_col, input_group$name_suffix,
331+
refd_col = refd_col, lag_col = lag_col, issued_col = issued_col
332+
)
333+
334+
input_data[[refd_col]] <- as.Date(input_data[[refd_col]], "%Y-%m-%d")
335+
input_data[[issued_col]] <- as.Date(input_data[[issued_col]], "%Y-%m-%d")
347336

348337
# Check available training days
349338
training_days_check(input_data[[issued_col]], params$training_days)

0 commit comments

Comments
 (0)