Skip to content

Commit e2353b7

Browse files
committed
smoothed_scaled passes all forecaster tests
1 parent 2f61069 commit e2353b7

10 files changed

+131
-12
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ export(single_id)
4040
export(slide_forecaster)
4141
export(smoothed_scaled)
4242
export(underprediction)
43+
export(update_predictors)
4344
export(weighted_interval_score)
4445
importFrom(assertthat,assert_that)
4546
importFrom(cli,cli_abort)
@@ -88,9 +89,12 @@ importFrom(epiprocess,epix_slide)
8889
importFrom(magrittr,"%<>%")
8990
importFrom(magrittr,"%>%")
9091
importFrom(purrr,imap)
92+
importFrom(purrr,list_modify)
9193
importFrom(purrr,map)
9294
importFrom(purrr,map2_vec)
95+
importFrom(purrr,map_chr)
9396
importFrom(purrr,map_vec)
97+
importFrom(purrr,reduce)
9498
importFrom(purrr,transpose)
9599
importFrom(recipes,all_numeric)
96100
importFrom(rlang,"!!")

R/data_transforms.R

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,48 @@
66
#' @param cols vector of column names to use. If `NULL`, fill with all non-key columns
77
get_trainable_names <- function(epi_data, cols) {
88
if (is.null(cols)) {
9-
cols <- names(epi_data)
10-
cols <- cols[!(cols %in% c("geo_value", "time_value", attr(epi_data, "metadata")$other_keys))]
9+
cols <- get_nonkey_names(epi_data)
1110
# exclude anything with the same naming schema as the rolling average/sd created below
1211
cols <- cols[!grepl("_\\w{1,2}\\d+", cols)]
1312
}
1413
return(cols)
1514
}
1615

16+
#' just the names which aren't keys for an epi_df
17+
#' @description
18+
#' names, but it excludes keys
19+
#' @param epi_data the epi_df
20+
get_nonkey_names <- function(epi_data) {
21+
cols <- names(epi_data)
22+
cols <- cols[!(cols %in% c("geo_value", "time_value", attr(epi_data, "metadata")$other_keys))]
23+
}
24+
25+
26+
#' update the predictors to only contain the smoothed/sd versions of cols
27+
#' @description
28+
#' should only be applied after both rolling_mean and rolling_sd
29+
#' @param epi_data the epi_df
30+
#' @param cols the list of columns
31+
#' @importFrom purrr map map_chr reduce
32+
#' @export
33+
update_predictors <- function(epi_data, cols_modified, predictors) {
34+
if (!is.null(cols_modified)) {
35+
# if cols_modified isn't null, make sure we include predictors that weren't modified
36+
other_predictors <- map(cols_modified, ~ !grepl(.x, predictors)) %>% reduce(`&`)
37+
other_predictors <- predictors[other_predictors]
38+
} else {
39+
other_predictors <- c()
40+
}
41+
# all the non-key names
42+
col_names <- get_nonkey_names(epi_data)
43+
is_present <- function(x) {
44+
grepl(x, col_names) & !(col_names %in% predictors)
45+
}
46+
is_modified <- map(predictors, is_present) %>% reduce(`|`)
47+
new_predictors <- col_names[is_modified]
48+
return(c(other_predictors, new_predictors))
49+
}
50+
1751
#' get a rolling average for the named columns
1852
#' @description
1953
#' add column(s) that are the rolling means of the specified columns, as
@@ -36,6 +70,25 @@ rolling_mean <- function(epi_data, width = 7L, cols_to_mean = NULL) {
3670
return(epi_data)
3771
}
3872

73+
#' store the metadata in a easy to reapply way
74+
#' @importFrom purrr list_modify
75+
cache_metadata <- function(epi_data) {
76+
features <- list()
77+
all_others <- attributes(epi_data)$metadata
78+
all_others["geo_type"] <- NULL
79+
all_others["time_type"] <- NULL
80+
all_others["as_of"] <- NULL
81+
if (length(all_others) == 0) {
82+
all_others <- list()
83+
}
84+
features <- list(
85+
as_of = attributes(epi_data)$metadata$as_of,
86+
geo_type = attributes(epi_data)$metadata$geo_type,
87+
time_type = attributes(epi_data)$metadata$time_type, all_others = all_others
88+
)
89+
return(features)
90+
}
91+
3992
#' get a rolling standard deviation for the named columns
4093
#' @description
4194
#' A rolling standard deviation, based off of a rolling mean. First it
@@ -56,15 +109,18 @@ rolling_sd <- function(epi_data, sd_width = 28L, mean_width = NULL, cols_to_sd =
56109
mean_width <- as.integer(ceiling(sd_width / 2))
57110
}
58111
cols_to_sd <- get_trainable_names(epi_data, cols_to_sd)
112+
metadata <- cache_metadata(epi_data)
59113
epi_data %<>% group_by(geo_value)
60114
for (col in cols_to_sd) {
61115
mean_name <- paste0(col, "_m", mean_width)
62116
sd_name <- paste0(col, "_SD", sd_width)
63117
epi_data %<>% mutate({{ mean_name }} := slider::slide_dbl(.data[[col]], mean, .before = mean_width))
64118
epi_data %<>% mutate({{ sd_name }} := slider::slide2_dbl(.data[[col]], .data[[mean_name]], ~ sqrt(mean((.x - .y)^2)), .before = sd_width))
65119
if (!keep_mean) {
120+
# TODO make sure the extra info sticks around
66121
epi_data %<>% select(-{{ mean_name }})
67122
}
123+
epi_data %<>% as_epi_df(metadata$geo_type, metadata$time_type, metadata$as_of, metadata$all_others)
68124
}
69125
epi_data %<>% ungroup()
70126
return(epi_data)

R/epipredict_utilities.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99
#' @seealso [arx_postprocess] for the layer equivalent
1010
#' @importFrom epipredict step_epi_lag step_epi_ahead step_epi_naomit step_training_window
1111
#' @export
12-
arx_preprocess <- function(rec, outcome, predictors, args_list) {
12+
arx_preprocess <- function(preproc, outcome, predictors, args_list) {
1313
# input already validated
1414
lags <- args_list$lags
1515
for (l in seq_along(lags)) {
1616
p <- predictors[l]
17-
rec %<>% step_epi_lag(!!p, lag = lags[[l]])
17+
preproc %<>% step_epi_lag(!!p, lag = lags[[l]])
1818
}
19-
rec %<>%
19+
preproc %<>%
2020
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
2121
step_epi_naomit() %>%
2222
step_training_window(n_recent = args_list$n_training)
23-
return(rec)
23+
return(preproc)
2424
}
2525

2626
# TODO replace with `layer_arx_forecaster`

R/forecaster_smoothed_scaled.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ smoothed_scaled <- function(epi_data,
100100
keep_mean = keep_mean
101101
)
102102
}
103-
# even
104-
103+
# and need to make sure we exclude the original varialbes as predictors
104+
predictors <- update_predictors(epi_data, c(smooth_cols, sd_cols), predictors)
105105
# preprocessing supported by epipredict
106106
preproc <- epi_recipe(epi_data)
107107
if (pop_scaling) {

man/arx_preprocess.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/cache_metadata.Rd

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_nonkey_names.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/update_predictors.Rd

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-forecasters-basics.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ forecasters <- list(
55
c("flatline_fc", flatline_fc),
66
c("smoothed_scaled", smoothed_scaled)
77
)
8+
forecaster <- forecasters[[3]]
89
for (forecaster in forecasters) {
910
test_that(paste(forecaster[[1]], "gets the date and columns right"), {
1011
jhu <- epipredict::case_death_rate_subset %>%

tests/testthat/test-transforms.R

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ test_that("rolling_mean generates correct mean", {
2121
# same, but "ca" is reversed, noting mean(40:(40-7)) =36.5
2222
linear_reverse_roll_mean <- c(seq(from = 40, to = 36.5, by = -0.5), seq(from = 35.5, to = 4.5, by = -1))
2323
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a_m7"), linear_reverse_roll_mean)
24+
expect_true("epi_df" %in% class(rolled))
2425
})
2526

2627
test_that("rolling_sd generates correct standard deviation", {
@@ -33,10 +34,26 @@ test_that("rolling_sd generates correct standard deviation", {
3334
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_SD28"), linear_roll_sd)
3435
# even though ca is reversed, the changes are all the same, so the standard deviation is *exactly* the same values
3536
expect_equal(rolled %>% filter(geo_value == "ca") %>% pull("a_SD28"), linear_roll_sd)
37+
# doesn't break types
38+
expect_true("epi_df" %in% class(rolled))
3639
})
37-
testthat("get_trainable_names pulls out mean and sd columns", {
40+
41+
test_that("get_trainable_names pulls out mean and sd columns", {
3842
rolled <- rolling_sd(epi_data, keep_mean = TRUE)
3943
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_m14", "a_SD28", "b_m14", "b_SD28"))
4044
expect_equal(get_trainable_names(rolled, NULL), c("a", "b"))
4145
})
4246
# TODO example with NA's, example with missing days, only one column, keep_mean
47+
48+
test_that("update_predictors keeps unmodified predictors", {
49+
epi_data["c"] = NaN
50+
epi_data["d"] = NaN
51+
epi_data["b_m14"] = NaN
52+
epi_data["b_SD28"] = NaN
53+
predictors <- c("a", "b", "c") # everything but d
54+
modified <- c("b", "c") # we want to exclude b but not its modified versions
55+
expected_predictors <- c("a", "b_m14", "b_SD28")
56+
expect_equal(update_predictors(epi_data, modified, predictors), expected_predictors)
57+
expected_if_all_modified <- c("b_m14", "b_SD28")
58+
expect_equal(update_predictors(epi_data, NULL, predictors), expected_if_all_modified)
59+
})

0 commit comments

Comments
 (0)