Skip to content

Commit 678a7e4

Browse files
committed
Merge branch 'ndefries/backfill/speed2' into ndefries/backfill/speed-join-v-merge-order-matters
2 parents 849834e + 78b9d85 commit 678a7e4

File tree

9 files changed

+65
-47
lines changed

9 files changed

+65
-47
lines changed

backfill_corrections/delphiBackfillCorrection/DESCRIPTION

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ Imports:
2323
jsonlite,
2424
lubridate,
2525
tidyr,
26-
zoo,
2726
utils,
28-
parallel
27+
parallel,
28+
purrr,
29+
vctrs,
30+
RcppRoll
2931
Suggests:
3032
knitr (>= 1.15),
3133
rmarkdown (>= 1.4),

backfill_corrections/delphiBackfillCorrection/NAMESPACE

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export(read_data)
2121
export(read_params)
2222
export(run_backfill)
2323
import(covidcast)
24+
importFrom(RcppRoll,roll_mean)
2425
importFrom(arrow,read_parquet)
2526
importFrom(dplyr,"%>%")
2627
importFrom(dplyr,across)
@@ -48,6 +49,7 @@ importFrom(lubridate,make_date)
4849
importFrom(lubridate,month)
4950
importFrom(lubridate,year)
5051
importFrom(parallel,detectCores)
52+
importFrom(purrr,map_dfc)
5153
importFrom(quantgen,quantile_lasso)
5254
importFrom(readr,write_csv)
5355
importFrom(stats,coef)
@@ -60,8 +62,7 @@ importFrom(stringr,str_split)
6062
importFrom(tibble,tribble)
6163
importFrom(tidyr,crossing)
6264
importFrom(tidyr,drop_na)
63-
importFrom(tidyr,fill)
6465
importFrom(tidyr,pivot_longer)
6566
importFrom(tidyr,pivot_wider)
6667
importFrom(utils,head)
67-
importFrom(zoo,rollmeanr)
68+
importFrom(vctrs,vec_fill_missing)

backfill_corrections/delphiBackfillCorrection/R/constants.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ SQRTSCALE <-c('sqrty0', 'sqrty1', "sqrty2")
1919
LOG_LAG <-"inv_log_lag"
2020

2121
# Dates
22+
DATE_FORMAT <- "%Y-%m-%d"
2223
WEEKDAYS_ABBR <- c("Mon", "Tue", "Wed", "Thurs", "Fri", "Sat") # wd
2324
WEEK_ISSUES <- c("W1_issue", "W2_issue", "W3_issue") # wm
2425
TODAY <- Sys.Date()
2526

27+
# Signals we want to make predictions for
2628
INDICATORS_AND_SIGNALS <- tibble::tribble(
2729
~indicator, ~signal, ~name_suffix, ~sub_dir,
2830
"changehc", "covid", "", "chng",

backfill_corrections/delphiBackfillCorrection/R/io.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ get_training_date_range <- function(params) {
244244
if (params$train_models) {
245245
if (params_element_exists_and_valid(params, "training_end_date")) {
246246
# Use user-provided end date.
247-
training_end_date <- as.Date(params$training_end_date)
247+
training_end_date <- as.Date(params$training_end_date, DATE_FORMAT)
248248
} else {
249249
# Default end date is today.
250250
training_end_date <- default_end_date

backfill_corrections/delphiBackfillCorrection/R/main.R

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ run_backfill <- function(df, params,
121121
) %>%
122122
drop_na()
123123
geo_test_data <- filter(combined_df,
124-
issue_date %in% params$test_dates
124+
issue_date %in% as.character(params$test_dates)
125125
) %>%
126126
drop_na()
127127

@@ -300,7 +300,7 @@ main <- function(params,
300300
input_group$indicator, input_group$signal, params, input_group$sub_dir
301301
)
302302
if (length(files_list) == 0) {
303-
warning("No files found for indicator indicator ", input_group$indicator,
303+
warning("No files found for indicator ", input_group$indicator,
304304
" signal ", input_group$signal, ", skipping")
305305
next
306306
}
@@ -331,9 +331,6 @@ main <- function(params,
331331
refd_col = refd_col, lag_col = lag_col, issued_col = issued_col
332332
)
333333

334-
input_data[[refd_col]] <- as.Date(input_data[[refd_col]], "%Y-%m-%d")
335-
input_data[[issued_col]] <- as.Date(input_data[[issued_col]], "%Y-%m-%d")
336-
337334
# Check available training days
338335
training_days_check(input_data[[issued_col]], params$training_days)
339336

backfill_corrections/delphiBackfillCorrection/R/model.R

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,8 @@ evaluate <- function(test_data, taus) {
183183
exponentiate_preds <- function(test_data, taus) {
184184
pred_cols = paste0("predicted_tau", taus)
185185

186-
# Drop original predictions and join on exponentiated versions
187-
test_data = bind_cols(
188-
select(test_data, -starts_with("predicted")),
189-
exp(test_data[, pred_cols])
190-
)
186+
# Replace original predictions with exponentiated versions
187+
test_data[, pred_cols] <- exp(test_data[, pred_cols])
191188

192189
return(test_data)
193190
}
@@ -211,17 +208,22 @@ get_model <- function(model_path, train_data, covariates, tau,
211208
" does not exist; training new model")
212209
}
213210
# Quantile regression
211+
## TODO: how does the speed compare using GLPK? Apparently it's faster on smaller
212+
# models.
214213
obj <- quantile_lasso(as.matrix(train_data[covariates]),
215214
train_data$log_value_target, tau = tau,
216215
lambda = lambda, standardize = FALSE, lp_solver = lp_solver)
217216

218217
# Save model to cache.
219218
create_dir_not_exist(dirname(model_path))
219+
## TODO: save() is fairly slow. Since we're not sharing the model files, can we
220+
# use saveRDS() instead?
220221
save(obj, file=model_path)
221222
} else {
222223
# Load model from cache invisibly. Object has the same name as the original
223224
# model object, `obj`.
224225
msg_ts("Loading from ", model_path)
226+
## TODO: readRDS()
225227
load(model_path)
226228
}
227229

backfill_corrections/delphiBackfillCorrection/R/preprocessing.R

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@
2525
fill_rows <- function(df, refd_col, lag_col, min_refd, max_refd, ref_lag) {
2626
# Full list of lags
2727
# +30 to have values for calculating 7-day averages
28-
lags <- min(df[[lag_col]]): (ref_lag + 30)
29-
refds <- seq(min_refd, max_refd, by="day") # Full list reference date
28+
lags <- min(df[[lag_col]]): (ref_lag + 30)
29+
# Full list reference dates
30+
refds <- as.character(
31+
seq(as.Date(min_refd, DATE_FORMAT), as.Date(max_refd, DATE_FORMAT), by="day")
32+
)
3033
row_inds_df <- setNames(
3134
as.data.frame(crossing(refds, lags)),
3235
c(refd_col, lag_col)
@@ -45,49 +48,59 @@ fill_rows <- function(df, refd_col, lag_col, min_refd, max_refd, ref_lag) {
4548
#' @template refd_col-template
4649
#' @template lag_col-template
4750
#'
48-
#' @importFrom tidyr fill pivot_wider pivot_longer
49-
#' @importFrom dplyr %>% everything select
51+
#' @importFrom tidyr pivot_wider pivot_longer
52+
#' @importFrom purrr map_dfc
53+
#' @importFrom vctrs vec_fill_missing
5054
#'
5155
#' @export
5256
fill_missing_updates <- function(df, value_col, refd_col, lag_col) {
53-
pivot_df <- df[order(df[[lag_col]], decreasing=FALSE), ] %>%
54-
pivot_wider(id_cols=lag_col, names_from=refd_col, values_from=value_col)
57+
pivot_df <- pivot_wider(
58+
df[order(df[[lag_col]], decreasing=FALSE), ],
59+
id_cols=lag_col, names_from=refd_col, values_from=value_col
60+
)
5561

5662
if (any(diff(pivot_df[[lag_col]]) != 1)) {
5763
stop("Risk exists in forward filling")
5864
}
59-
pivot_df <- fill(pivot_df, everything(), .direction="down")
65+
66+
pivot_df <- map_dfc(pivot_df, function(col) {
67+
vec_fill_missing(col, direction="down")
68+
})
6069

6170
# Fill NAs with 0s
6271
pivot_df[is.na(pivot_df)] <- 0
6372

6473
backfill_df <- pivot_longer(pivot_df,
6574
-lag_col, values_to="value_raw", names_to=refd_col
6675
)
67-
backfill_df[[refd_col]] = as.Date(backfill_df[[refd_col]])
6876

6977
return (as.data.frame(backfill_df))
7078
}
7179

7280
#' Calculate 7 day moving average for each issue date
81+
#'
7382
#' The 7dav for date D reported on issue date D_i is the average from D-7 to D-1
83+
#'
7484
#' @param pivot_df Data Frame where the columns are issue dates and the rows are
7585
#' reference dates
7686
#' @template refd_col-template
7787
#'
78-
#' @importFrom zoo rollmeanr
79-
#'
88+
#' @importFrom RcppRoll roll_mean
89+
#'
8090
#' @export
8191
get_7dav <- function(pivot_df, refd_col) {
82-
for (col in colnames(pivot_df)) {
83-
if (col == refd_col) next
84-
pivot_df[, col] <- rollmeanr(pivot_df[, col], 7, align="right", fill=NA)
85-
}
92+
pivot_df <- cbind(
93+
# Keep time values at the front
94+
pivot_df[, refd_col],
95+
# Compute moving average of all non-refd columns
96+
RcppRoll::roll_mean(
97+
as.matrix(pivot_df[, names(pivot_df)[names(pivot_df) != refd_col]]),
98+
7L, align = "right", fill = NA
99+
)
100+
)
86101
backfill_df <- pivot_longer(pivot_df,
87102
-refd_col, values_to="value_raw", names_to="issue_date"
88103
)
89-
backfill_df[[refd_col]] = as.Date(backfill_df[[refd_col]])
90-
backfill_df[["issue_date"]] = as.Date(backfill_df[["issue_date"]])
91104
return (as.data.frame(backfill_df))
92105
}
93106

@@ -99,7 +112,7 @@ get_7dav <- function(pivot_df, refd_col) {
99112
#'
100113
#' @export
101114
add_shift <- function(df, n_day, refd_col) {
102-
df[, refd_col] <- as.Date(df[, refd_col]) + n_day
115+
df[[refd_col]] <- as.character(as.Date(df[[refd_col]], DATE_FORMAT) + n_day)
103116
return (df)
104117
}
105118

@@ -113,7 +126,7 @@ add_shift <- function(df, n_day, refd_col) {
113126
#'
114127
#' @export
115128
add_dayofweek <- function(df, time_col, suffix, wd = WEEKDAYS_ABBR) {
116-
dayofweek <- as.numeric(format(df[[time_col]], format="%u"))
129+
dayofweek <- as.numeric(format(as.Date(df[[time_col]], DATE_FORMAT), format="%u"))
117130
for (i in seq_along(wd)) {
118131
df[, paste0(wd[i], suffix)] <- as.numeric(dayofweek == i)
119132
}
@@ -159,7 +172,7 @@ get_weekofmonth <- function(date) {
159172
#'
160173
#' @export
161174
add_weekofmonth <- function(df, time_col, wm = WEEK_ISSUES) {
162-
weekofmonth <- get_weekofmonth(df[[time_col]])
175+
weekofmonth <- get_weekofmonth(as.Date(df[[time_col]], DATE_FORMAT))
163176
for (i in seq_along(wm)) {
164177
df[, paste0(wm[i])] <- as.numeric(weekofmonth == i)
165178
}
@@ -174,16 +187,17 @@ add_weekofmonth <- function(df, time_col, wm = WEEK_ISSUES) {
174187
#' @template lag_col-template
175188
#' @template ref_lag-template
176189
#'
177-
#' @importFrom dplyr %>% full_join left_join
190+
#' @importFrom dplyr full_join left_join
178191
#' @importFrom tidyr pivot_wider drop_na
179192
#'
180193
#' @export
181194
add_7davs_and_target <- function(df, value_col, refd_col, lag_col, ref_lag) {
182-
df$issue_date <- df[[refd_col]] + df[[lag_col]]
183-
pivot_df <- df[order(df$issue_date, decreasing=FALSE), ] %>%
184-
pivot_wider(id_cols=refd_col, names_from="issue_date",
185-
values_from=value_col)
186-
195+
df$issue_date <- as.character(as.Date(df[[refd_col]], DATE_FORMAT) + df[[lag_col]])
196+
pivot_df <- pivot_wider(
197+
df[order(df$issue_date, decreasing=FALSE), ],
198+
id_cols=refd_col, names_from="issue_date", values_from=value_col
199+
)
200+
187201
# Add 7dav avg
188202
avg_df <- get_7dav(pivot_df, refd_col)
189203
avg_df <- add_shift(avg_df, 1, refd_col) # 7dav until yesterday

backfill_corrections/delphiBackfillCorrection/R/utils.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ read_params <- function(path = "params.json", template_path = "params.json.templ
9191
stop("`test_dates` setting in params must be a length-2 list of dates")
9292
}
9393
params$test_dates <- seq(
94-
as.Date(params$test_dates[1]),
95-
as.Date(params$test_dates[2]),
94+
as.Date(params$test_dates[1], DATE_FORMAT),
95+
as.Date(params$test_dates[2], DATE_FORMAT),
9696
by="days"
9797
)
9898
}
9999
if (params_element_exists_and_valid(params, "training_end_date")) {
100-
if (as.Date(params$training_end_date) > TODAY) {
100+
if (as.Date(params$training_end_date, DATE_FORMAT) > TODAY) {
101101
stop("training_end_date can't be in the future")
102102
}
103103
}
@@ -189,7 +189,9 @@ validity_checks <- function(df, value_types, num_col, denom_col, signal_suffixes
189189
#' @param issue_date contents of input data's `issue_date` column
190190
#' @template training_days-template
191191
training_days_check <- function(issue_date, training_days) {
192-
valid_training_days = as.integer(max(issue_date) - min(issue_date)) + 1
192+
valid_training_days = as.integer(
193+
as.Date(max(issue_date), DATE_FORMAT) - as.Date(min(issue_date), DATE_FORMAT)
194+
) + 1
193195
if (training_days > valid_training_days) {
194196
warning(sprintf("Only %d days are available at most for training.", valid_training_days))
195197
}

backfill_corrections/delphiBackfillCorrection/man/get_7dav.Rd

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)