Skip to content

Commit e4b3d45

Browse files
committed
feat: rolling_mean/sd for a new forecaster
1 parent 77d7e5e commit e4b3d45

File tree

8 files changed

+248
-0
lines changed

8 files changed

+248
-0
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Imports:
2727
purrr,
2828
recipes (>= 1.0.4),
2929
rlang,
30+
slider,
3031
targets,
3132
tibble,
3233
tidyr

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ export(make_target_param_grid)
3030
export(overprediction)
3131
export(perform_sanity_checks)
3232
export(read_external_predictions_data)
33+
export(rolling_mean)
34+
export(rolling_sd)
3335
export(run_evaluation_measure)
3436
export(run_workflow_and_format)
3537
export(scaled_pop)
3638
export(sharpness)
3739
export(single_id)
3840
export(slide_forecaster)
41+
export(smooth_scaled)
3942
export(underprediction)
4043
export(weighted_interval_score)
4144
importFrom(assertthat,assert_that)
@@ -96,6 +99,8 @@ importFrom(rlang,.data)
9699
importFrom(rlang,quo)
97100
importFrom(rlang,sym)
98101
importFrom(rlang,syms)
102+
importFrom(slider,slide2_dbl)
103+
importFrom(slider,slide_dbl)
99104
importFrom(targets,tar_config_get)
100105
importFrom(targets,tar_group)
101106
importFrom(targets,tar_read)

R/data_transforms.R

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# various reusable transforms to apply before handing to epipredict
2+
3+
#' extract the non-key columns from epi_data
4+
#' @keywords internal
5+
#' @param epi_data the epi_data tibble
6+
#' @param cols vector of column names to use. If `NULL`, fill with all non-key columns
7+
get_trainable_names <- function(epi_data, cols) {
8+
if (is.null(cols)) {
9+
cols <- names(epi_data)
10+
cols <- cols[!(cols %in% c("geo_value", "time_value", attr(epi_data, "metadata")$other_keys))]
11+
}
12+
return(cols)
13+
}
14+
15+
#' get a rolling average for the named columns
16+
#' @description
17+
#' add column(s) that are the rolling means of the specified columns, as
18+
#' implemented by slider. Defaults to the previous 7 days.
19+
#' Currently only group_by's on the geo_value. Should probably extend to more
20+
#' keys if you have them
21+
#' @param epi_data the dataset
22+
#' @param width the number of days (or examples, the sliding isn't time-aware) to use
23+
#' @param cols_to_mean the non-key columns to take the mean over. `NULL` means all
24+
#' @importFrom slider slide_dbl
25+
#' @export
26+
rolling_mean <- function(epi_data, width = 7L, cols_to_mean = NULL) {
27+
cols_to_mean <- get_trainable_names(epi_data, cols_to_mean)
28+
epi_data %<>% group_by(geo_value)
29+
for (col in cols_to_mean) {
30+
mean_name <- paste0(col, width)
31+
epi_data %<>% mutate({{ mean_name }} := slider::slide_dbl(.data[[col]], mean, .before = width))
32+
}
33+
epi_data %<>% ungroup()
34+
return(epi_data)
35+
}
36+
37+
#' get a rolling standard deviation for the named columns
38+
#' @description
39+
#' A rolling standard deviation, based off of a rolling mean. First it
40+
#' calculates a rolling mean with width `mean_width`, and then squares the
41+
#' difference between that and the actual value, averaged over `sd_width`.
42+
#' @param epi_data the dataset
43+
#' @param sd_width the number of days (or examples, the sliding isn't
44+
#' time-aware) to use for the standard deviation calculation
45+
#' @param mean_width like `sd_width`, but it governs the mean. Should be less
46+
#' than the `sd_width`, and if `NULL` (the default) it is half of `sd_width`
47+
#' (so 14 in the complete default case)
48+
#' @param cols_to_sd the non-key columns to take the sd over. `NULL` means all
49+
#' @param keep_mean bool, if `TRUE`, it retains keeps the mean column
50+
#' @importFrom slider slide_dbl slide2_dbl
51+
#' @export
52+
rolling_sd <- function(epi_data, sd_width = 28L, mean_width = NULL, cols_to_sd = NULL, keep_mean = FALSE) {
53+
if (is.null(mean_width)) {
54+
mean_width <- as.integer(ceiling(sd_width / 2))
55+
}
56+
cols_to_sd <- get_trainable_names(epi_data, cols_to_sd)
57+
epi_data %<>% group_by(geo_value)
58+
for (col in cols_to_sd) {
59+
mean_name <- paste0(col, "_m", mean_width)
60+
sd_name <- paste0(col, "_SD", sd_width)
61+
epi_data %<>% mutate({{ mean_name }} := slider::slide_dbl(.data[[col]], mean, .before = mean_width))
62+
epi_data %<>% mutate({{ sd_name }} := slider::slide2_dbl(.data[[col]], .data[[mean_name]], ~ sqrt(mean((.x - .y)^2)), .before = sd_width))
63+
if (!keep_mean) {
64+
epi_data %<>% select(-{{ mean_name }})
65+
}
66+
}
67+
epi_data %<>% ungroup()
68+
return(epi_data)
69+
}

man/get_trainable_names.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rolling_mean.Rd

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rolling_sd.Rd

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/smooth_scaled.Rd

Lines changed: 65 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-transforms.R

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
n_days <- 40
2+
simple_dates <- seq(as.Date("2012-01-01"), by = "day", length.out = n_days)
3+
rand_vals <- rnorm(n_days)
4+
epi_data <- epiprocess::as_epi_df(rbind(tibble(
5+
geo_value = "al",
6+
time_value = simple_dates,
7+
a = 1:n_days,
8+
b = rand_vals
9+
), tibble(
10+
geo_value = "ca",
11+
time_value = simple_dates,
12+
a = n_days:1,
13+
b = rand_vals + 10
14+
)))
15+
test_that("rolling_mean generates correct mean", {
16+
rolled <- rolling_mean(epi_data)
17+
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a7", "b7"))
18+
# hand specified rolling mean with a rear window of 7, noting that mean(1:7) = 4
19+
linear_roll_mean <- c(seq(from=1, to = 4, by = .5), seq(from = 4.5, to = 36.5, by = 1))
20+
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a7"), linear_roll_mean)
21+
# same, but "ca" is reversed, noting mean(40:(40-7)) =36.5
22+
linear_reverse_roll_mean <- c(seq(from=40, to = 36.5, by = -0.5), seq(from = 35.5, to = 4.5, by = -1))
23+
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a7"), linear_reverse_roll_mean)
24+
})
25+
26+
test_that("rolling_sd generates correct standard deviation", {
27+
rolled <- rolling_sd(epi_data)
28+
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_SD28", "b_SD28"))
29+
# hand specified rolling mean with a rear window of 7, noting that mean(1:14) = 7.5
30+
linear_roll_mean <- c(seq(from=1, to = 7.5, by = .5), seq(from = 8, to = 33, by = 1))
31+
# and the standard deviation is
32+
linear_roll_sd <- sqrt(slider::slide_dbl((1:40 - linear_roll_mean)^2, mean, .before = 28))
33+
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_SD28"), linear_roll_sd)
34+
# even though ca is reversed, the changes are all the same, so the standard deviation is *exactly* the same values
35+
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a_SD28"), linear_roll_sd)
36+
})
37+
# TODO example with NA's, example with missing days, only one column, keep_mean

0 commit comments

Comments
 (0)