Skip to content

Commit 96c83d6

Browse files
committed
use closure to fetch min_ref_time_values from starts instead of
recalculating
1 parent f16eb60 commit 96c83d6

File tree

2 files changed

+61
-42
lines changed

2 files changed

+61
-42
lines changed

R/slide.R

+17-42
Original file line numberDiff line numberDiff line change
@@ -230,37 +230,15 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
230230
after <- time_step(after)
231231
}
232232

233+
# Do set up to let us recover `ref_time_value`s later.
233234
min_ref_time_values <- ref_time_values - before
234235
min_ref_time_values_not_in_x <- min_ref_time_values[!(min_ref_time_values %in% unique(x$time_value))]
235236

236-
# Do set up to let us recover `ref_time_value`s later.
237-
# A helper column marking real observations.
238-
x$.real <- TRUE
239-
240-
# Create df containing phony data. Df has the same columns and attributes as
241-
# `x`, but filled with `NA`s aside from grouping columns. Number of rows is
242-
# equal to the number of `min_ref_time_values_not_in_x` we have * the
243-
# number of unique levels seen in the grouping columns.
244-
before_time_values_df <- data.frame(time_value = min_ref_time_values_not_in_x)
245-
if (length(group_vars(x)) != 0) {
246-
before_time_values_df <- dplyr::cross_join(
247-
# Get unique combinations of grouping columns seen in real data.
248-
unique(x[, group_vars(x)]),
249-
before_time_values_df
250-
)
251-
}
252-
# Automatically fill in all other columns from `x` with `NA`s, and carry
253-
# attributes over to new df.
254-
before_time_values_df <- bind_rows(x[0, ], before_time_values_df)
255-
before_time_values_df$.real <- FALSE
256-
257-
x <- bind_rows(before_time_values_df, x)
258-
259237
# Arrange by increasing time_value
260238
x <- arrange(x, time_value)
261239

262240
# Now set up starts and stops for sliding/hopping
263-
time_range <- range(unique(x$time_value))
241+
time_range <- range(unique(c(x$time_value, min_ref_time_values_not_in_x)))
264242
starts <- in_range(ref_time_values - before, time_range)
265243
stops <- in_range(ref_time_values + after, time_range)
266244

@@ -273,7 +251,7 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
273251

274252
# Computation for one group, all time values
275253
slide_one_grp <- function(.data_group,
276-
f, ...,
254+
f_factory, ...,
277255
starts,
278256
stops,
279257
time_values,
@@ -288,6 +266,8 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
288266
stops <- stops[o]
289267
time_values <- time_values[o]
290268

269+
f <- f_factory(starts)
270+
291271
# Compute the slide values
292272
slide_values_list <- slider::hop_index(
293273
.x = .data_group,
@@ -349,7 +329,6 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
349329
# fills with NA equivalent.
350330
vctrs::vec_slice(slide_values, o) <- orig_values
351331
} else {
352-
# This implicitly removes phony (`.real` == FALSE) observations.
353332
.data_group <- filter(.data_group, o)
354333
}
355334
return(mutate(.data_group, !!new_col := slide_values))
@@ -372,15 +351,20 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
372351

373352
f <- as_slide_computation(f, ...)
374353
# Create a wrapper that calculates and passes `.ref_time_value` to the
375-
# computation.
376-
f_wrapper <- function(.x, .group_key, ...) {
377-
.ref_time_value <- min(.x$time_value) + before
378-
.x <- .x[.x$.real, ]
379-
.x$.real <- NULL
380-
f(.x, .group_key, .ref_time_value, ...)
354+
# computation. `i` is contained in the `f_wrapper_factory` environment such
355+
# that when called within `slide_one_grp` `i` is reset for every group.
356+
f_wrapper_factory <- function(starts) {
357+
# Use `i` to advance through list of start dates.
358+
i <- 1L
359+
f_wrapper <- function(.x, .group_key, ...) {
360+
.ref_time_value <- starts[[i]] + before
361+
i <<- i + 1L
362+
f(.x, .group_key, .ref_time_value, ...)
363+
}
364+
return(f_wrapper)
381365
}
382366
x <- group_modify(x, slide_one_grp,
383-
f = f_wrapper, ...,
367+
f_factory = f_wrapper_factory, ...,
384368
starts = starts,
385369
stops = stops,
386370
time_values = ref_time_values,
@@ -394,14 +378,5 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
394378
x <- unnest(x, !!new_col, names_sep = names_sep)
395379
}
396380

397-
# Remove any remaining phony observations. When `all_rows` is TRUE, phony
398-
# observations aren't necessarily removed in `slide_one_grp`.
399-
if (all_rows) {
400-
x <- x[x$.real, ]
401-
}
402-
403-
# Drop helper column `.real`.
404-
x$.real <- NULL
405-
406381
return(x)
407382
}

tests/testthat/test-epi_slide.R

+44
Original file line numberDiff line numberDiff line change
@@ -626,3 +626,47 @@ test_that("`epi_slide` can access objects inside of helper functions", {
626626
NA
627627
)
628628
})
629+
630+
test_that("epi_slide basic behavior is correct when groups have non-overlapping date ranges", {
631+
small_x_misaligned_dates <- dplyr::bind_rows(
632+
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15),
633+
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5))
634+
) %>%
635+
as_epi_df(as_of = d + 6) %>%
636+
group_by(geo_value)
637+
638+
expected_output <- dplyr::bind_rows(
639+
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15, slide_value = cumsum(11:15)),
640+
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5), slide_value = cumsum(-(1:5)))
641+
) %>%
642+
group_by(geo_value) %>%
643+
as_epi_df(as_of = d + 6)
644+
645+
result1 <- epi_slide(small_x_misaligned_dates, f = ~ sum(.x$value), before = 50)
646+
expect_identical(result1, expected_output)
647+
})
648+
649+
650+
test_that("epi_slide gets correct ref_time_value when groups have non-overlapping date ranges", {
651+
small_x_misaligned_dates <- dplyr::bind_rows(
652+
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15),
653+
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5))
654+
) %>%
655+
as_epi_df(as_of = d + 6) %>%
656+
group_by(geo_value)
657+
658+
expected_output <- dplyr::bind_rows(
659+
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15, slide_value = d + 1:5),
660+
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5), slide_value = d + 151:155)
661+
) %>%
662+
group_by(geo_value) %>%
663+
as_epi_df(as_of = d + 6)
664+
665+
result1 <- small_x_misaligned_dates %>%
666+
epi_slide(
667+
before = 50,
668+
slide_value = .ref_time_value
669+
)
670+
671+
expect_identical(result1, expected_output)
672+
})

0 commit comments

Comments
 (0)