Skip to content

Commit 961f4e5

Browse files
authored
Merge pull request #134 from cmu-delphi/133-fix-residuals
133 fix residuals
2 parents fc0af3f + 7c37ba6 commit 961f4e5

13 files changed

+128
-39
lines changed

DESCRIPTION

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ URL: https://github.com/cmu-delphi/epipredict/,
2121
https://cmu-delphi.github.io/epipredict
2222
BugReports: https://github.com/cmu-delphi/epipredict/issues/
2323
Depends:
24-
R (>= 3.5.0)
24+
R (>= 3.5.0),
25+
epiprocess
2526
Imports:
2627
assertthat,
2728
cli,
2829
distributional,
2930
dplyr,
30-
epiprocess,
3131
fs,
3232
generics,
3333
glue,

NAMESPACE

-1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,5 @@ importFrom(stats,predict)
137137
importFrom(stats,qnorm)
138138
importFrom(stats,quantile)
139139
importFrom(stats,residuals)
140-
importFrom(stats,setNames)
141140
importFrom(tibble,is_tibble)
142141
importFrom(tibble,tibble)

R/arx_forecaster.R

+5-5
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
#' out <- arx_forecaster(jhu, "death_rate",
2828
#' c("case_rate", "death_rate"))
2929
arx_forecaster <- function(epi_data,
30-
outcome,
31-
predictors,
32-
trainer = parsnip::linear_reg(),
33-
args_list = arx_args_list()) {
30+
outcome,
31+
predictors,
32+
trainer = parsnip::linear_reg(),
33+
args_list = arx_args_list()) {
3434

3535
validate_forecaster_inputs(epi_data, outcome, predictors)
3636
if (!is.list(trainer) || trainer$mode != "regression")
@@ -60,7 +60,7 @@ arx_forecaster <- function(epi_data,
6060
layer_add_target_date(target_date = target_date)
6161
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
6262

63-
latest <- get_test_data(r, epi_data)
63+
latest <- get_test_data(r, epi_data, TRUE)
6464

6565
wf <- epi_workflow(r, trainer, f) %>% generics::fit(epi_data)
6666
list(

R/get_test_data.R

+45-12
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,25 @@
66
#' and other variables in the original dataset,
77
#' which will be used to create test data.
88
#'
9+
#' It also optionally fills missing values
10+
#' using the last-observation-carried-forward (LOCF) method. If this
11+
#' is not possible (say because there would be only `NA`'s in some location),
12+
#' it will produce an error suggesting alternative options to handle missing
13+
#' values with more advanced techniques.
14+
#'
915
#' @param recipe A recipe object. The step will be added to the
1016
#' sequence of operations for this recipe.
1117
#' @param x A data frame, tibble, or epi_df data set.
18+
#' @param fill_locf Logical. Should we use `locf` to fill in missing data?
19+
#' @param n_recent Integer or NULL. If filling missing data with `locf=TRUE`,
20+
#' how far back are we willing to tolerate missing data? Larger values allow
21+
#' more filling. The default `NULL` will determine this from the maximum
22+
#' lags used in the `recipe`. For example, suppose n_recent = 3, then if the
23+
#' 3 most recent observations in some region are all `NA`’s, we won’t be able
24+
#' to fill anything, and an error message will be thrown.
1225
#'
13-
#' @return A tibble with columns `geo_value`, `time_value`
14-
#' and other variables in the original dataset.
26+
#' @return A tibble with columns `geo_value`, `time_value`, any additional
27+
#' keys, as well other variables in the original dataset.
1528
#' @examples
1629
#' # create recipe
1730
#' rec <- epi_recipe(case_death_rate_subset) %>%
@@ -21,11 +34,17 @@
2134
#' get_test_data(recipe = rec, x = case_death_rate_subset)
2235
#' @export
2336

24-
get_test_data <- function(recipe, x) {
37+
get_test_data <- function(recipe, x, fill_locf = FALSE, n_recent = NULL) {
2538
stopifnot(is.data.frame(x))
26-
if (! all(colnames(x) %in% colnames(recipe$template)))
39+
arg_is_lgl(fill_locf)
40+
arg_is_scalar(fill_locf)
41+
arg_is_pos_int(n_recent, allow_null = TRUE)
42+
arg_is_scalar(n_recent, allow_null = TRUE)
43+
44+
if (!all(colnames(x) %in% colnames(recipe$template)))
2745
cli_stop("some variables used for training are not available in `x`.")
2846
max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0)
47+
if (is.null(n_recent)) n_recent <- max_lags + 1
2948

3049
# CHECK: Return NA if insufficient training data
3150
if (dplyr::n_distinct(x$time_value) < max_lags) {
@@ -36,15 +55,29 @@ get_test_data <- function(recipe, x) {
3655
groups <- epi_keys(recipe)
3756
groups <- groups[groups != "time_value"]
3857

39-
x %>%
40-
dplyr::filter(
41-
dplyr::if_any(
42-
.cols = recipe$term_info$variable[which(recipe$var_info$role == 'raw')],
43-
.fns = ~ !is.na(.x)
44-
)
45-
) %>%
58+
x <- x %>%
4659
epiprocess::group_by(dplyr::across(dplyr::all_of(groups))) %>%
60+
dplyr::slice_tail(n = max(n_recent, max_lags + 1))
61+
62+
if (fill_locf) {
63+
cannot_be_used <- x %>%
64+
dplyr::slice_tail(n = n_recent) %>%
65+
dplyr::summarize(dplyr::across(
66+
!time_value, ~ !is.na(.x[1])), .groups = "drop") %>%
67+
dplyr::summarise(dplyr::across(-dplyr::all_of(groups), ~ any(!.x))) %>%
68+
unlist()
69+
if (any(cannot_be_used)) {
70+
bad_vars <- names(cannot_be_used)[cannot_be_used]
71+
cli_stop("The variables {bad_vars} have ",
72+
"too many recent missing values to be filled automatically. ",
73+
"You should either choose `n_recent` larger than its current ",
74+
"value {n_recent}, or perform NA imputation manually, perhaps with ",
75+
"{.code recipes::step_impute_*()} or with {.code tidyr::fill()}.")
76+
}
77+
x <- x %>% tidyr::fill(!time_value)
78+
}
79+
80+
x %>%
4781
dplyr::slice_tail(n = max_lags + 1) %>%
4882
epiprocess::ungroup()
49-
5083
}

R/layer_residual_quantiles.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#' @param probs numeric vector of probabilities with values in (0,1)
66
#' referring to the desired quantile.
77
#' @param symmetrize logical. If `TRUE` then interval will be symmetric.
8-
#' @param by_key A character vector of keys to group the residuls by before
8+
#' @param by_key A character vector of keys to group the residuals by before
99
#' calculating quantiles. The default, `c()` performs no grouping.
1010
#' @param name character. The name for the output column.
1111
#' @param .flag a logical to determine if the layer is added. Passed on to
@@ -45,7 +45,7 @@
4545
#'
4646
#' p2 <- predict(wf2, latest)
4747
layer_residual_quantiles <- function(frosting, ...,
48-
probs = c(0.0275, 0.975),
48+
probs = c(0.05, 0.95),
4949
symmetrize = TRUE,
5050
by_key = character(0L),
5151
name = ".pred_distn",

R/utils_arg.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ arg_is_pos_int = function(..., allow_null = FALSE) {
5959
...,
6060
tests = function(name, value) {
6161
if (!((is.numeric(value) && all(value > 0) && all(value%%1 == 0)) |
62-
(is.null(value) & !allow_null)))
62+
(is.null(value) & allow_null)))
6363
cli_stop("All {.val {name}} must be whole positive number(s).")
6464
}
6565
)

man/bake.Rd

+6-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/epi_juice.Rd

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/epi_workflow.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/flatline.Rd

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_test_data.Rd

+19-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_residual_quantiles.Rd

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-get_test_data.R

+39-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ test_that("return expected number of rows and returned dataset is ungrouped", {
1010
test <- get_test_data(recipe = r, x = case_death_rate_subset)
1111

1212
expect_equal(nrow(test),
13-
dplyr::n_distinct(case_death_rate_subset$geo_value)* 29)
13+
dplyr::n_distinct(case_death_rate_subset$geo_value) * 29)
1414

1515
expect_false(dplyr::is.grouped_df(test))
1616
})
@@ -39,3 +39,41 @@ test_that("expect error that geo_value or time_value does not exist", {
3939
expect_error(get_test_data(recipe = r, x = wrong_epi_df))
4040
})
4141

42+
43+
test_that("NA fill behaves as desired", {
44+
df <- tibble::tibble(
45+
geo_value = rep(c("ca", "ny"), each = 10),
46+
time_value = rep(1:10, times = 2),
47+
x1 = rnorm(20),
48+
x2 = rnorm(20)) %>%
49+
epiprocess::as_epi_df()
50+
51+
r <- epi_recipe(df) %>%
52+
step_epi_ahead(x1, ahead = 3) %>%
53+
step_epi_lag(x1, x2, lag = c(1,3)) %>%
54+
step_epi_naomit()
55+
56+
expect_silent(tt <- get_test_data(r, df))
57+
expect_s3_class(tt, "epi_df")
58+
59+
expect_error(get_test_data(r, df, "A"))
60+
expect_error(get_test_data(r, df, TRUE, -3))
61+
62+
df2 <- df
63+
df2$x1[df2$geo_value == "ca"] <- NA
64+
65+
td <- get_test_data(r, df2)
66+
expect_true(any(is.na(td)))
67+
expect_error(get_test_data(r, df2, TRUE))
68+
69+
df1 <- df2
70+
df1$x1[1:4] <- 1:4
71+
td1 <- get_test_data(r, df1, TRUE, n_recent = 7)
72+
expect_true(!any(is.na(td1)))
73+
74+
df2$x1[7:8] <- 1:2
75+
td2 <- get_test_data(r, df2, TRUE)
76+
expect_true(!any(is.na(td2)))
77+
78+
79+
})

0 commit comments

Comments
 (0)