Skip to content

Commit d542be2

Browse files
committed
various suggestions from logan, before=n_points-1
1 parent d6d4cdd commit d542be2

10 files changed

+75
-50
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,4 @@ importFrom(tidyr,drop_na)
113113
importFrom(tidyr,expand_grid)
114114
importFrom(tidyr,pivot_wider)
115115
importFrom(tidyr,unnest)
116+
importFrom(zeallot,"%<-%")

R/data_transforms.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ rolling_mean <- function(epi_data, width = 7L, cols_to_mean = NULL) {
7171
epi_data %<>% group_by(geo_value)
7272
for (col in cols_to_mean) {
7373
mean_name <- paste0(col, "_m", width)
74-
epi_data %<>% epi_slide(~ mean(.x[[col]]), before = width-1L, new_col_name = mean_name)
74+
epi_data %<>% epi_slide(~ mean(.x[[col]], rm.na = TRUE), before = width-1L, new_col_name = mean_name)
7575
}
7676
epi_data %<>% ungroup()
7777
return(epi_data)
@@ -102,8 +102,8 @@ rolling_sd <- function(epi_data, sd_width = 28L, mean_width = NULL, cols_to_sd =
102102
result %<>% group_by(geo_value)
103103
mean_name <- paste0(col, "_m", mean_width)
104104
sd_name <- paste0(col, "_sd", sd_width)
105-
result %<>% epi_slide(~ mean(.x[[col]]), before = mean_width-1L, new_col_name = mean_name)
106-
result %<>% epi_slide(~ sqrt(mean((.x[[mean_name]] - .x[[col]])^2)), before = sd_width-1, new_col_name = sd_name)
105+
result %<>% epi_slide(~ mean(.x[[col]], na.rm = TRUE), before = mean_width-1L, new_col_name = mean_name)
106+
result %<>% epi_slide(~ sqrt(mean((.x[[mean_name]] - .x[[col]])^2, na.rm = TRUE)), before = sd_width-1, new_col_name = sd_name)
107107
if (!keep_mean) {
108108
# TODO make sure the extra info sticks around
109109
result %<>% select(-{{ mean_name }})

R/forecaster_scaled_pop.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ scaled_pop <- function(epi_data,
7373
args_input[["ahead"]] <- effective_ahead
7474
args_input[["quantile_levels"]] <- quantile_levels
7575
args_list <- do.call(arx_args_list, args_input)
76-
# if you want to ignore extra_sources, setting predictors is the way to do it
76+
# if you want to hardcode particular predictors in a particular forecaster
7777
predictors <- c(outcome, extra_sources)
7878
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
7979
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)

R/forecaster_smoothed_scaled.R

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,17 @@
1414
#' to the max time value of the `epi_df`. how to handle this is a modelling
1515
#' question left up to each forecaster; see latency_adjusting.R for the
1616
#' existing examples)
17-
#' @param pop_scaling an example extra parameter unique to this forecaster
18-
#' @param trainer an example extra parameter that is fairly common
17+
#' @param pop_scaling bool; if `TRUE`, assume all numeric columns are on the
18+
#' count scale and translate them to a rate scale for model fitting.
19+
#' Predictions will be translated back to count scale. Any
20+
#' `layer_residual_quantiles` (for non-`"quantile_reg"` `trainer`s) will be
21+
#' done on the rate scale. When specifying predictor lags, note that rate
22+
#' variables will use the same names as and overwrite the count variables.
23+
#' Rates here will be counts per 100k population, based on
24+
#' `epipredict::state_census`.
25+
#' @param trainer optional; parsnip model specification to use for the core
26+
#' fitting & prediction (the `spec` of the internal
27+
#' [`epipredict::epi_workflow`]). Default is `parsnip::linear_reg()`.
1928
#' @param smooth_width the number of days over which to do smoothing. If `NULL`,
2029
#' then no smoothing is applied.
2130
#' @param smooth_cols the names of the columns to smooth. If `NULL` it smooths
@@ -34,57 +43,52 @@
3443
#' @importFrom epipredict epi_recipe step_population_scaling frosting arx_args_list layer_population_scaling
3544
#' @importFrom tibble tibble
3645
#' @importFrom recipes all_numeric
46+
#' @importFrom zeallot %<-%
3747
#' @export
3848
smoothed_scaled <- function(epi_data,
39-
outcome,
40-
extra_sources = "",
41-
ahead = 1,
42-
pop_scaling = TRUE,
43-
trainer = parsnip::linear_reg(),
44-
quantile_levels = covidhub_probs(),
45-
smooth_width = 7,
46-
smooth_cols = NULL,
47-
sd_width = 28,
48-
sd_mean_width = 14,
49-
sd_cols = NULL,
50-
...) {
49+
outcome,
50+
extra_sources = "",
51+
ahead = 1,
52+
pop_scaling = TRUE,
53+
trainer = parsnip::linear_reg(),
54+
quantile_levels = covidhub_probs(),
55+
smooth_width = 7,
56+
smooth_cols = NULL,
57+
sd_width = 28,
58+
sd_mean_width = 14,
59+
sd_cols = NULL,
60+
...) {
5161
# perform any preprocessing not supported by epipredict
5262
# this is a temp fix until a real fix gets put into epipredict
5363
epi_data <- clear_lastminute_nas(epi_data)
5464
# one that every forecaster will need to handle: how to manage max(time_value)
5565
# that's older than the `as_of` date
56-
epidataAhead <- extend_ahead(epi_data, ahead)
66+
c(epi_data, effective_ahead) %<-% extend_ahead(epi_data, ahead)
5767
# see latency_adjusting for other examples
58-
# this next part is basically unavoidable boilerplate you'll want to copy
59-
epi_data <- epidataAhead[[1]]
60-
effective_ahead <- epidataAhead[[2]]
6168
args_input <- list(...)
6269
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
6370
if (!confirm_sufficient_data(epi_data, effective_ahead, args_input)) {
64-
null_result <- tibble(
65-
geo_value = character(),
66-
forecast_date = lubridate::Date(),
67-
target_end_date = lubridate::Date(),
68-
quantile = numeric(),
69-
value = numeric()
70-
)
71+
null_result <- epi_data[0L, c("geo_value", attr(epi_data, "metadata", exact = TRUE)[["other_keys"]])] %>%
72+
mutate(
73+
forecast_date = epi_data$time_value[0],
74+
target_end_date = epi_data$time_value[0],
75+
quantile = numeric(),
76+
value = numeric()
77+
)
7178
return(null_result)
7279
}
7380
args_input[["ahead"]] <- effective_ahead
7481
args_input[["quantile_levels"]] <- quantile_levels
7582
args_list <- do.call(arx_args_list, args_input)
76-
# if you want to ignore extra_sources, setting predictors is the way to do it
83+
# `extra_sources` sets which variables beyond the outcome are lagged and used as predictors
84+
# any which are modified by `rolling_mean` or `rolling_sd` have their original values dropped later
7785
predictors <- c(outcome, extra_sources)
78-
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
79-
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)
80-
args_list <- argsPredictorsTrainer[[1]]
81-
predictors <- argsPredictorsTrainer[[2]]
82-
trainer <- argsPredictorsTrainer[[3]]
8386
# end of the copypasta
8487
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
8588
# epipredict
8689
# smoothing
87-
keep_mean <- (smooth_width == sd_mean_width) # do we need to do the mean separately?
90+
keep_mean <- !is.null(smooth_width) && !is.null(sd_mean_width) &&
91+
smooth_width == sd_mean_width # do we (not) need to do the mean separately?
8892
if (!is.null(smooth_width) && !keep_mean) {
8993
epi_data %<>% rolling_mean(
9094
width = smooth_width,
@@ -101,8 +105,10 @@ smoothed_scaled <- function(epi_data,
101105
keep_mean = keep_mean
102106
)
103107
}
104-
# and need to make sure we exclude the original varialbes as predictors
108+
# and need to make sure we exclude the original variables as predictors
105109
predictors <- update_predictors(epi_data, c(smooth_cols, sd_cols), predictors)
110+
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
111+
c(args_list, predictors, trainer) %<-% perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)
106112
# preprocessing supported by epipredict
107113
preproc <- epi_recipe(epi_data)
108114
if (pop_scaling) {

R/targets_utils.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ make_shared_grids <- function() {
133133
forecaster = "scaled_pop",
134134
trainer = c("linreg", "quantreg"),
135135
ahead = c(1:7, 14, 21, 28),
136-
lags = list(c(0, 3, 5, 7, 14), c(0, 7, 14)),
136+
lags = list(c(0, 3, 5, 7, 14), c(0, 7, 14), c(0,7,14,24)),
137137
pop_scaling = c(FALSE)
138138
),
139139
tidyr::expand_grid(
@@ -144,7 +144,7 @@ make_shared_grids <- function() {
144144
forecaster = "smoothed_scaled",
145145
trainer = c("quantreg"),
146146
ahead = c(1:7, 14, 21, 28),
147-
lags = list(list(c(0, 3, 5, 7, 14), c(0),), c(0, 7, 14)),
147+
lags = list(list(c(0, 3, 5, 7, 14), c(0),c(0, 3, 5, 7, 14), c(0),), c(0, 7, 14), c(0,2,4,7,14,21,28)),
148148
pop_scaling = c(FALSE)
149149
)
150150
)

man/get_trainable_names.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/smoothed_scaled.Rd

Lines changed: 11 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/update_predictors.Rd

Lines changed: 9 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-forecasters-basics.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ forecasters <- list(
55
c("flatline_fc", flatline_fc),
66
c("smoothed_scaled", smoothed_scaled)
77
)
8-
forecaster <- forecasters[[3]]
98
for (forecaster in forecasters) {
109
test_that(paste(forecaster[[1]], "gets the date and columns right"), {
1110
jhu <- epipredict::case_death_rate_subset %>%

tests/testthat/test-transforms.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,19 @@ test_that("rolling_sd generates correct standard deviation", {
5252
rolled <- rolling_sd(epi_data, keep_mean = TRUE)
5353
expect_equal(names(rolled), c("geo_value", "time_value", "a", "b", "a_m14", "a_sd28", "b_m14", "b_sd28"))
5454
# hand specified rolling mean with a rear window of 7, noting that mean(1:14) = 7.5
55-
linear_roll_mean <- c(seq(from = 1, to = 7.5, by = .5), seq(from = 8.5, to = 16.5, by = 1), seq(from = 17, to = 32, by = 1))
56-
linear_roll_mean
55+
linear_roll_mean <- c(seq(from = 1, to = 7, by = .5), seq(from = 8, to = 16, by = 1), seq(from = 16.5, to = 32.5, by = 1))
56+
## linear_roll_mean <- c(seq(from = 1, by = .5, length.out = 14), seq(from = 8.5, to = 32.5, by = 1))
57+
## gap_starts <- epi_data %>% filter(geo_value == "al" & time_value == as.Date("2012-01-11")) %>% pull(a)
58+
## unusual_days <- map_vec(seq(from = 0, to = 5), \(d) mean(((gap_starts + d) - 0):max((gap_starts + d) - 14, 1)))
59+
## map(seq(from = 0, to = 5), \(d) mean(((gap_starts + d) - 0):max((gap_starts + d) - 13, 1)))
60+
## linear_roll_mean
61+
## rolled %>% filter(geo_value == "al") %>% pull("a_m14")
5762
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_m14"), linear_roll_mean)
5863
# and the standard deviation is
5964
linear_roll_mean <- append(linear_roll_mean, NA, after = removed_date - 1)
6065
linear_values <- 1:39
6166
linear_values <- append(linear_values, NA, after = removed_date - 1)
62-
linear_roll_sd <- sqrt(slider::slide_dbl((linear_values - linear_roll_mean)^2, \(x) mean(x, na.rm = TRUE), .before = 28))
67+
linear_roll_sd <- sqrt(slider::slide_dbl((linear_values - linear_roll_mean)^2, \(x) mean(x, na.rm = TRUE), .before = 28 - 1))
6368
# drop the extra date caused by the inclusion of the NAs
6469
linear_roll_sd <- linear_roll_sd[-(removed_date)]
6570
expect_equal(rolled %>% filter(geo_value == "al") %>% pull("a_sd28"), linear_roll_sd)

0 commit comments

Comments
 (0)