Skip to content

Make epi_slide faster #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ importFrom(dplyr,group_by_drop_default)
importFrom(dplyr,group_modify)
importFrom(dplyr,group_vars)
importFrom(dplyr,groups)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
importFrom(dplyr,relocate)
importFrom(dplyr,rename)
Expand All @@ -103,6 +104,7 @@ importFrom(rlang,enquos)
importFrom(rlang,env)
importFrom(rlang,f_env)
importFrom(rlang,f_rhs)
importFrom(rlang,hash)
importFrom(rlang,is_environment)
importFrom(rlang,is_formula)
importFrom(rlang,is_function)
Expand Down
84 changes: 42 additions & 42 deletions R/slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,16 @@
#' through the `new_col_name` argument.
#'
#' @importFrom lubridate days weeks
#' @importFrom dplyr bind_rows group_vars filter select
#' @importFrom rlang .data .env !! enquo enquos sym env missing_arg
#' @importFrom dplyr bind_rows group_vars filter select left_join group_vars
#' @importFrom rlang .data .env !! enquo enquos sym env missing_arg hash
#' @export
#' @examples
#' @examples
#' # slide a 7-day trailing average formula on cases
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide(cases_7dav = mean(cases), before = 6) %>%
#' epi_slide(cases_7dav = mean(cases), before = 6) %>%
#' # rmv a nonessential var. to ensure new col is printed
#' dplyr::select(-death_rate_7d_av)
#' dplyr::select(-death_rate_7d_av)
#'
#' # slide a 7-day leading average
#' jhu_csse_daily_subset %>%
Expand All @@ -143,16 +143,16 @@
#' # slide a 7-day centre-aligned average
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide(cases_7dav = mean(cases), before = 3, after = 3) %>%
#' epi_slide(cases_7dav = mean(cases), before = 3, after = 3) %>%
#' # rmv a nonessential var. to ensure new col is printed
#' dplyr::select(-death_rate_7d_av)
#' dplyr::select(-death_rate_7d_av)
#'
#' # slide a 14-day centre-aligned average
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide(cases_7dav = mean(cases), before = 6, after = 7) %>%
#' epi_slide(cases_7dav = mean(cases), before = 6, after = 7) %>%
#' # rmv a nonessential var. to ensure new col is printed
#' dplyr::select(-death_rate_7d_av)
#' dplyr::select(-death_rate_7d_av)
#'
#' # nested new columns
#' jhu_csse_daily_subset %>%
Expand All @@ -161,17 +161,17 @@
#' cases_2dma = mad(cases)),
#' before = 1, as_list_col = TRUE)
epi_slide = function(x, f, ..., before, after, ref_time_values,
time_step,
time_step,
new_col_name = "slide_value", as_list_col = FALSE,
names_sep = "_", all_rows = FALSE) {
names_sep = "_", all_rows = FALSE) {

# Check we have an `epi_df` object
if (!inherits(x, "epi_df")) Abort("`x` must be of class `epi_df`.")

if (missing(ref_time_values)) {
ref_time_values = unique(x$time_value)
}

# Some of these `ref_time_values` checks and processing steps also apply to
# the `ref_time_values` default; for simplicity, just apply all the steps
# regardless of whether we are working with a default or user-provided
Expand All @@ -187,7 +187,7 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
} else {
ref_time_values = sort(ref_time_values)
}

# Validate and pre-process `before`, `after`:
if (!missing(before)) {
before <- vctrs::vec_cast(before, integer())
Expand Down Expand Up @@ -231,9 +231,6 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
min_ref_time_values_not_in_x <- min_ref_time_values[!(min_ref_time_values %in% unique(x$time_value))]

# Do set up to let us recover `ref_time_value`s later.
# A helper column marking real observations.
x$.real = TRUE

# Create df containing phony data. Df has the same columns and attributes as
# `x`, but filled with `NA`s aside from grouping columns. Number of rows is
# equal to the number of `min_ref_time_values_not_in_x` we have * the
Expand All @@ -249,19 +246,16 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
# Automatically fill in all other columns from `x` with `NA`s, and carry
# attributes over to new df.
before_time_values_df <- bind_rows(x[0,], before_time_values_df)
before_time_values_df$.real <- FALSE

x <- bind_rows(before_time_values_df, x)

# Arrange by increasing time_value
x = arrange(x, time_value)

# Now set up starts and stops for sliding/hopping
time_range = range(unique(x$time_value))
time_range = range(unique(c(x$time_value, before_time_values_df$time_value)))
starts = in_range(ref_time_values - before, time_range)
stops = in_range(ref_time_values + after, time_range)
if( length(starts) == 0 || length(stops) == 0 ) {

if( length(starts) == 0 || length(stops) == 0 ) {
Abort("The starting and/or stopping times for sliding are out of bounds with respect to the range of times in your data. Check your settings for ref_time_values and align (and before, if specified).")
}

Expand All @@ -270,7 +264,7 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,

# Computation for one group, all time values
slide_one_grp = function(.data_group,
f, ...,
f, ...,
starts,
stops,
time_values,
Expand All @@ -283,9 +277,9 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
o = time_values %in% .data_group$time_value
starts = starts[o]
stops = stops[o]
time_values = time_values[o]
# Compute the slide values
time_values = time_values[o]

# Compute the slide values
slide_values_list = slider::hop_index(.x = .data_group,
.i = .data_group$time_value,
.f = f, ...,
Expand All @@ -296,7 +290,7 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
# values; this will be useful for all sorts of checks that follow
o = .data_group$time_value %in% time_values
num_ref_rows = sum(o)

# Count the number of appearances of each reference time value (these
# appearances should all be real for now, but if we allow ref time values
# outside of .data_group's time values):
Expand Down Expand Up @@ -344,7 +338,6 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
# fills with NA equivalent.
vctrs::vec_slice(slide_values, o) = orig_values
} else {
# This implicitly removes phony (`.real` == FALSE) observations.
.data_group = filter(.data_group, o)
}
return(mutate(.data_group, !!new_col := slide_values))
Expand All @@ -359,19 +352,35 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
if (length(quos) > 1) {
Abort("If `f` is missing then only a single computation can be specified via `...`.")
}

f = quos[[1]]
new_col = sym(names(rlang::quos_auto_name(quos)))
... = missing_arg() # magic value that passes zero args as dots in calls below
}

# Pre-dalculate min ref_time_value for each group, instead of doing so in
# `f_wrapper`.
min_before_time_values_df <- summarize(
before_time_values_df, min_ref_date = min(time_value)
)
# Turn by-group ref_time_value into a named list for fast and easy
# retrieval. Names are hashed lists of group key names and values.
min_before_time_values_map <- as.list(min_before_time_values_df$min_ref_date)
names(min_before_time_values_map) <- purrr::pmap_chr(
as.list(min_before_time_values_df[, dplyr::group_vars(before_time_values_df)]),
function(...) {
rlang::hash(list(...))
}
)

f = as_slide_computation(f, ...)
# Create a wrapper that calculates and passes `.ref_time_value` to the
# computation.
f_wrapper = function(.x, .group_key, ...) {
.ref_time_value = min(.x$time_value) + before
.x <- .x[.x$.real,]
.x$.real <- NULL
# Hash current group key using the same approach as above so that we can
# access pre-calculated by-group min ref_time_values.
min_phony_ref_date <- min_before_time_values_map[[rlang::hash(as.list(.group_key))]]
.ref_time_value = min(.x$time_value, min_phony_ref_date ) + before
f(.x, .group_key, .ref_time_value, ...)
}
x = group_modify(x, slide_one_grp,
Expand All @@ -388,14 +397,5 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
x = unnest(x, !!new_col, names_sep = names_sep)
}

# Remove any remaining phony observations. When `all_rows` is TRUE, phony
# observations aren't necessarily removed in `slide_one_grp`.
if (all_rows) {
x <- x[x$.real,]
}

# Drop helper column `.real`.
x$.real <- NULL

return(x)
}