diff --git a/DESCRIPTION b/DESCRIPTION index 97577506..9d3ce499 100755 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -51,6 +51,8 @@ Imports: tidyr, tsibble Suggests: + testthat (>= 3.0.0), delphi.epidata Remotes: github::cmu-delphi/delphi-epidata-r +Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index b077342e..78363e2b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -64,4 +64,5 @@ importFrom(stats,cor) importFrom(stats,median) importFrom(tidyr,unnest) importFrom(tidyselect,eval_select) +importFrom(tidyselect,starts_with) importFrom(tsibble,as_tsibble) diff --git a/R/outliers.R b/R/outliers.R index f12b99e5..9bd370dc 100644 --- a/R/outliers.R +++ b/R/outliers.R @@ -45,12 +45,15 @@ #' #' @export detect_outlr = function(x = seq_along(y), y, - methods = tibble(method = "rm", + methods = tibble::tibble(method = "rm", args = list(list()), abbr = "rm"), combiner = c("median", "mean", "none")) { # Validate combiner combiner = match.arg(combiner) + + # Validate that x contains all distinct values + if (max(table(x)) > 1) Abort("`x` must not contain duplicate values; did you group your `epi_df` by all relevant key variables?") # Run all outlier detection methods results = purrr::pmap_dfc(methods, function(method, args, abbr) { @@ -187,6 +190,7 @@ detect_outlr_rm = function(x = seq_along(y), y, n = 21, #' description. #' #' @importFrom stats median +#' @importFrom tidyselect starts_with #' @export detect_outlr_stl = function(x = seq_along(y), y, n_trend = 21, @@ -216,11 +220,10 @@ detect_outlr_stl = function(x = seq_along(y), y, fabletools::model(feasts::STL(stl_formula, robust = TRUE)) %>% generics::components() %>% tibble::as_tibble() %>% - dplyr::transmute( - trend = trend, - seasonal = season_week, - resid = remainder) - + dplyr::select(trend:remainder) %>% + dplyr::rename_with(~ "seasonal", tidyselect::starts_with("season")) %>% + dplyr::rename(resid = remainder) + # Allocate the seasonal term from STL to either fitted or resid if (!is.null(seasonal_period)) { stl_components = stl_components %>% @@ -263,15 +266,18 @@ detect_outlr_stl = function(x = seq_along(y), y, # Common function for rolling IQR, using fitted and resid variables roll_iqr = function(z, n, detection_multiplier, min_radius, - replacement_multiplier, min_lower) { + replacement_multiplier, min_lower) { + if (typeof(z$y) == "integer") as_type = as.integer + else as_type = as.numeric + epi_slide(z, roll_iqr = IQR(resid), n = n, align = "center") %>% dplyr::mutate( lower = pmax(min_lower, fitted - pmax(min_radius, detection_multiplier * roll_iqr)), upper = fitted + pmax(min_radius, detection_multiplier * roll_iqr), replacement = dplyr::case_when( - (y < lower) ~ fitted - replacement_multiplier * roll_iqr, - (y > upper) ~ fitted + replacement_multiplier * roll_iqr, + (y < lower) ~ as_type(fitted - replacement_multiplier * roll_iqr), + (y > upper) ~ as_type(fitted + replacement_multiplier * roll_iqr), TRUE ~ y)) %>% dplyr::select(lower, upper, replacement) %>% tibble::as_tibble() diff --git a/man/detect_outlr.Rd b/man/detect_outlr.Rd index 4a3c24c0..d832cd6d 100644 --- a/man/detect_outlr.Rd +++ b/man/detect_outlr.Rd @@ -7,7 +7,7 @@ detect_outlr( x = seq_along(y), y, - methods = tibble(method = "rm", args = list(list()), abbr = "rm"), + methods = tibble::tibble(method = "rm", args = list(list()), abbr = "rm"), combiner = c("median", "mean", "none") ) } diff --git a/tests/testthat.R b/tests/testthat.R new file mode 100644 index 00000000..b26d274b --- /dev/null +++ b/tests/testthat.R @@ -0,0 +1,4 @@ +library(testthat) +library(epiprocess) + +test_check("epiprocess") diff --git a/tests/testthat/test-outliers.R b/tests/testthat/test-outliers.R new file mode 100644 index 00000000..acdc2a35 --- /dev/null +++ b/tests/testthat/test-outliers.R @@ -0,0 +1,7 @@ +test_that("detect_outlr throws error with duplicate x", { + expect_error(detect_outlr(x = c(1, 2, 3, 3, 4), y = 1:5)) +}) + +test_that("detect_outlr throws error with length(x) != length(y)", { + expect_error(detect_outlr(x = 1:3, y = 1:5)) +}) diff --git a/vignettes/outliers.Rmd b/vignettes/outliers.Rmd index 1f91f17d..22bae575 100644 --- a/vignettes/outliers.Rmd +++ b/vignettes/outliers.Rmd @@ -75,7 +75,8 @@ detection_methods = bind_rows( abbr = "rm"), tibble(method = "stl", args = list(list(detect_negatives = TRUE, - detection_multiplier = 2.5)), + detection_multiplier = 2.5, + seasonal_period = 7)), abbr = "stl_seasonal"), tibble(method = "stl", args = list(list(detect_negatives = TRUE,