diff --git a/.Rbuildignore b/.Rbuildignore index d0c10332d..e4bff0b18 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -18,3 +18,4 @@ ^.git-blame-ignore-revs$ ^doc$ ^Meta$ +^.lintr$ \ No newline at end of file diff --git a/.lintr b/.lintr new file mode 100644 index 000000000..c7c90554d --- /dev/null +++ b/.lintr @@ -0,0 +1,9 @@ +linters: linters_with_defaults( + line_length_linter(120), + cyclocomp_linter = NULL, + object_length_linter(length = 40L) + ) +exclusions: list( + "renv", + "venv" + ) diff --git a/NEWS.md b/NEWS.md index 5f629f45f..5941a8f3b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,34 +4,35 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat # epipredict 0.1 -- add `check_enough_train_data` that will error if training data is too small -- added `check_enough_train_data` to `arx_forecaster` -- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()` -- rename the `dist_quantiles()` to be more descriptive, breaking change -- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change) -- add `pivot_quantiles_wider()` for easier plotting -- add complement `pivot_quantiles_longer()` -- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()` -- add `smooth_quantile_reg()` -- improved printing of various methods / internals -- canned forecasters get a class -- fixed quantile bug in `flatline_forecaster()` -- add functionality to output the unfit workflow from the canned forecasters -- add quantile_reg() -- clean up documentation bugs -- add smooth_quantile_reg() -- add classifier -- training window step debugged -- `min_train_window` argument removed from canned forecasters -- add forecasters -- implement postprocessing -- vignettes avaliable -- arx_forecaster -- pkgdown -- Publish public for easy navigation -- Two simple forecasters as test beds -- Working vignette -- use `checkmate` for input validation -- refactor quantile extrapolation (possibly creates different results) -- force `target_date` + `forecast_date` handling to match the time_type of - the epi_df. allows for annual and weekly data +- `layer_residual_quantiles()` will now error if any of the residual quantiles are NA +- add `check_enough_train_data` that will error if training data is too small +- added `check_enough_train_data` to `arx_forecaster` +- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()` +- rename the `dist_quantiles()` to be more descriptive, breaking change +- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change) +- add `pivot_quantiles_wider()` for easier plotting +- add complement `pivot_quantiles_longer()` +- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()` +- add `smooth_quantile_reg()` +- improved printing of various methods / internals +- canned forecasters get a class +- fixed quantile bug in `flatline_forecaster()` +- add functionality to output the unfit workflow from the canned forecasters +- add quantile_reg() +- clean up documentation bugs +- add smooth_quantile_reg() +- add classifier +- training window step debugged +- `min_train_window` argument removed from canned forecasters +- add forecasters +- implement postprocessing +- vignettes avaliable +- arx_forecaster +- pkgdown +- Publish public for easy navigation +- Two simple forecasters as test beds +- Working vignette +- use `checkmate` for input validation +- refactor quantile extrapolation (possibly creates different results) +- force `target_date` + `forecast_date` handling to match the time_type of + the epi_df. allows for annual and weekly data diff --git a/R/arx_classifier.R b/R/arx_classifier.R index d42247426..1ce6a5b3c 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -197,7 +197,7 @@ arx_class_epi_workflow <- function( } forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - target_date <- args_list$target_date %||% forecast_date + args_list$ahead + target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) # --- postprocessor f <- frosting() %>% layer_predict() # %>% layer_naomit() diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index ce2fa57b0..d84f65df1 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -143,7 +143,7 @@ arx_fcast_epi_workflow <- function( } forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - target_date <- args_list$target_date %||% forecast_date + args_list$ahead + target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) # --- postprocessor f <- frosting() %>% layer_predict() # %>% layer_naomit() @@ -289,8 +289,8 @@ print.arx_fcast <- function(x, ...) { } compare_quantile_args <- function(alist, tlist) { - default_alist <- eval(formals(arx_args_list)$quantile_level) - default_tlist <- eval(formals(quantile_reg)$quantile_level) + default_alist <- eval(formals(arx_args_list)$quantile_levels) + default_tlist <- eval(formals(quantile_reg)$quantile_levels) if (setequal(alist, default_alist)) { if (setequal(tlist, default_tlist)) { return(sort(unique(union(alist, tlist)))) diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index abb231bca..4af6d6f3f 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -75,7 +75,7 @@ cdc_baseline_forecaster <- function( step_training_window(n_recent = args_list$n_training) forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - # target_date <- args_list$target_date %||% forecast_date + args_list$ahead + # target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) latest <- get_test_data( diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 99ebc8694..c7b060caa 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -47,8 +47,7 @@ flatline_forecaster <- function( step_training_window(n_recent = args_list$n_training) forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - target_date <- args_list$target_date %||% forecast_date + args_list$ahead - + target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) latest <- get_test_data( epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer, diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index b09956c2e..a6bc93f18 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -123,6 +123,13 @@ slather.layer_residual_quantiles <- probs = object$quantile_levels, na.rm = TRUE )) ) + # Check for NA + if (any(sapply(r$dstn, is.na))) { + cli::cli_abort(c( + "Residual quantiles could not be calculated due to missing residuals.", + i = "This may be due to `n_train` < `ahead` in your {.cls epi_recipe}." + )) + } estimate <- components$predictions$.pred res <- tibble::tibble( diff --git a/tests/testthat/test-dist_quantiles.R b/tests/testthat/test-dist_quantiles.R index 99ce742d5..9c01ca103 100644 --- a/tests/testthat/test-dist_quantiles.R +++ b/tests/testthat/test-dist_quantiles.R @@ -28,17 +28,16 @@ test_that("single dist_quantiles works, quantiles are accessible", { test_that("quantile extrapolator works", { dstn <- dist_normal(c(10, 2), c(5, 10)) - qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) + qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) expect_s3_class(qq, "distribution") expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles") - expect_length(parameters(qq[1])$q[[1]], 3L) - + expect_length(parameters(qq[1])$quantile_levels[[1]], 3L) dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) - qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) + qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) expect_s3_class(qq, "distribution") expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles") - expect_length(parameters(qq[1])$q[[1]], 7L) + expect_length(parameters(qq[1])$quantile_levels[[1]], 7L) }) test_that("small deviations of quantile requests work", { diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index 5723fbce9..a346c62a3 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -98,3 +98,10 @@ test_that("Canned forecasters work with / without", { ) ) }) + +test_that("flatline_forecaster correctly errors when n_training < ahead", { + expect_error( + flatline_forecaster(jhu, "death_rate", args_list = flatline_args_list(ahead = 10, n_training = 9)), + "This may be due to `n_train` < `ahead`" + ) +}) diff --git a/tests/testthat/test-target_date_bug.R b/tests/testthat/test-target_date_bug.R new file mode 100644 index 000000000..4a7e7d2e8 --- /dev/null +++ b/tests/testthat/test-target_date_bug.R @@ -0,0 +1,77 @@ +# These tests address #290: +# https://github.com/cmu-delphi/epipredict/issues/290 + +library(dplyr) +train <- jhu_csse_daily_subset |> + filter(time_value >= as.Date("2021-10-01")) |> + select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av) +ngeos <- n_distinct(train$geo_value) + +test_that("flatline determines target_date where forecast_date exists", { + flat <- flatline_forecaster( + train, "dr", + args_list = flatline_args_list( + forecast_date = as.Date("2021-12-31"), + target_date = as.Date("2022-01-01"), + ahead = 1L + ) + ) + # previously, if target_date existed, it could be + # erroneously incremented by the ahead + expect_identical( + flat$predictions$target_date, + rep(as.Date("2022-01-01"), ngeos) + ) + expect_identical( + flat$predictions$forecast_date, + rep(as.Date("2021-12-31"), ngeos) + ) + expect_true(all(!is.na(flat$predictions$.pred_distn))) + expect_true(all(!is.na(flat$predictions$.pred))) +}) + +test_that("arx_forecaster determines target_date where forecast_date exists", { + arx <- arx_forecaster( + train, "dr", c("dr", "cr"), + args_list = arx_args_list( + forecast_date = as.Date("2021-12-31"), + target_date = as.Date("2022-01-01"), + ahead = 1L + ) + ) + # previously, if target_date existed, it could be + # erroneously incremented by the ahead + expect_identical( + arx$predictions$target_date, + rep(as.Date("2022-01-01"), ngeos) + ) + expect_identical( + arx$predictions$forecast_date, + rep(as.Date("2021-12-31"), ngeos) + ) + expect_true(all(!is.na(arx$predictions$.pred_distn))) + expect_true(all(!is.na(arx$predictions$.pred))) +}) + +test_that("arx_classifier determines target_date where forecast_date exists", { + arx <- arx_classifier( + train, "dr", c("dr"), + trainer = parsnip::boost_tree(mode = "classification", trees = 5), + args_list = arx_class_args_list( + forecast_date = as.Date("2021-12-31"), + target_date = as.Date("2022-01-01"), + ahead = 1L + ) + ) + # previously, if target_date existed, it could be + # erroneously incremented by the ahead + expect_identical( + arx$predictions$target_date, + rep(as.Date("2022-01-01"), ngeos) + ) + expect_identical( + arx$predictions$forecast_date, + rep(as.Date("2021-12-31"), ngeos) + ) + expect_true(all(!is.na(arx$predictions$.pred_class))) +})