Skip to content

Commit c5b2568

Browse files
committed
fix drop na subtleties, unify prep/bake usage
A couple of simultaneous problems that were making this tricky: 1. drop_na can completely remove states 2. checking each column individually misses cases where combinations of the states cause the signal to be left out. 3. checking all columns simultaneously doesn't let the user know which columns to check.
1 parent 5ef9823 commit c5b2568

File tree

3 files changed

+109
-64
lines changed

3 files changed

+109
-64
lines changed

R/check_enough_data.R

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,8 @@ prep.check_enough_data <- function(x, training, info = NULL, ...) {
9494
x$n <- length(col_names)
9595
}
9696

97-
if (x$drop_na) {
98-
training <- tidyr::drop_na(training, any_of(unname(col_names)))
99-
}
100-
cols_not_enough_data <- training %>%
101-
group_by(across(all_of(.env$x$epi_keys))) %>%
102-
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$x$n), .groups = "drop") %>%
103-
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
104-
unlist() %>%
105-
names(.)[.]
97+
check_enough_data_core(training, x, col_names, "train")
10698

107-
if (length(cols_not_enough_data) > 0) {
108-
cli_abort(
109-
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
110-
class = "epipredict__not_enough_data"
111-
)
112-
}
11399

114100
check_enough_data_new(
115101
n = x$n,
@@ -127,24 +113,7 @@ prep.check_enough_data <- function(x, training, info = NULL, ...) {
127113
#' @export
128114
bake.check_enough_data <- function(object, new_data, ...) {
129115
col_names <- object$columns
130-
if (object$drop_na) {
131-
non_na_data <- tidyr::drop_na(new_data, any_of(unname(col_names)))
132-
} else {
133-
non_na_data <- new_data
134-
}
135-
cols_not_enough_data <- non_na_data %>%
136-
group_by(across(all_of(.env$object$epi_keys))) %>%
137-
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$object$n), .groups = "drop") %>%
138-
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
139-
unlist() %>%
140-
names(.)[.]
141-
142-
if (length(cols_not_enough_data) > 0) {
143-
cli_abort(
144-
"The following columns don't have enough data to predict: {cols_not_enough_data}.",
145-
class = "epipredict__not_enough_data"
146-
)
147-
}
116+
check_enough_data_core(new_data, object, col_names, "predict")
148117
new_data
149118
}
150119

@@ -168,3 +137,59 @@ tidy.check_enough_data <- function(x, ...) {
168137
res$drop_na <- x$drop_na
169138
res
170139
}
140+
141+
check_enough_data_core <- function(epi_df, step_obj, col_names, train_or_predict) {
142+
epi_df <- epi_df %>%
143+
group_by(across(all_of(.env$step_obj$epi_keys)))
144+
if (step_obj$drop_na) {
145+
any_missing_data <- epi_df %>%
146+
mutate(any_are_na = rowSums(across(any_of(.env$col_names), ~ is.na(.x))) > 0) %>%
147+
# count the number of rows where they're all not na
148+
summarise(sum(any_are_na == 0) < .env$step_obj$n, .groups = "drop")
149+
any_missing_data <- any_missing_data %>%
150+
summarize(across(all_of(setdiff(names(any_missing_data), step_obj$epi_keys)), any)) %>%
151+
any()
152+
153+
# figuring out which individual columns (if any) are to blame for this darth
154+
# of data
155+
cols_not_enough_data <- epi_df %>%
156+
summarise(
157+
across(
158+
all_of(.env$col_names),
159+
~ sum(!is.na(.x)) < .env$step_obj$n
160+
),
161+
.groups = "drop"
162+
) %>%
163+
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
164+
unlist() %>%
165+
names(.)[.]
166+
167+
if (length(cols_not_enough_data) == 0) {
168+
cols_not_enough_data <-
169+
glue::glue("no single column, but the combination of {paste0(col_names, collapse = ', ')}")
170+
}
171+
} else {
172+
# if we're not dropping na values, just count
173+
cols_not_enough_data <- epi_df %>%
174+
summarise(
175+
across(
176+
all_of(.env$col_names),
177+
~ dplyr::n() < .env$step_obj$n
178+
)
179+
)
180+
any_missing_data <- cols_not_enough_data %>%
181+
summarize(across(all_of(.env$col_names), all)) %>%
182+
all()
183+
cols_not_enough_data <- cols_not_enough_data %>%
184+
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
185+
unlist() %>%
186+
names(.)[.]
187+
}
188+
189+
if (any_missing_data) {
190+
cli_abort(
191+
"The following columns don't have enough data to {train_or_predict}: {cols_not_enough_data}.",
192+
class = "epipredict__not_enough_data"
193+
)
194+
}
195+
}

tests/testthat/_snaps/check_enough_data.md

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,53 @@
22

33
Code
44
epi_recipe(toy_epi_df) %>% check_enough_data(x, y, n = 2 * n + 1, drop_na = FALSE) %>%
5-
prep(toy_epi_df) %>% bake(new_data = NULL)
5+
prep(toy_epi_df)
66
Condition
7-
Error in `prep()`:
8-
! The following columns don't have enough data to predict: x and y.
7+
Error in `check_enough_data_core()`:
8+
! The following columns don't have enough data to train: x and y.
99

1010
---
1111

1212
Code
1313
epi_recipe(toy_epi_df) %>% check_enough_data(x, y, n = 2 * n - 1, drop_na = TRUE) %>%
14-
prep(toy_epi_df) %>% bake(new_data = NULL)
14+
prep(toy_epi_df)
1515
Condition
16-
Error in `prep()`:
17-
! The following columns don't have enough data to predict: x and y.
16+
Error in `check_enough_data_core()`:
17+
! The following columns don't have enough data to train: x.
1818

1919
# check_enough_data works on unpooled data
2020

2121
Code
2222
epi_recipe(toy_epi_df) %>% check_enough_data(x, y, n = n + 1, epi_keys = "geo_value",
23-
drop_na = FALSE) %>% prep(toy_epi_df) %>% bake(new_data = NULL)
23+
drop_na = FALSE) %>% prep(toy_epi_df)
2424
Condition
25-
Error in `prep()`:
26-
! The following columns don't have enough data to predict: x and y.
25+
Error in `check_enough_data_core()`:
26+
! The following columns don't have enough data to train: x and y.
2727

2828
---
2929

3030
Code
3131
epi_recipe(toy_epi_df) %>% check_enough_data(x, y, n = 2 * n - 3, epi_keys = "geo_value",
32-
drop_na = TRUE) %>% prep(toy_epi_df) %>% bake(new_data = NULL)
32+
drop_na = TRUE) %>% prep(toy_epi_df)
3333
Condition
34-
Error in `prep()`:
35-
! The following columns don't have enough data to predict: x and y.
34+
Error in `check_enough_data_core()`:
35+
! The following columns don't have enough data to train: x and y.
36+
37+
# check_enough_data only checks train data when skip = FALSE
38+
39+
Code
40+
forecaster %>% predict(new_data = toy_test_data %>% filter(time_value >
41+
"2020-01-08"))
42+
Condition
43+
Error in `check_enough_data_core()`:
44+
! The following columns don't have enough data to predict: x.
3645

3746
# check_enough_data works with all_predictors() downstream of constructed terms
3847

3948
Code
4049
epi_recipe(toy_epi_df) %>% step_epi_lag(x, lag = c(1, 2)) %>% check_enough_data(
41-
all_predictors(), y, n = 2 * n - 4) %>% prep(toy_epi_df) %>% bake(new_data = NULL)
50+
all_predictors(), y, n = 2 * n - 4) %>% prep(toy_epi_df)
4251
Condition
43-
Error in `prep()`:
44-
! The following columns don't have enough data to predict: lag_1_x, lag_2_x, and y.
52+
Error in `check_enough_data_core()`:
53+
! The following columns don't have enough data to train: no single column, but the combination of lag_1_x, lag_2_x, y.
4554

tests/testthat/test-check_enough_data.R

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@ test_that("check_enough_data works on pooled data", {
2727
error = TRUE,
2828
epi_recipe(toy_epi_df) %>%
2929
check_enough_data(x, y, n = 2 * n + 1, drop_na = FALSE) %>%
30-
prep(toy_epi_df) %>%
31-
bake(new_data = NULL)
30+
prep(toy_epi_df)
3231
)
3332
# Check drop_na works
3433
expect_snapshot(
3534
error = TRUE,
3635
epi_recipe(toy_epi_df) %>%
3736
check_enough_data(x, y, n = 2 * n - 1, drop_na = TRUE) %>%
38-
prep(toy_epi_df) %>%
39-
bake(new_data = NULL)
37+
prep(toy_epi_df)
4038
)
4139
})
4240

@@ -53,16 +51,14 @@ test_that("check_enough_data works on unpooled data", {
5351
error = TRUE,
5452
epi_recipe(toy_epi_df) %>%
5553
check_enough_data(x, y, n = n + 1, epi_keys = "geo_value", drop_na = FALSE) %>%
56-
prep(toy_epi_df) %>%
57-
bake(new_data = NULL)
54+
prep(toy_epi_df)
5855
)
5956
# Check drop_na works
6057
expect_snapshot(
6158
error = TRUE,
6259
epi_recipe(toy_epi_df) %>%
6360
check_enough_data(x, y, n = 2 * n - 3, epi_keys = "geo_value", drop_na = TRUE) %>%
64-
prep(toy_epi_df) %>%
65-
bake(new_data = NULL)
61+
prep(toy_epi_df)
6662
)
6763
})
6864

@@ -85,7 +81,7 @@ test_that("check_enough_data outputs the correct recipe values", {
8581
expect_equal(p$geo_value, rep(c("ca", "hi"), each = n))
8682
})
8783

88-
test_that("check_enough_train_data only checks train data", {
84+
test_that("check_enough_data only checks train data when skip = FALSE", {
8985
# Check that the train data has enough data, the test data does not, but
9086
# the check passes anyway (because it should be applied to training data)
9187
toy_test_data <- toy_epi_df %>%
@@ -94,16 +90,32 @@ test_that("check_enough_train_data only checks train data", {
9490
epiprocess::as_epi_df()
9591
expect_no_error(
9692
epi_recipe(toy_epi_df) %>%
97-
check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value", skip = TRUE) %>%
93+
check_enough_data(x, y, n = n - 2, epi_keys = "geo_value") %>%
9894
prep(toy_epi_df) %>%
9995
bake(new_data = toy_test_data)
10096
)
101-
# Same thing, but skip = FALSE
97+
# Making sure `skip = TRUE` is working correctly in `predict`
10298
expect_no_error(
10399
epi_recipe(toy_epi_df) %>%
104-
check_enough_train_data(y, n = n - 2, epi_keys = "geo_value") %>%
105-
prep(toy_epi_df) %>%
106-
bake(new_data = toy_test_data)
100+
add_role(y, new_role = "outcome") %>%
101+
check_enough_data(x, n = n - 2, epi_keys = "geo_value") %>%
102+
epi_workflow(linear_reg()) %>%
103+
fit(toy_epi_df) %>%
104+
predict(new_data = toy_test_data %>% filter(time_value > "2020-01-08"))
105+
)
106+
# making sure it works for skip = FALSE, where there's enough data to train
107+
# but not enough to predict
108+
expect_no_error(
109+
forecaster <- epi_recipe(toy_epi_df) %>%
110+
add_role(y, new_role = "outcome") %>%
111+
check_enough_data(x, n = 1, epi_keys = "geo_value", skip = FALSE) %>%
112+
epi_workflow(linear_reg()) %>%
113+
fit(toy_epi_df)
114+
)
115+
expect_snapshot(
116+
error = TRUE,
117+
forecaster %>%
118+
predict(new_data = toy_test_data %>% filter(time_value > "2020-01-08"))
107119
)
108120
})
109121

@@ -122,7 +134,6 @@ test_that("check_enough_data works with all_predictors() downstream of construct
122134
epi_recipe(toy_epi_df) %>%
123135
step_epi_lag(x, lag = c(1, 2)) %>%
124136
check_enough_data(all_predictors(), y, n = 2 * n - 4) %>%
125-
prep(toy_epi_df) %>%
126-
bake(new_data = NULL)
137+
prep(toy_epi_df)
127138
)
128139
})

0 commit comments

Comments
 (0)