Skip to content

Commit eb56faf

Browse files
authored
Merge branch 'ndefries/training-date-from-model' into ndefries/bc-input-file-dates
2 parents 5acee86 + 8ce7082 commit eb56faf

15 files changed

+120
-55
lines changed

backfill_corrections/delphiBackfillCorrection/R/beta_prior_estimation.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ objective <- function(theta, x, prob, ...) {
4848
#' @param start the initialization of the the points in nlm
4949
#' @param base_pseudo_denom the pseudo counts added to denominator if little data for training
5050
#' @param base_pseudo_num the pseudo counts added to numerator if little data for training
51-
#' @param training_end_date the most recent training date
51+
#' @template training_end_date-template
52+
#' @template training_start_date-template
5253
#' @param model_save_dir directory containing trained models
5354
#'
5455
#' @importFrom stats nlm predict
@@ -58,7 +59,8 @@ objective <- function(theta, x, prob, ...) {
5859
est_priors <- function(train_data, prior_test_data, geo, value_type, dw, taus,
5960
covariates, response, lp_solver, lambda,
6061
indicator, signal, geo_level, signal_suffix,
61-
training_end_date, model_save_dir, start=c(0, log(10)),
62+
training_end_date, training_start_date,
63+
model_save_dir, start=c(0, log(10)),
6264
base_pseudo_denom=1000, base_pseudo_num=10,
6365
train_models = TRUE, make_predictions = TRUE) {
6466
sub_train_data <- train_data %>% filter(train_data[[dw]] == 1)
@@ -76,6 +78,7 @@ est_priors <- function(train_data, prior_test_data, geo, value_type, dw, taus,
7678
geo=geo, dw=dw, tau=tau,
7779
value_type=value_type,
7880
training_end_date=training_end_date,
81+
training_start_date=training_start_date,
7982
beta_prior_mode=TRUE)
8083
model_path <- file.path(model_save_dir, model_file_name)
8184

@@ -123,7 +126,8 @@ frac_adj_with_pseudo <- function(data, dw, pseudo_num, pseudo_denom, num_col, de
123126
#' @template train_data-template
124127
#' @param test_data testing data
125128
#' @param prior_test_data testing data for the lag -1 model
126-
#' @param training_end_date the most recent training date
129+
#' @template training_end_date-template
130+
#' @template training_start_date-template
127131
#' @param model_save_dir directory containing trained models
128132
#' @template indicator-template
129133
#' @template signal-template
@@ -141,7 +145,8 @@ frac_adj_with_pseudo <- function(data, dw, pseudo_num, pseudo_denom, num_col, de
141145
frac_adj <- function(train_data, test_data, prior_test_data,
142146
indicator, signal, geo_level, signal_suffix,
143147
lambda, value_type, geo,
144-
training_end_date, model_save_dir,
148+
training_end_date, training_start_date,
149+
model_save_dir,
145150
taus, lp_solver,
146151
train_models = TRUE,
147152
make_predictions = TRUE) {
@@ -177,7 +182,7 @@ frac_adj <- function(train_data, test_data, prior_test_data,
177182
pseudo_counts <- est_priors(train_data, prior_test_data, geo, value_type, cov, taus,
178183
pre_covariates, "log_value_target", lp_solver, lambda,
179184
indicator, signal, geo_level, signal_suffix,
180-
training_end_date, model_save_dir,
185+
training_end_date, training_start_date, model_save_dir,
181186
train_models = train_models,
182187
make_predictions = make_predictions)
183188
pseudo_denum = pseudo_counts[1]

backfill_corrections/delphiBackfillCorrection/R/io.R

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@ read_data <- function(input_dir) {
2222
#' @template lambda-template
2323
#' @template value_type-template
2424
#' @template export_dir-template
25-
#' @param training_end_date the most recent training date
25+
#' @template training_end_date-template
26+
#' @template training_start_date-template
2627
#'
2728
#' @importFrom readr write_csv
2829
#' @importFrom stringr str_interp str_split
2930
export_test_result <- function(test_data, coef_data, indicator, signal,
3031
geo_level, geo, signal_suffix, lambda,
31-
training_end_date,
32+
training_end_date, training_start_date,
3233
value_type, export_dir) {
3334
base_name <- generate_filename(indicator=indicator, signal=signal,
3435
geo_level=geo_level, signal_suffix=signal_suffix,
3536
lambda=lambda, training_end_date=training_end_date,
37+
training_start_date=training_start_date,
3638
geo=geo, value_type=value_type, model_mode=FALSE)
3739

3840
signal_info <- str_interp("indicator ${indicator} signal ${signal} geo ${geo} value_type ${value_type}")
@@ -102,13 +104,13 @@ subset_valid_files <- function(files_list, file_type = c("daily", "rollup"), par
102104
switch(file_type,
103105
daily = {
104106
start_issue_dates <- as.Date(
105-
sub("^.*/.*_as_of_([0-9]{8}).parquet$", "\\1", files_list),
107+
sub("^.*/.*_as_of_([0-9]{8})[.]parquet$", "\\1", files_list),
106108
format = date_format
107109
)
108110
end_issue_dates <- start_issue_dates
109111
},
110112
rollup = {
111-
rollup_pattern <- "^.*/.*_from_([0-9]{8})_to_([0-9]{8}).parquet$"
113+
rollup_pattern <- "^.*/.*_from_([0-9]{8})_to_([0-9]{8})[.]parquet$"
112114
start_issue_dates <- as.Date(
113115
sub(rollup_pattern, "\\1", files_list),
114116
format = date_format
@@ -187,7 +189,7 @@ create_name_pattern <- function(indicator, signal,
187189
file_type = c("daily", "rollup")) {
188190
file_type <- match.arg(file_type)
189191
switch(file_type,
190-
daily = str_interp("${indicator}_${signal}_as_of_[0-9]{8}.parquet$"),
191-
rollup = str_interp("${indicator}_${signal}_from_[0-9]{8}_to_[0-9]{8}.parquet$")
192+
daily = str_interp("${indicator}_${signal}_as_of_[0-9]{8}[.]parquet$"),
193+
rollup = str_interp("${indicator}_${signal}_from_[0-9]{8}_to_[0-9]{8}[.]parquet$")
192194
)
193195
}

backfill_corrections/delphiBackfillCorrection/R/main.R

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
run_backfill <- function(df, params,
1818
refd_col = "time_value", lag_col = "lag", issued_col = "issue_date",
1919
signal_suffixes = c(""), indicator = "", signal = "") {
20-
result <- get_training_date_range(params)
21-
training_start_date <- result$training_start_date
22-
training_end_date <- result$training_end_date
23-
2420
df <- filter(df, .data$lag < params$ref_lag + 30) # a rough filtration to save memory
2521

2622
geo_levels <- params$geo_levels
@@ -118,9 +114,9 @@ run_backfill <- function(df, params,
118114
combined_df <- combined_df %>% filter(.data$lag < params$ref_lag)
119115

120116
geo_train_data <- combined_df %>%
121-
filter(.data$issue_date < training_end_date) %>%
122-
filter(.data$target_date <= training_end_date) %>%
123-
filter(.data$target_date > training_start_date) %>%
117+
filter(.data$issue_date < params$training_end_date) %>%
118+
filter(.data$target_date <= params$training_end_date) %>%
119+
filter(.data$target_date > params$training_start_date) %>%
124120
drop_na()
125121
geo_test_data <- combined_df %>%
126122
filter(.data$issue_date %in% params$test_dates) %>%
@@ -138,7 +134,8 @@ run_backfill <- function(df, params,
138134
indicator = indicator, signal = signal,
139135
geo_level = geo_level, signal_suffix = signal_suffix,
140136
lambda = params$lambda, value_type = value_type, geo = geo,
141-
training_end_date = training_end_date,
137+
training_end_date = params$training_end_date,
138+
training_start_date = params$training_start_date,
142139
model_save_dir = params$cache_dir,
143140
taus = params$taus,
144141
lp_solver = params$lp_solver,
@@ -181,7 +178,9 @@ run_backfill <- function(df, params,
181178
lambda = params$lambda, test_lag = test_lag, geo = geo,
182179
value_type = value_type, model_save_dir = params$cache_dir,
183180
indicator = indicator, signal = signal, geo_level = geo_level,
184-
signal_suffix =signal_suffix, training_end_date = training_end_date,
181+
signal_suffix =signal_suffix,
182+
training_end_date = params$training_end_date,
183+
training_start_date = params$training_start_date,
185184
train_models = params$train_models,
186185
make_predictions = params$make_predictions
187186
)
@@ -210,10 +209,12 @@ run_backfill <- function(df, params,
210209
test_combined <- bind_rows(test_data_list[[key]])
211210
coef_combined <- bind_rows(coef_list[[key]])
212211
export_test_result(test_combined, coef_combined,
213-
indicator, signal,
214-
geo_level, geo, signal_suffix, params$lambda,
215-
training_end_date,
216-
value_type, export_dir=params$export_dir)
212+
indicator=indicator, signal=signal,
213+
geo_level=geo_level, geo=geo,
214+
signal_suffix=signal_suffix, lambda=params$lambda,
215+
training_end_date=params$training_end_date,
216+
training_start_date=params$training_start_date,
217+
value_type=value_type, export_dir=params$export_dir)
217218
}
218219
}
219220
}
@@ -239,7 +240,7 @@ main <- function(params) {
239240

240241
if (params$train_models) {
241242
msg_ts("Removing stored models")
242-
files_list <- list.files(params$cache_dir, pattern="*.model", full.names = TRUE)
243+
files_list <- list.files(params$cache_dir, pattern="[.]model$", full.names = TRUE)
243244
file.remove(files_list)
244245
}
245246

@@ -254,7 +255,13 @@ main <- function(params) {
254255
options(mc.cores = min(params$parallel_max_cores, max(floor(cores / 2), 1L)))
255256
}
256257
}
257-
258+
259+
# Training start and end dates are the same for all indicators, so we can fetch
260+
# at the beginning.
261+
result <- get_training_date_range(params)
262+
params$training_start_date <- result$training_start_date
263+
params$training_end_date <- result$training_end_date
264+
258265
# Loop over every indicator + signal combination.
259266
for (group_i in seq_len(nrow(INDICATORS_AND_SIGNALS))) {
260267
input_group <- INDICATORS_AND_SIGNALS[group_i,]

backfill_corrections/delphiBackfillCorrection/R/model.R

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ add_sqrtscale<- function(train_data, test_data, max_raw, value_col) {
8989
#' @template train_models-template
9090
#' @template make_predictions-template
9191
#' @param model_save_dir directory containing trained models
92-
#' @param training_end_date Most recent training date
92+
#' @template training_end_date-template
93+
#' @template training_start_date-template
9394
#'
9495
#' @importFrom stats predict coef
9596
#' @importFrom stringr str_interp
@@ -101,6 +102,7 @@ model_training_and_testing <- function(train_data, test_data, taus, covariates,
101102
indicator, signal,
102103
geo_level, signal_suffix,
103104
training_end_date,
105+
training_start_date,
104106
train_models = TRUE,
105107
make_predictions = TRUE) {
106108
success = 0
@@ -112,6 +114,7 @@ model_training_and_testing <- function(train_data, test_data, taus, covariates,
112114
model_file_name <- generate_filename(indicator=indicator, signal=signal,
113115
geo_level=geo_level, signal_suffix=signal_suffix,
114116
lambda=lambda, training_end_date=training_end_date,
117+
training_start_date=training_start_date,
115118
geo=geo, value_type=value_type,
116119
test_lag=test_lag, tau=tau)
117120
model_path <- file.path(model_save_dir, model_file_name)
@@ -220,15 +223,16 @@ get_model <- function(model_path, train_data, covariates, tau,
220223
#' @param tau decimal quantile to be predicted. Values must be between 0 and 1.
221224
#' @param beta_prior_mode bool, indicate whether it is for a beta prior model
222225
#' @param model_mode bool, indicate whether the file name is for a model
223-
#' @param training_end_date the most recent training date
226+
#' @template training_end_date-template
227+
#' @template training_start_date-template
224228
#'
225229
#' @return path to file containing model object
226230
#'
227231
#' @importFrom stringr str_interp
228232
#'
229233
generate_filename <- function(indicator, signal,
230234
geo_level, signal_suffix, lambda,
231-
training_end_date="", geo="",
235+
training_end_date, training_start_date, geo="",
232236
value_type = "", test_lag="", tau="", dw="",
233237
beta_prior_mode = FALSE, model_mode = TRUE) {
234238
if (lambda != "") {
@@ -250,7 +254,8 @@ generate_filename <- function(indicator, signal,
250254
} else {
251255
file_type <- ".csv.gz"
252256
}
253-
components <- c(as.character(training_end_date), beta_prior,
257+
components <- c(format(training_end_date, "%Y%m%d"),
258+
format(training_start_date, "%Y%m%d"), beta_prior,
254259
indicator, signal, signal_suffix,
255260
geo_level, lambda, value_type,
256261
geo, test_lag, dw, tau)
@@ -265,8 +270,10 @@ generate_filename <- function(indicator, signal,
265270

266271
#' Get date range of data to use for training models
267272
#'
268-
#' Calculate training end date, input data start date, and input
269-
#' data end date based on user settings.
273+
#' Calculate training start and end dates based on user settings.
274+
#' `training_start_date` is the minimum allowed target date when selecting
275+
#' training data to use. `training_end_date` is the maximum allowed target
276+
#' date and maximum allowed issue date.
270277
#'
271278
#' Cases:
272279
#' 1. We are training new models.
@@ -290,20 +297,26 @@ get_training_date_range <- function(params) {
290297
training_end_date <- default_end_date
291298
}
292299
} else {
293-
# Get end date from cached model files.
294-
# Assumes filename format like `2022-06-28_changehc_covid_state_lambda0.1_count_ca_lag5_tau0.9.model`
295-
# where the leading date is the training end date for that model.
296-
model_files <- list.files(params$cache_dir, "202[0-9]-[0-9]{2}-[0-9]{2}*.model")
300+
# Get end date from cached model files. Assumes filename format like
301+
# `20220628_20220529_changehc_covid_state_lambda0.1_count_ca_lag5_tau0.9.model`
302+
# where the leading date is the training end date for that model, and the
303+
# second date is the training start date.
304+
model_files <- list.files(params$cache_dir, "^202[0-9]{5}_202[0-9]{5}.*[.]model$")
297305
if (length(model_files) == 0) {
298306
# We know we'll be retraining models today.
299307
training_end_date <- default_end_date
300308
} else {
301309
# If only some models are in the cache, they will be used and those
302310
# missing will be regenerated as-of the training end date.
303-
training_end_date <- max(as.Date(substr(model_files, 1, 10)))
311+
training_end_date <- max(as.Date(substr(model_files, 1, 8), "%Y%m%d"))
304312
}
305313
}
306314

315+
# Calculate start date instead of reading from cached files. This assumes
316+
# that the user-provided `params$training_days` is more up-to-date. If
317+
# `params$training_days` has changed such that for a given training end
318+
# date, the calculated training start date differs from the start date
319+
# referenced in cached file names, then those cached files will not be used.
307320
training_start_date <- training_end_date - params$training_days
308321

309322
msg_ts(paste0(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#' @param training_end_date the latest target date and issue date included in training data
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#' @param training_start_date the earliest target date included in training data

backfill_corrections/delphiBackfillCorrection/man/est_priors.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backfill_corrections/delphiBackfillCorrection/man/export_test_result.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backfill_corrections/delphiBackfillCorrection/man/frac_adj.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backfill_corrections/delphiBackfillCorrection/man/generate_filename.Rd

Lines changed: 5 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backfill_corrections/delphiBackfillCorrection/man/get_training_date_range.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)