Skip to content

Commit d6d4cdd

Browse files
committed
fix before behavior: mean tests simplified
1 parent c4c7430 commit d6d4cdd

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

R/data_transforms.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ rolling_mean <- function(epi_data, width = 7L, cols_to_mean = NULL) {
7171
epi_data %<>% group_by(geo_value)
7272
for (col in cols_to_mean) {
7373
mean_name <- paste0(col, "_m", width)
74-
epi_data %<>% epi_slide(~ mean(.x[[col]]), before = width, new_col_name = mean_name)
74+
epi_data %<>% epi_slide(~ mean(.x[[col]]), before = width-1L, new_col_name = mean_name)
7575
}
7676
epi_data %<>% ungroup()
7777
return(epi_data)
@@ -102,8 +102,8 @@ rolling_sd <- function(epi_data, sd_width = 28L, mean_width = NULL, cols_to_sd =
102102
result %<>% group_by(geo_value)
103103
mean_name <- paste0(col, "_m", mean_width)
104104
sd_name <- paste0(col, "_sd", sd_width)
105-
result %<>% epi_slide(~ mean(.x[[col]]), before = mean_width, new_col_name = mean_name)
106-
result %<>% epi_slide(~ sqrt(mean((.x[[mean_name]] - .x[[col]])^2)), before = sd_width, new_col_name = sd_name)
105+
result %<>% epi_slide(~ mean(.x[[col]]), before = mean_width-1L, new_col_name = mean_name)
106+
result %<>% epi_slide(~ sqrt(mean((.x[[mean_name]] - .x[[col]])^2)), before = sd_width-1, new_col_name = sd_name)
107107
if (!keep_mean) {
108108
# TODO make sure the extra info sticks around
109109
result %<>% select(-{{ mean_name }})

tests/testthat/test-transforms.R

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,44 @@ n_days <- 40
33
removed_date <- 10
44
simple_dates <- seq(as.Date("2012-01-01"), by = "day", length.out = n_days)
55
simple_dates <- simple_dates[-removed_date]
6-
rand_vals <- rnorm(n_days-1)
6+
rand_vals <- rnorm(n_days - 1)
77

88
# Two states, with 2 variables. a is linear, going up in one state and down in the other
99
# b is just random
1010
# note that day 10 is missing
1111
epi_data <- epiprocess::as_epi_df(rbind(tibble(
1212
geo_value = "al",
1313
time_value = simple_dates,
14-
a = 1:(n_days-1),
14+
a = 1:(n_days - 1),
1515
b = rand_vals
1616
), tibble(
1717
geo_value = "ca",
1818
time_value = simple_dates,
19-
a = (n_days-1):1,
19+
a = (n_days - 1):1,
2020
b = rand_vals + 10
2121
)))
2222
test_that("rolling_mean generates correct mean", {
2323
rolled <- rolling_mean(epi_data)
2424
rolled
2525
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_m7", "b_m7"))
2626
# hand specified rolling mean with a rear window of 7, noting that mean(1:7) = 4
27-
linear_roll_mean <- c(seq(from = 1, to = 4, by = .5), seq(from = 4.5, to = 35.5, by = 1))
28-
# day 10 is missing, so days 11-18 are thrown off
29-
lag_st <- 10
30-
unusual_days <- c(mean(c((lag_st):(lag_st-6))), mean(c((lag_st+1):(lag_st+1-6))), mean(c((lag_st+2):(lag_st+2-6))), mean(c((lag_st+3):(lag_st+3-6))), mean(c((lag_st+4):(lag_st+4-6))), mean(c((lag_st+5):(lag_st+5-6))), mean(c((lag_st+6):(lag_st+6-6))))
27+
linear_roll_mean <- c(seq(from = 1, by = .5, length.out = 7), seq(from = 5, to = 36, by = 1))
28+
# day 10 is missing, so the average days 11-16 are thrown off, only using 6 values instead of 7
29+
gap_starts <- epi_data %>% filter(geo_value == "al" & time_value == as.Date("2012-01-11")) %>% pull(a)
30+
unusual_days <- map_vec(seq(from = 0, to = 5), \(d) mean(((gap_starts + d) - 0):((gap_starts + d) - 5)))
3131
# stitching the lag induced hiccup into the "normal" mean values
32-
expected_mean <- c(linear_roll_mean[1:9], unusual_days, linear_roll_mean[17:(n_days-1)])
32+
expected_mean <- c(linear_roll_mean[1:9], unusual_days, linear_roll_mean[16:(n_days - 1)])
33+
expected_mean
3334

3435
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_m7"), expected_mean)
36+
# Doing the same for California
3537
# same, but "ca" is reversed, noting mean(40:(40-7)) =36.5
36-
linear_reverse_roll_mean <- c(seq(from = 39, to = 35.5, by = -0.5), seq(from = 34.5, to = 4.5, by = -1))
37-
lag_st <- 36
38-
# day 10 is missing, so days 11-18 are thrown off
39-
unusual_days <- c(mean(c((lag_st):(lag_st-6))), mean(c((lag_st-1):(lag_st-1-6))), mean(c((lag_st-2):(lag_st-2-6))), mean(c((lag_st-3):(lag_st-3-6))), mean(c((lag_st-4):(lag_st-4-6))), mean(c((lag_st-5):(lag_st-5-6))), mean(c((lag_st-6):(lag_st-6-6))))
38+
linear_reverse_roll_mean <- c(seq(from = 39, by = -0.5, length.out = 7), seq(from = 35, to = 4, by = -1))
39+
# day 10 is missing, so days 11-16 are thrown off
40+
gap_starts <- epi_data %>% filter(geo_value == "ca" & time_value == as.Date("2012-01-11")) %>% pull(a)
41+
unusual_days <- map_vec(seq(from = 0, to = 5), \(d) mean(((gap_starts - d) + 0):((gap_starts - d) + 5)))
4042
# stitching the lag induced hiccup into the "normal" mean values
41-
expected_mean <- c(linear_reverse_roll_mean[1:9], unusual_days, linear_reverse_roll_mean[17:(n_days-1)])
43+
expected_mean <- c(linear_reverse_roll_mean[1:9], unusual_days, linear_reverse_roll_mean[16:(n_days - 1)])
4244
# actually testing
4345
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a_m7"), expected_mean)
4446

@@ -47,16 +49,16 @@ test_that("rolling_mean generates correct mean", {
4749
})
4850

4951
test_that("rolling_sd generates correct standard deviation", {
50-
rolled <- rolling_sd(epi_data,keep_mean = TRUE)
52+
rolled <- rolling_sd(epi_data, keep_mean = TRUE)
5153
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_m14", "a_sd28", "b_m14", "b_sd28"))
5254
# hand specified rolling mean with a rear window of 7, noting that mean(1:14) = 7.5
5355
linear_roll_mean <- c(seq(from = 1, to = 7.5, by = .5), seq(from = 8.5, to = 16.5, by = 1), seq(from = 17, to = 32, by = 1))
5456
linear_roll_mean
5557
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_m14"), linear_roll_mean)
5658
# and the standard deviation is
57-
linear_roll_mean <- append(linear_roll_mean, NA, after = removed_date-1)
59+
linear_roll_mean <- append(linear_roll_mean, NA, after = removed_date - 1)
5860
linear_values <- 1:39
59-
linear_values <- append(linear_values, NA, after = removed_date-1)
61+
linear_values <- append(linear_values, NA, after = removed_date - 1)
6062
linear_roll_sd <- sqrt(slider::slide_dbl((linear_values - linear_roll_mean)^2, \(x) mean(x, na.rm = TRUE), .before = 28))
6163
# drop the extra date caused by the inclusion of the NAs
6264
linear_roll_sd <- linear_roll_sd[-(removed_date)]
@@ -75,10 +77,10 @@ test_that("get_trainable_names pulls out mean and sd columns", {
7577
# TODO example with NA's, example with missing days, only one column, keep_mean
7678

7779
test_that("update_predictors keeps unmodified predictors", {
78-
epi_data["c"] = NaN
79-
epi_data["d"] = NaN
80-
epi_data["b_m14"] = NaN
81-
epi_data["b_sd28"] = NaN
80+
epi_data["c"] <- NaN
81+
epi_data["d"] <- NaN
82+
epi_data["b_m14"] <- NaN
83+
epi_data["b_sd28"] <- NaN
8284
predictors <- c("a", "b", "c") # everything but d
8385
modified <- c("b", "c") # we want to exclude b but not its modified versions
8486
expected_predictors <- c("a", "b_m14", "b_sd28")

0 commit comments

Comments
 (0)