Skip to content

Commit 2af0296

Browse files
Jingjing TangJingjing Tang
Jingjing Tang
authored and
Jingjing Tang
committed
decision format implementation, update unittest
1 parent 3ed7d7e commit 2af0296

File tree

4 files changed

+44
-34
lines changed

4 files changed

+44
-34
lines changed

backfill_corrections/delphiBackfillCorrection/R/io.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,21 @@ read_data <- function(input_dir) {
2828
#' @importFrom readr write_csv
2929
#' @importFrom stringr str_interp str_split
3030
export_test_result <- function(test_data, coef_data, indicator, signal,
31-
geo_level, geo, signal_suffix, lambda,
31+
geo_level, signal_suffix, lambda,
3232
training_end_date, training_start_date,
3333
value_type, export_dir) {
3434
base_name <- generate_filename(indicator=indicator, signal=signal,
3535
geo_level=geo_level, signal_suffix=signal_suffix,
3636
lambda=lambda, training_end_date=training_end_date,
3737
training_start_date=training_start_date,
38-
geo=geo, value_type=value_type, model_mode=FALSE)
38+
value_type=value_type, model_mode=FALSE)
3939

40-
signal_info <- str_interp("indicator ${indicator} signal ${signal} geo ${geo} value_type ${value_type}")
40+
signal_info <- str_interp("indicator ${indicator} signal ${signal} geo_level ${geo_level} value_type ${value_type}")
4141

4242
components <- c(indicator, signal, signal_suffix)
43-
signal_dir = paste0(
44-
# Drop any empty strings.
45-
paste(components[components != ""], collapse="_"),
46-
)
43+
signal_dir <- paste(components[components != ""], collapse="_")
44+
45+
dir.create(file.path(export_dir, signal_dir), showWarnings = FALSE)
4746

4847
if (nrow(test_data) == 0) {
4948
warning(str_interp("No test data available for ${signal_info}"))
@@ -52,7 +51,8 @@ export_test_result <- function(test_data, coef_data, indicator, signal,
5251
pred_output_file <- str_interp("prediction_${base_name}")
5352

5453
prediction_col <- colnames(test_data)[grepl("^predicted", colnames(test_data))]
55-
expected_col <- c("time_value", "issue_date", "lag", "target_date", "wis", prediction_col)
54+
expected_col <- c("time_value", "issue_date", "lag", "geo_value",
55+
"target_date", "wis", prediction_col)
5656
write_csv(test_data[expected_col], file.path(export_dir, signal_dir, pred_output_file))
5757
}
5858

backfill_corrections/delphiBackfillCorrection/R/main.R

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,16 @@ run_backfill <- function(df, params,
5656
coef_list[[key]] <- list()
5757
}
5858
}
59-
59+
6060
msg_ts("Splitting data into geo groups")
6161
group_dfs <- group_split(df, .data$geo_value)
6262

6363
# Build model for each location
6464
for (subdf in group_dfs) {
6565
geo <- subdf$geo_value[1]
66+
67+
if (!(geo %in% c("pa", "ma", "ny"))) {next}
68+
6669
msg_ts(str_interp("Processing ${geo} geo group"))
6770

6871
min_refd <- min(subdf[[refd_col]])
@@ -201,25 +204,24 @@ run_backfill <- function(df, params,
201204
}# End for test lags
202205
}# End for value types
203206
}# End for signal suffixes
204-
205-
if (params$make_predictions) {
206-
for (value_type in params$value_types) {
207-
for (signal_suffix in signal_suffixes) {
208-
key <- make_key(value_type, signal_suffix)
209-
test_combined <- bind_rows(test_data_list[[key]])
210-
coef_combined <- bind_rows(coef_list[[key]])
211-
export_test_result(test_combined, coef_combined,
212-
indicator=indicator, signal=signal,
213-
geo_level=geo_level, geo=geo,
214-
signal_suffix=signal_suffix, lambda=params$lambda,
215-
training_end_date=params$training_end_date,
216-
training_start_date=params$training_start_date,
217-
value_type=value_type, export_dir=params$export_dir)
218-
}
207+
}# End for geo list
208+
209+
if (params$make_predictions) {
210+
for (value_type in params$value_types) {
211+
for (signal_suffix in signal_suffixes) {
212+
key <- make_key(value_type, signal_suffix)
213+
test_combined <- bind_rows(test_data_list[[key]])
214+
coef_combined <- bind_rows(coef_list[[key]])
215+
export_test_result(test_combined, coef_combined,
216+
indicator=indicator, signal=signal,
217+
signal_suffix=signal_suffix,
218+
geo_level=geo_level, lambda=params$lambda,
219+
training_end_date=params$training_end_date,
220+
training_start_date=params$training_start_date,
221+
value_type=value_type, export_dir=params$export_dir)
219222
}
220223
}
221-
222-
}# End for geo list
224+
}
223225
}# End for geo type
224226
}
225227

backfill_corrections/delphiBackfillCorrection/R/model.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ model_training_and_testing <- function(train_data, test_data, taus, covariates,
136136
if (success < length(taus)) {return (NULL)}
137137
if (!make_predictions) {return (list())}
138138

139-
coef_combined_result = data.frame(tau=taus, geo=geo, test_lag=test_lag,
139+
test_data$geo_value = geo
140+
coef_combined_result = data.frame(tau=taus, geo_value=geo, test_lag=test_lag,
140141
training_end_date=training_end_date,
141142
training_start_date=training_start_date,
142143
lambda=lambda)
@@ -263,7 +264,7 @@ generate_filename <- function(indicator, signal,
263264
geo_level, lambda, value_type,
264265
geo, test_lag, dw, tau)
265266

266-
filename = paste0(
267+
filename <- paste0(
267268
# Drop any empty strings.
268269
paste(components[components != ""], collapse="_"),
269270
file_type

backfill_corrections/delphiBackfillCorrection/unit-tests/testthat/test-io.R

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,30 @@ create_dir_not_exist("./cache")
2121

2222
test_that("testing exporting the output file", {
2323
params <- read_params("params-run.json", "params-run.json.template")
24-
24+
25+
expected_col <- c("time_value", "issue_date", "lag", "geo_value",
26+
"target_date", "wis", "predicted_tau0.5")
2527
test_data <- data.frame(test=TRUE)
28+
test_data[expected_col] = TRUE
2629
coef_data <- data.frame(test=TRUE)
2730

31+
components <- c(indicator, signal, signal_suffix)
32+
signal_dir <- paste(components[components != ""], collapse="_")
33+
2834
export_test_result(test_data, coef_data, indicator, signal,
29-
geo_level, geo="", signal_suffix, lambda,
35+
geo_level, signal_suffix, lambda,
3036
training_end_date, training_start_date,
3137
value_type, params$export_dir)
32-
prediction_file <- file.path(params$export_dir, "prediction_20220101_20211225_chng_outpatient_state_lambda0.1_fraction.csv.gz")
33-
coefs_file <- file.path(params$export_dir, "coefs_20220101_20211225_chng_outpatient_state_lambda0.1_fraction.csv.gz")
38+
prediction_file <- file.path(params$export_dir, signal_dir,
39+
"prediction_20220101_20211225_chng_outpatient_state_lambda0.1_fraction.csv.gz")
40+
coefs_file <- file.path(params$export_dir, signal_dir,
41+
"coefs_20220101_20211225_chng_outpatient_state_lambda0.1_fraction.csv.gz")
3442

3543
expect_true(file.exists(prediction_file))
3644
expect_true(file.exists(coefs_file))
3745

3846
# Remove
39-
file.remove(prediction_file)
40-
file.remove(coefs_file)
47+
unlink(file.path(params$export_dir, signal_dir),recursive = TRUE)
4148
file.remove("params-run.json")
4249
})
4350

0 commit comments

Comments
 (0)