Skip to content

Commit a606486

Browse files
authored
Merge pull request #1711 from cmu-delphi/ndefries/training-date-from-model
Fetch modeling date from model filenames
2 parents b7005a8 + 9e47f29 commit a606486

18 files changed

+277
-92
lines changed

backfill_corrections/delphiBackfillCorrection/R/beta_prior_estimation.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ objective <- function(theta, x, prob, ...) {
4848
#' @param start the initialization of the the points in nlm
4949
#' @param base_pseudo_denom the pseudo counts added to denominator if little data for training
5050
#' @param base_pseudo_num the pseudo counts added to numerator if little data for training
51-
#' @param training_end_date the most recent training date
51+
#' @template training_end_date-template
52+
#' @template training_start_date-template
5253
#' @param model_save_dir directory containing trained models
5354
#'
5455
#' @importFrom stats nlm predict
@@ -58,7 +59,8 @@ objective <- function(theta, x, prob, ...) {
5859
est_priors <- function(train_data, prior_test_data, geo, value_type, dw, taus,
5960
covariates, response, lp_solver, lambda,
6061
indicator, signal, geo_level, signal_suffix,
61-
training_end_date, model_save_dir, start=c(0, log(10)),
62+
training_end_date, training_start_date,
63+
model_save_dir, start=c(0, log(10)),
6264
base_pseudo_denom=1000, base_pseudo_num=10,
6365
train_models = TRUE, make_predictions = TRUE) {
6466
sub_train_data <- train_data %>% filter(train_data[[dw]] == 1)
@@ -76,6 +78,7 @@ est_priors <- function(train_data, prior_test_data, geo, value_type, dw, taus,
7678
geo=geo, dw=dw, tau=tau,
7779
value_type=value_type,
7880
training_end_date=training_end_date,
81+
training_start_date=training_start_date,
7982
beta_prior_mode=TRUE)
8083
model_path <- file.path(model_save_dir, model_file_name)
8184

@@ -123,7 +126,8 @@ frac_adj_with_pseudo <- function(data, dw, pseudo_num, pseudo_denom, num_col, de
123126
#' @template train_data-template
124127
#' @param test_data testing data
125128
#' @param prior_test_data testing data for the lag -1 model
126-
#' @param training_end_date the most recent training date
129+
#' @template training_end_date-template
130+
#' @template training_start_date-template
127131
#' @param model_save_dir directory containing trained models
128132
#' @template indicator-template
129133
#' @template signal-template
@@ -141,7 +145,8 @@ frac_adj_with_pseudo <- function(data, dw, pseudo_num, pseudo_denom, num_col, de
141145
frac_adj <- function(train_data, test_data, prior_test_data,
142146
indicator, signal, geo_level, signal_suffix,
143147
lambda, value_type, geo,
144-
training_end_date, model_save_dir,
148+
training_end_date, training_start_date,
149+
model_save_dir,
145150
taus, lp_solver,
146151
train_models = TRUE,
147152
make_predictions = TRUE) {
@@ -177,7 +182,7 @@ frac_adj <- function(train_data, test_data, prior_test_data,
177182
pseudo_counts <- est_priors(train_data, prior_test_data, geo, value_type, cov, taus,
178183
pre_covariates, "log_value_target", lp_solver, lambda,
179184
indicator, signal, geo_level, signal_suffix,
180-
training_end_date, model_save_dir,
185+
training_end_date, training_start_date, model_save_dir,
181186
train_models = train_models,
182187
make_predictions = make_predictions)
183188
pseudo_denum = pseudo_counts[1]

backfill_corrections/delphiBackfillCorrection/R/io.R

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,46 @@ read_data <- function(input_dir) {
2222
#' @template lambda-template
2323
#' @template value_type-template
2424
#' @template export_dir-template
25-
#' @param training_end_date the most recent training date
25+
#' @template training_end_date-template
26+
#' @template training_start_date-template
2627
#'
2728
#' @importFrom readr write_csv
2829
#' @importFrom stringr str_interp str_split
2930
export_test_result <- function(test_data, coef_data, indicator, signal,
30-
geo_level, geo, signal_suffix, lambda,
31-
training_end_date,
31+
geo_level, signal_suffix, lambda,
32+
training_end_date, training_start_date,
3233
value_type, export_dir) {
3334
base_name <- generate_filename(indicator=indicator, signal=signal,
3435
geo_level=geo_level, signal_suffix=signal_suffix,
3536
lambda=lambda, training_end_date=training_end_date,
36-
geo=geo, value_type=value_type, model_mode=FALSE)
37+
training_start_date=training_start_date,
38+
value_type=value_type, model_mode=FALSE)
3739

38-
signal_info <- str_interp("indicator ${indicator} signal ${signal} geo ${geo} value_type ${value_type}")
40+
signal_info <- str_interp("indicator ${indicator} signal ${signal} geo_level ${geo_level} value_type ${value_type}")
41+
42+
components <- c(indicator, signal, signal_suffix)
43+
signal_dir <- paste(components[components != ""], collapse="_")
44+
45+
dir.create(file.path(export_dir, signal_dir), showWarnings = FALSE)
46+
3947
if (nrow(test_data) == 0) {
4048
warning(str_interp("No test data available for ${signal_info}"))
4149
} else {
4250
msg_ts(str_interp("Saving predictions to disk for ${signal_info} "))
4351
pred_output_file <- str_interp("prediction_${base_name}")
44-
write_csv(test_data, file.path(export_dir, pred_output_file))
52+
53+
prediction_col <- colnames(test_data)[grepl("^predicted", colnames(test_data))]
54+
expected_col <- c("time_value", "issue_date", "lag", "geo_value",
55+
"target_date", "wis", prediction_col)
56+
write_csv(test_data[expected_col], file.path(export_dir, signal_dir, pred_output_file))
4557
}
4658

4759
if (nrow(coef_data) == 0) {
4860
warning(str_interp("No coef data available for ${signal_info}"))
4961
} else {
5062
msg_ts(str_interp("Saving coefficients to disk for ${signal_info}"))
51-
coef_output_file <- str_interp("coefs_${base_name}")
52-
write_csv(coef_data, file.path(export_dir, coef_output_file))
63+
coef_output_file <- str_interp("coefs_${base_name}")
64+
write_csv(coef_data, file.path(export_dir, signal_dir, coef_output_file))
5365
}
5466
}
5567

@@ -99,13 +111,13 @@ subset_valid_files <- function(files_list, file_type = c("daily", "rollup"), par
99111
switch(file_type,
100112
daily = {
101113
start_dates <- as.Date(
102-
sub("^.*/.*_as_of_([0-9]{8}).parquet$", "\\1", files_list),
114+
sub("^.*/.*_as_of_([0-9]{8})[.]parquet$", "\\1", files_list),
103115
format = date_format
104116
)
105117
end_dates <- start_dates
106118
},
107119
rollup = {
108-
rollup_pattern <- "^.*/.*_from_([0-9]{8})_to_([0-9]{8}).parquet$"
120+
rollup_pattern <- "^.*/.*_from_([0-9]{8})_to_([0-9]{8})[.]parquet$"
109121
start_dates <- as.Date(
110122
sub(rollup_pattern, "\\1", files_list),
111123
format = date_format
@@ -117,12 +129,9 @@ subset_valid_files <- function(files_list, file_type = c("daily", "rollup"), par
117129
}
118130
)
119131

120-
# Start_date depends on if we're doing model training or just corrections.
121-
n_addl_days <- params$ref_lag
122-
if (params$train_models) {
123-
n_addl_days <- n_addl_days + params$training_days
124-
}
125-
132+
## TODO: right now, this gets both training and testing data regardless of
133+
# which mode is selected
134+
n_addl_days <- params$ref_lag + params$training_days
126135
start_date <- TODAY - n_addl_days
127136
end_date <- TODAY - 1
128137

@@ -146,7 +155,7 @@ create_name_pattern <- function(indicator, signal,
146155
file_type = c("daily", "rollup")) {
147156
file_type <- match.arg(file_type)
148157
switch(file_type,
149-
daily = str_interp("${indicator}_${signal}_as_of_[0-9]{8}.parquet$"),
150-
rollup = str_interp("${indicator}_${signal}_from_[0-9]{8}_to_[0-9]{8}.parquet$")
158+
daily = str_interp("${indicator}_${signal}_as_of_[0-9]{8}[.]parquet$"),
159+
rollup = str_interp("${indicator}_${signal}_from_[0-9]{8}_to_[0-9]{8}[.]parquet$")
151160
)
152161
}

backfill_corrections/delphiBackfillCorrection/R/main.R

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
#' @template signal_suffixes-template
99
#' @template indicator-template
1010
#' @template signal-template
11-
#' @param training_end_date the most recent training date
1211
#'
1312
#' @importFrom dplyr %>% filter select group_by summarize across everything group_split ungroup
1413
#' @importFrom tidyr drop_na
1514
#' @importFrom rlang .data .env
1615
#'
1716
#' @export
18-
run_backfill <- function(df, params, training_end_date,
17+
run_backfill <- function(df, params,
1918
refd_col = "time_value", lag_col = "lag", issued_col = "issue_date",
2019
signal_suffixes = c(""), indicator = "", signal = "") {
2120
df <- filter(df, .data$lag < params$ref_lag + 30) # a rough filtration to save memory
@@ -57,13 +56,14 @@ run_backfill <- function(df, params, training_end_date,
5756
coef_list[[key]] <- list()
5857
}
5958
}
60-
59+
6160
msg_ts("Splitting data into geo groups")
6261
group_dfs <- group_split(df, .data$geo_value)
6362

6463
# Build model for each location
6564
for (subdf in group_dfs) {
6665
geo <- subdf$geo_value[1]
66+
6767
msg_ts(str_interp("Processing ${geo} geo group"))
6868

6969
min_refd <- min(subdf[[refd_col]])
@@ -115,9 +115,9 @@ run_backfill <- function(df, params, training_end_date,
115115
combined_df <- combined_df %>% filter(.data$lag < params$ref_lag)
116116

117117
geo_train_data <- combined_df %>%
118-
filter(.data$issue_date < training_end_date) %>%
119-
filter(.data$target_date <= training_end_date) %>%
120-
filter(.data$target_date > training_end_date - params$training_days) %>%
118+
filter(.data$issue_date < params$training_end_date) %>%
119+
filter(.data$target_date <= params$training_end_date) %>%
120+
filter(.data$target_date > params$training_start_date) %>%
121121
drop_na()
122122
geo_test_data <- combined_df %>%
123123
filter(.data$issue_date %in% params$test_dates) %>%
@@ -135,7 +135,8 @@ run_backfill <- function(df, params, training_end_date,
135135
indicator = indicator, signal = signal,
136136
geo_level = geo_level, signal_suffix = signal_suffix,
137137
lambda = params$lambda, value_type = value_type, geo = geo,
138-
training_end_date = training_end_date,
138+
training_end_date = params$training_end_date,
139+
training_start_date = params$training_start_date,
139140
model_save_dir = params$cache_dir,
140141
taus = params$taus,
141142
lp_solver = params$lp_solver,
@@ -178,7 +179,9 @@ run_backfill <- function(df, params, training_end_date,
178179
lambda = params$lambda, test_lag = test_lag, geo = geo,
179180
value_type = value_type, model_save_dir = params$cache_dir,
180181
indicator = indicator, signal = signal, geo_level = geo_level,
181-
signal_suffix =signal_suffix, training_end_date = training_end_date,
182+
signal_suffix =signal_suffix,
183+
training_end_date = params$training_end_date,
184+
training_start_date = params$training_start_date,
182185
train_models = params$train_models,
183186
make_predictions = params$make_predictions
184187
)
@@ -199,23 +202,24 @@ run_backfill <- function(df, params, training_end_date,
199202
}# End for test lags
200203
}# End for value types
201204
}# End for signal suffixes
202-
203-
if (params$make_predictions) {
204-
for (value_type in params$value_types) {
205-
for (signal_suffix in signal_suffixes) {
206-
key <- make_key(value_type, signal_suffix)
207-
test_combined <- bind_rows(test_data_list[[key]])
208-
coef_combined <- bind_rows(coef_list[[key]])
209-
export_test_result(test_combined, coef_combined,
210-
indicator, signal,
211-
geo_level, geo, signal_suffix, params$lambda,
212-
training_end_date,
213-
value_type, export_dir=params$export_dir)
214-
}
205+
}# End for geo list
206+
207+
if (params$make_predictions) {
208+
for (value_type in params$value_types) {
209+
for (signal_suffix in signal_suffixes) {
210+
key <- make_key(value_type, signal_suffix)
211+
test_combined <- bind_rows(test_data_list[[key]])
212+
coef_combined <- bind_rows(coef_list[[key]])
213+
export_test_result(test_combined, coef_combined,
214+
indicator=indicator, signal=signal,
215+
signal_suffix=signal_suffix,
216+
geo_level=geo_level, lambda=params$lambda,
217+
training_end_date=params$training_end_date,
218+
training_start_date=params$training_start_date,
219+
value_type=value_type, export_dir=params$export_dir)
215220
}
216221
}
217-
218-
}# End for geo list
222+
}
219223
}# End for geo type
220224
}
221225

@@ -236,14 +240,10 @@ main <- function(params) {
236240

237241
if (params$train_models) {
238242
msg_ts("Removing stored models")
239-
files_list <- list.files(params$cache_dir, pattern="*.model", full.names = TRUE)
243+
files_list <- list.files(params$cache_dir, pattern="[.]model$", full.names = TRUE)
240244
file.remove(files_list)
241245
}
242246

243-
training_end_date <- as.Date(readLines(
244-
file.path(params$cache_dir, "training_end_date.txt")))
245-
msg_ts(str_interp("training_end_date is ${training_end_date}"))
246-
247247
## Set default number of cores for mclapply to half of those available.
248248
if (params$parallel) {
249249
cores <- detectCores()
@@ -255,7 +255,18 @@ main <- function(params) {
255255
options(mc.cores = min(params$parallel_max_cores, max(floor(cores / 2), 1L)))
256256
}
257257
}
258-
258+
259+
# Training start and end dates are the same for all indicators, so we can fetch
260+
# at the beginning.
261+
result <- get_training_date_range(params)
262+
params$training_start_date <- result$training_start_date
263+
params$training_end_date <- result$training_end_date
264+
265+
msg_ts(paste0(
266+
str_interp("training_start_date is ${params$training_start_date}, "),
267+
str_interp("training_end_date is ${params$training_end_date}")
268+
))
269+
259270
# Loop over every indicator + signal combination.
260271
for (group_i in seq_len(nrow(INDICATORS_AND_SIGNALS))) {
261272
input_group <- INDICATORS_AND_SIGNALS[group_i,]
@@ -302,14 +313,8 @@ main <- function(params) {
302313
training_days_check(input_data$issue_date, params$training_days)
303314

304315
# Perform backfill corrections and save result
305-
run_backfill(input_data, params, training_end_date,
316+
run_backfill(input_data, params,
306317
indicator = input_group$indicator, signal = input_group$signal,
307318
signal_suffixes = input_group$name_suffix)
308-
309-
if (params$train_models) {
310-
# Save the training end date to a text file.
311-
writeLines(as.character(TODAY),
312-
file.path(params$cache_dir, "training_end_date.txt"))
313-
}
314319
}
315320
}

0 commit comments

Comments
 (0)