diff --git a/R/get_test_data.R b/R/get_test_data.R index 3c29c3748..40992daa9 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -42,6 +42,8 @@ get_test_data <- function(recipe, x){ stop("insufficient training data") } + groups <- epi_keys(recipe)[epi_keys(recipe) != "time_value"] + test_data <- x %>% dplyr::filter( dplyr::if_any( @@ -49,7 +51,7 @@ get_test_data <- function(recipe, x){ .fns = ~ !is.na(.x) ) ) %>% - epiprocess::group_by(geo_value) %>% + epiprocess::group_by(dplyr::across(dplyr::all_of(groups))) %>% dplyr::slice_tail(n = max(max_lags) + 1) %>% epiprocess::ungroup() diff --git a/tests/testthat/test-get_test_data.R b/tests/testthat/test-get_test_data.R index e063523a1..d6f5256c9 100644 --- a/tests/testthat/test-get_test_data.R +++ b/tests/testthat/test-get_test_data.R @@ -1,4 +1,5 @@ -test_that("return expected number of rows", { +library(dplyr) +test_that("return expected number of rows and returned dataset is ungrouped", { r <- epi_recipe(case_death_rate_subset) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(death_rate, lag = c(0, 7, 14, 21, 28)) %>% @@ -10,6 +11,8 @@ test_that("return expected number of rows", { expect_equal(nrow(test), dplyr::n_distinct(case_death_rate_subset$geo_value)* 29) + + expect_false(dplyr::is.grouped_df(test)) }) @@ -35,3 +38,4 @@ test_that("expect error that geo_value or time_value does not exist", { expect_error(get_test_data(recipe = r, x = wrong_epi_df)) }) + diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index 5e6148ff5..db9888779 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -20,7 +20,6 @@ test_that("predict layer works alone", { expect_s3_class(p, "epi_df") expect_equal(nrow(p), 108L) expect_named(p, c("time_value", "geo_value", ".pred")) - }) test_that("prediction with interval works", {