diff --git a/.Rbuildignore b/.Rbuildignore index 3a77bb347..510725267 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -3,6 +3,7 @@ ^epipredict\.Rproj$ ^\.Rproj\.user$ ^LICENSE\.md$ +^DEVELOPMENT\.md$ ^drafts$ ^\.Rprofile$ ^man-roxygen$ @@ -15,5 +16,8 @@ ^data-raw$ ^vignettes/articles$ ^.git-blame-ignore-revs$ +^DEVELOPMENT\.md$ ^doc$ ^Meta$ +^.lintr$ +^.venv$ \ No newline at end of file diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..201f22b93 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,24 @@ +### Checklist + +Please: + +- [ ] Make sure this PR is against "dev", not "main". +- [ ] Request a review from one of the current epipredict main reviewers: + dajmcdon. +- [ ] Make sure to bump the version number in `DESCRIPTION` and `NEWS.md`. + Always increment the patch version number (the third number), unless you are + making a release PR from dev to main, in which case increment the minor + version number (the second number). +- [ ] Describe changes made in NEWS.md, making sure breaking changes + (backwards-incompatible changes to the documented interface) are noted. + Collect the changes under the next release number (e.g. if you are on + 0.7.2, then write your changes under the 0.8 heading). +- [ ] Consider pinning the `epiprocess` version in the `DESCRIPTION` file if + - You anticipate breaking changes in `epiprocess` soon + - You want to co-develop features in `epipredict` and `epiprocess` + +### Change explanations for reviewer + +### Magic GitHub syntax to mark associated Issue(s) as resolved when this is merged into the default branch + +- Resolves #{issue number} diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index c4bcd6b68..1c8055ff0 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -4,9 +4,9 @@ # Created with usethis + edited to use API key. on: push: - branches: [main, master] + branches: [main, dev] pull_request: - branches: [main, master] + branches: [main, dev] name: R-CMD-check diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index 245e43c59..cc940bc8b 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -1,10 +1,10 @@ -# Workflow derived from https://github.com/r-lib/actions/tree/master/examples +# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help # -# Created with usethis + edited to use API key. +# Created with usethis + edited to run on PRs to dev, use API key. on: push: - branches: [main, master] + branches: [main, dev] release: types: [published] workflow_dispatch: @@ -13,6 +13,8 @@ name: pkgdown jobs: pkgdown: + # only build docs on the main repository and not forks + if: github.repository_owner == 'cmu-delphi' runs-on: ubuntu-latest # Only restrict concurrency for non-PR jobs concurrency: @@ -20,7 +22,7 @@ jobs: env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: r-lib/actions/setup-pandoc@v2 @@ -35,13 +37,19 @@ jobs: - name: Build site env: - DELPHI_EPIDATA_KEY: ${{ secrets.SECRET_EPIPREDICT_GHACTIONS_DELPHI_EPIDATA_KEY }} - run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) + DELPHI_EPIDATA_KEY: ${{ secrets.SECRET_EPIPROCESS_GHACTIONS_DELPHI_EPIDATA_KEY }} + run: | + if (startsWith("${{ github.event_name }}", "pull_request")) { + mode <- ifelse("${{ github.base_ref }}" == "main", "release", "devel") + } else { + mode <- ifelse("${{ github.ref_name }}" == "main", "release", "devel") + } + pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE, override=list(PKGDOWN_DEV_MODE=mode)) shell: Rscript {0} - name: Deploy to GitHub pages 🚀 if: github.event_name != 'pull_request' - uses: JamesIves/github-pages-deploy-action@4.1.4 + uses: JamesIves/github-pages-deploy-action@v4.4.1 with: clean: false branch: gh-pages diff --git a/.lintr b/.lintr new file mode 100644 index 000000000..c7c90554d --- /dev/null +++ b/.lintr @@ -0,0 +1,9 @@ +linters: linters_with_defaults( + line_length_linter(120), + cyclocomp_linter = NULL, + object_length_linter(length = 40L) + ) +exclusions: list( + "renv", + "venv" + ) diff --git a/DESCRIPTION b/DESCRIPTION index b0b592e2a..5cd468fb9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,13 +1,14 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.8 +Version: 0.1.0 Authors@R: c( - person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), + person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), + person("Dmitry", "Shemetov", email = "dshemeto@andrew.cmu.edu", role = "aut"), + person("David", "Weber", email = "davidweb@andrew.cmu.edu", role = "aut"), + person("CMU's Delphi Research Group", role = c("cph", "fnd")), person("Logan", "Brooks", role = "aut"), person("Rachel", "Lobay", role = "aut"), - person("Dmitry", "Shemetov", email = "dshemeto@andrew.cmu.edu", role = "ctb"), - person("David", "Weber", email = "davidweb@andrew.cmu.edu", role = "ctb"), person("Maggie", "Liu", role = "ctb"), person("Ken", "Mawer", role = "ctb"), person("Chloe", "You", role = "ctb"), @@ -23,43 +24,46 @@ URL: https://github.com/cmu-delphi/epipredict/, https://cmu-delphi.github.io/epipredict BugReports: https://github.com/cmu-delphi/epipredict/issues/ Depends: - epiprocess (>= 0.6.0), + epiprocess (>= 0.9.0), parsnip (>= 1.0.0), R (>= 3.5.0) Imports: + checkmate, cli, distributional, dplyr, - fs, generics, + ggplot2, glue, hardhat (>= 1.3.0), lifecycle, magrittr, - methods, - quantreg, recipes (>= 1.0.4), - rlang, - smoothqr, + rlang (>= 1.1.0), stats, tibble, tidyr, tidyselect, - usethis, + tsibble, vctrs, workflows (>= 1.0.0) Suggests: covidcast, data.table, epidatr (>= 1.0.0), - ggplot2, + fs, + grf, knitr, lubridate, poissonreg, + purrr, + quantreg, ranger, RcppRoll, rmarkdown, + smoothqr, testthat (>= 3.0.0), + usethis, xgboost VignetteBuilder: knitr @@ -71,4 +75,4 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.0 +RoxygenNote: 7.3.2 diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 000000000..67f1b3003 --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,45 @@ +## Setting up the development environment + +```r +install.packages(c('devtools', 'pkgdown', 'styler', 'lintr')) # install dev dependencies +devtools::install_deps(dependencies = TRUE) # install package dependencies +devtools::document() # generate package meta data and man files +devtools::build() # build package +``` + +## Validating the package + +```r +styler::style_pkg() # format code +lintr::lint_package() # lint code + +devtools::test() # test package +devtools::check() # check package for errors +``` + +## Developing the documentation site + +The [documentation site](https://cmu-delphi.github.io/epipredict/) is built off of the `main` branch. The `dev` version of the site is available at https://cmu-delphi.github.io/epipredict/dev. + +The documentation site can be previewed locally by running in R + +```r +pkgdown::build_site(preview=TRUE) +``` + +The `main` version is available at `file:////epidatr/epipredict/index.html` and `dev` at `file:////epipredict/docs/dev/index.html`. + +You can also build the docs manually and launch the site with python. From the terminal, this looks like + +```bash +R -e 'devtools::document()' +python -m http.server -d docs +``` + +## Versioning + +Please follow the guidelines in the [PR template document](.github/pull_request_template.md). + +## Release process + +TBD diff --git a/NAMESPACE b/NAMESPACE index fc7a7ea00..e815203eb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,8 +1,13 @@ # Generated by roxygen2: do not edit by hand +S3method(Add_model,epi_workflow) +S3method(Add_model,workflow) S3method(Math,dist_quantiles) S3method(Ops,dist_quantiles) -S3method(add_model,epi_workflow) +S3method(Remove_model,epi_workflow) +S3method(Remove_model,workflow) +S3method(Update_model,epi_workflow) +S3method(Update_model,workflow) S3method(adjust_epi_recipe,epi_recipe) S3method(adjust_epi_recipe,epi_workflow) S3method(adjust_frosting,epi_workflow) @@ -10,20 +15,19 @@ S3method(adjust_frosting,frosting) S3method(apply_frosting,default) S3method(apply_frosting,epi_workflow) S3method(augment,epi_workflow) +S3method(autoplot,canned_epipred) +S3method(autoplot,epi_workflow) S3method(bake,check_enough_train_data) S3method(bake,epi_recipe) S3method(bake,step_epi_ahead) S3method(bake,step_epi_lag) +S3method(bake,step_epi_slide) S3method(bake,step_growth_rate) S3method(bake,step_lag_difference) S3method(bake,step_population_scaling) S3method(bake,step_training_window) S3method(detect_layer,frosting) S3method(detect_layer,workflow) -S3method(epi_keys,data.frame) -S3method(epi_keys,default) -S3method(epi_keys,epi_df) -S3method(epi_keys,recipe) S3method(epi_recipe,default) S3method(epi_recipe,epi_df) S3method(epi_recipe,formula) @@ -42,9 +46,12 @@ S3method(extrapolate_quantiles,distribution) S3method(fit,epi_workflow) S3method(flusight_hub_formatter,canned_epipred) S3method(flusight_hub_formatter,data.frame) +S3method(forecast,epi_workflow) S3method(format,dist_quantiles) S3method(is.na,dist_quantiles) S3method(is.na,distribution) +S3method(key_colnames,epi_workflow) +S3method(key_colnames,recipe) S3method(mean,dist_quantiles) S3method(median,dist_quantiles) S3method(predict,epi_workflow) @@ -53,6 +60,7 @@ S3method(prep,check_enough_train_data) S3method(prep,epi_recipe) S3method(prep,step_epi_ahead) S3method(prep,step_epi_lag) +S3method(prep,step_epi_slide) S3method(prep,step_growth_rate) S3method(prep,step_lag_difference) S3method(prep,step_population_scaling) @@ -81,6 +89,7 @@ S3method(print,layer_threshold) S3method(print,layer_unnest) S3method(print,step_epi_ahead) S3method(print,step_epi_lag) +S3method(print,step_epi_slide) S3method(print,step_growth_rate) S3method(print,step_lag_difference) S3method(print,step_naomit) @@ -88,7 +97,6 @@ S3method(print,step_population_scaling) S3method(print,step_training_window) S3method(quantile,dist_quantiles) S3method(refresh_blueprint,default_epi_recipe_blueprint) -S3method(remove_model,epi_workflow) S3method(residuals,flatline) S3method(run_mold,default_epi_recipe_blueprint) S3method(slather,layer_add_forecast_date) @@ -111,10 +119,16 @@ S3method(tidy,check_enough_train_data) S3method(tidy,frosting) S3method(tidy,layer) S3method(update,layer) -S3method(update_model,epi_workflow) S3method(vec_ptype_abbr,dist_quantiles) S3method(vec_ptype_full,dist_quantiles) +S3method(weighted_interval_score,default) +S3method(weighted_interval_score,dist_default) +S3method(weighted_interval_score,dist_quantiles) +S3method(weighted_interval_score,distribution) export("%>%") +export(Add_model) +export(Remove_model) +export(Update_model) export(add_epi_recipe) export(add_frosting) export(add_layer) @@ -128,15 +142,15 @@ export(arx_class_epi_workflow) export(arx_classifier) export(arx_fcast_epi_workflow) export(arx_forecaster) +export(autoplot) export(bake) export(cdc_baseline_args_list) export(cdc_baseline_forecaster) export(check_enough_train_data) -export(create_layer) +export(clean_f_name) export(default_epi_recipe_blueprint) export(detect_layer) export(dist_quantiles) -export(epi_keys) export(epi_recipe) export(epi_recipe_blueprint) export(epi_workflow) @@ -149,9 +163,9 @@ export(flatline) export(flatline_args_list) export(flatline_forecaster) export(flusight_hub_formatter) +export(forecast) export(frosting) export(get_test_data) -export(grab_names) export(is_epi_recipe) export(is_epi_workflow) export(is_layer) @@ -175,6 +189,7 @@ export(pivot_quantiles_longer) export(pivot_quantiles_wider) export(prep) export(quantile_reg) +export(rand_id) export(remove_epi_recipe) export(remove_frosting) export(remove_model) @@ -183,46 +198,89 @@ export(smooth_quantile_reg) export(step_epi_ahead) export(step_epi_lag) export(step_epi_naomit) +export(step_epi_slide) export(step_growth_rate) export(step_lag_difference) export(step_population_scaling) export(step_training_window) +export(tibble) +export(tidy) export(update_epi_recipe) export(update_frosting) export(update_model) export(validate_layer) +export(weighted_interval_score) import(distributional) import(epiprocess) import(parsnip) import(recipes) +importFrom(checkmate,assert_class) +importFrom(checkmate,assert_numeric) +importFrom(checkmate,test_character) +importFrom(checkmate,test_date) +importFrom(checkmate,test_function) +importFrom(checkmate,test_integerish) +importFrom(checkmate,test_logical) +importFrom(checkmate,test_numeric) +importFrom(checkmate,test_scalar) importFrom(cli,cli_abort) +importFrom(cli,cli_warn) importFrom(dplyr,across) importFrom(dplyr,all_of) +importFrom(dplyr,any_of) +importFrom(dplyr,arrange) +importFrom(dplyr,bind_cols) +importFrom(dplyr,bind_rows) +importFrom(dplyr,everything) +importFrom(dplyr,filter) +importFrom(dplyr,full_join) importFrom(dplyr,group_by) -importFrom(dplyr,n) +importFrom(dplyr,left_join) +importFrom(dplyr,mutate) +importFrom(dplyr,relocate) +importFrom(dplyr,rename) +importFrom(dplyr,select) importFrom(dplyr,summarise) +importFrom(dplyr,summarize) importFrom(dplyr,ungroup) +importFrom(epiprocess,epi_slide) importFrom(epiprocess,growth_rate) importFrom(generics,augment) importFrom(generics,fit) +importFrom(generics,forecast) +importFrom(generics,tidy) +importFrom(ggplot2,aes) +importFrom(ggplot2,autoplot) +importFrom(ggplot2,geom_line) +importFrom(ggplot2,geom_linerange) +importFrom(ggplot2,geom_point) +importFrom(ggplot2,geom_ribbon) importFrom(hardhat,refresh_blueprint) importFrom(hardhat,run_mold) -importFrom(lifecycle,deprecated) importFrom(magrittr,"%>%") -importFrom(methods,is) -importFrom(quantreg,rq) importFrom(recipes,bake) importFrom(recipes,prep) +importFrom(recipes,rand_id) +importFrom(rlang,"!!!") importFrom(rlang,"!!") importFrom(rlang,"%@%") importFrom(rlang,"%||%") importFrom(rlang,":=") importFrom(rlang,abort) +importFrom(rlang,arg_match) +importFrom(rlang,as_function) +importFrom(rlang,caller_arg) importFrom(rlang,caller_env) -importFrom(rlang,is_empty) +importFrom(rlang,enquo) +importFrom(rlang,enquos) +importFrom(rlang,expr) +importFrom(rlang,global_env) +importFrom(rlang,inject) +importFrom(rlang,is_logical) importFrom(rlang,is_null) -importFrom(rlang,quos) -importFrom(smoothqr,smooth_qr) +importFrom(rlang,is_true) +importFrom(rlang,set_names) +importFrom(rlang,sym) importFrom(stats,as.formula) importFrom(stats,family) importFrom(stats,lm) @@ -234,9 +292,8 @@ importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,residuals) importFrom(tibble,as_tibble) -importFrom(tibble,is_tibble) importFrom(tibble,tibble) -importFrom(tidyr,drop_na) +importFrom(tidyr,crossing) importFrom(vctrs,as_list_of) importFrom(vctrs,field) importFrom(vctrs,new_rcrd) diff --git a/NEWS.md b/NEWS.md index d2cbc0d29..8edddae92 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,49 +1,63 @@ # epipredict (development) -# epipredict 0.0.8 - -- add `check_enough_train_data` that will error if training data is too small -- added `check_enough_train_data` to `arx_forecaster` - -# epipredict 0.0.7 - -- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()` - -# epipredict 0.0.6 - -- rename the `dist_quantiles()` to be more descriptive, breaking change) -- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change) -- add `pivot_quantiles_wider()` for easier plotting -- add complement `pivot_quantiles_longer()` -- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()` - -# epipredict 0.0.5 - -- add `smooth_quantile_reg()` -- improved printing of various methods / internals -- canned forecasters get a class -- fixed quantile bug in `flatline_forecaster()` -- add functionality to output the unfit workflow from the canned forecasters - -# epipredict 0.0.4 - -- add quantile_reg() -- clean up documentation bugs -- add smooth_quantile_reg() -- add classifier -- training window step debugged -- `min_train_window` argument removed from canned forecasters - -# epipredict 0.0.3 - -- add forecasters -- implement postprocessing -- vignettes avaliable -- arx_forecaster -- pkgdown - -# epipredict 0.0.0.9000 - -- Publish public for easy navigation -- Two simple forecasters as test beds -- Working vignette +Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicate PR's. + +# epipredict 0.1 + +- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()` +- rename the `dist_quantiles()` to be more descriptive, breaking change +- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change) +- add `pivot_quantiles_wider()` for easier plotting +- add complement `pivot_quantiles_longer()` +- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()` +- add `smooth_quantile_reg()` +- improved printing of various methods / internals +- canned forecasters get a class +- fixed quantile bug in `flatline_forecaster()` +- add functionality to output the unfit workflow from the canned forecasters +- add quantile_reg() +- clean up documentation bugs +- add smooth_quantile_reg() +- add classifier +- training window step debugged +- `min_train_window` argument removed from canned forecasters +- add forecasters +- implement postprocessing +- vignettes avaliable +- arx_forecaster +- pkgdown +- Publish public for easy navigation +- Two simple forecasters as test beds +- Working vignette +- use `checkmate` for input validation +- refactor quantile extrapolation (possibly creates different results) +- force `target_date` + `forecast_date` handling to match the time_type of the + epi_df. allows for annual and weekly data +- add `check_enough_train_data()` that will error if training data is too small +- added `check_enough_train_data()` to `arx_forecaster()` +- `layer_residual_quantiles()` will now error if any of the residual quantiles + are NA +- `*_args_list()` functions now warn if `forecast_date + ahead != target_date` +- the `predictor` argument in `arx_forecaster()` now defaults to the value of + the `outcome` argument +- `arx_fcast_epi_workflow()` and `arx_class_epi_workflow()` now default to + `trainer = parsnip::logistic_reg()` to match their more canned versions. +- add a `forecast()` method simplify generating forecasts +- refactor `bake.epi_recipe()` and remove `epi_juice()`. +- Revise `compat-purrr` to use the r-lang `standalone-*` version (via + `{usethis}`) +- Replaced old version-faithful example in sliding AR & ARX forecasters vignette +- `epi_recipe()` will now warn when given non-`epi_df` data +- `layer_predict()` and `predict.epi_workflow()` will now appropriately forward + `...` args intended for `predict.model_fit()` +- `bake.epi_recipe()` will now re-infer the geo and time type in case baking the + steps has changed the appropriate values +- produce length 0 `dist_quantiles()` +- add functionality to calculate weighted interval scores for `dist_quantiles()` +- Add `step_epi_slide` to produce generic sliding computations over an `epi_df` +- Add quantile random forests (via `{grf}`) as a parsnip engine +- Replace `epi_keys()` with `epiprocess::key_colnames()`, #352 +- More descriptive error messages from `arg_is_*()`, #287 +- Fix bug where `fit()` drops the `epi_workflow` class (also error if + non-`epi_df` data is given to `epi_recipe()`), #363 +- Try to retain the `epi_df` class during baking to the extent possible, #376 diff --git a/R/arx_classifier.R b/R/arx_classifier.R index d42247426..ca6a3537b 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -1,7 +1,7 @@ #' Direct autoregressive classifier with covariates #' #' This is an autoregressive classification model for -#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning #' that it estimates a class at a particular target horizon. #' #' @inheritParams arx_forecaster @@ -26,8 +26,9 @@ #' @seealso [arx_class_epi_workflow()], [arx_class_args_list()] #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value >= as.Date("2021-11-01")) +#' filter(time_value >= as.Date("2021-11-01")) #' #' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate")) #' @@ -45,25 +46,23 @@ arx_classifier <- function( epi_data, outcome, predictors, - trainer = parsnip::logistic_reg(), + trainer = logistic_reg(), args_list = arx_class_args_list()) { if (!is_classification(trainer)) { - cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") } - wf <- arx_class_epi_workflow( - epi_data, outcome, predictors, trainer, args_list - ) + wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list) + wf <- fit(wf, epi_data) - latest <- get_test_data( - hardhat::extract_preprocessor(wf), epi_data, TRUE, args_list$nafill_buffer, - args_list$forecast_date %||% max(epi_data$time_value) - ) - - wf <- generics::fit(wf, epi_data) - preds <- predict(wf, new_data = latest) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) + preds <- forecast( + wf, + fill_locf = TRUE, + n_recent = args_list$nafill_buffer, + forecast_date = args_list$forecast_date %||% max(epi_data$time_value) + ) %>% + as_tibble() %>% + select(-time_value) structure( list( @@ -87,19 +86,19 @@ arx_classifier <- function( #' may alter the returned `epi_workflow` object but can be omitted. #' #' @inheritParams arx_classifier -#' @param trainer A `{parsnip}` model describing the type of estimation. -#' For now, we enforce `mode = "classification"`. Typical values are +#' @param trainer A `{parsnip}` model describing the type of estimation. For +#' now, we enforce `mode = "classification"`. Typical values are #' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated -#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can -#' also be used. May be `NULL` (the default). +#' trainers like [parsnip::naive_Bayes()] or [parsnip::rand_forest()] can also +#' be used. May be `NULL` if you'd like to decide later. #' #' @return An unfit `epi_workflow`. #' @export #' @seealso [arx_classifier()] #' @examples -#' +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value >= as.Date("2021-11-01")) +#' filter(time_value >= as.Date("2021-11-01")) #' #' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate")) #' @@ -107,7 +106,7 @@ arx_classifier <- function( #' jhu, #' "death_rate", #' c("case_rate", "death_rate"), -#' trainer = parsnip::multinom_reg(), +#' trainer = multinom_reg(), #' args_list = arx_class_args_list( #' breaks = c(-.05, .1), ahead = 14, #' horizon = 14, method = "linear_reg" @@ -117,14 +116,14 @@ arx_class_epi_workflow <- function( epi_data, outcome, predictors, - trainer = NULL, + trainer = parsnip::logistic_reg(), args_list = arx_class_args_list()) { validate_forecaster_inputs(epi_data, outcome, predictors) if (!inherits(args_list, c("arx_class", "alist"))) { - rlang::abort("args_list was not created using `arx_class_args_list().") + cli_abort("`args_list` was not created using `arx_class_args_list()`.") } if (!(is.null(trainer) || is_classification(trainer))) { - rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") } lags <- arx_lags_validator(predictors, args_list$lags) @@ -132,7 +131,7 @@ arx_class_epi_workflow <- function( # ------- predictors r <- epi_recipe(epi_data) %>% step_growth_rate( - tidyselect::all_of(predictors), + dplyr::all_of(predictors), role = "grp", horizon = args_list$horizon, method = args_list$method, @@ -175,29 +174,27 @@ arx_class_epi_workflow <- function( o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o)) r <- r %>% step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>% - step_mutate( + recipes::step_mutate( outcome_class = cut(!!o2, breaks = args_list$breaks), role = "outcome" ) %>% step_epi_naomit() %>% - step_training_window(n_recent = args_list$n_training) %>% - { - if (!is.null(args_list$check_enough_data_n)) { - check_enough_train_data( - ., - all_predictors(), - !!outcome, - n = args_list$check_enough_data_n, - epi_keys = args_list$check_enough_data_epi_keys, - drop_na = FALSE - ) - } else { - . - } - } + step_training_window(n_recent = args_list$n_training) + + if (!is.null(args_list$check_enough_data_n)) { + r <- check_enough_train_data( + r, + recipes::all_predictors(), + recipes::all_outcomes(), + n = args_list$check_enough_data_n, + epi_keys = args_list$check_enough_data_epi_keys, + drop_na = FALSE + ) + } + forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - target_date <- args_list$target_date %||% forecast_date + args_list$ahead + target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) # --- postprocessor f <- frosting() %>% layer_predict() # %>% layer_naomit() @@ -266,7 +263,7 @@ arx_class_args_list <- function( outcome_transform = c("growth_rate", "lag_difference"), breaks = 0.25, horizon = 7L, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, additional_gr_args = list(), nafill_buffer = Inf, @@ -276,8 +273,8 @@ arx_class_args_list <- function( rlang::check_dots_empty() .lags <- lags if (is.list(lags)) lags <- unlist(lags) - method <- match.arg(method) - outcome_transform <- match.arg(outcome_transform) + method <- rlang::arg_match(method) + outcome_transform <- rlang::arg_match(outcome_transform) arg_is_scalar(ahead, n_training, horizon, log_scale) arg_is_scalar(forecast_date, target_date, allow_null = TRUE) @@ -289,16 +286,24 @@ arx_class_args_list <- function( if (is.finite(n_training)) arg_is_pos_int(n_training) if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) if (!is.list(additional_gr_args)) { - cli::cli_abort( - c("`additional_gr_args` must be a {.cls list}.", - "!" = "This is a {.cls {class(additional_gr_args)}}.", - i = "See `?epiprocess::growth_rate` for available arguments." - ) - ) + cli_abort(c( + "`additional_gr_args` must be a {.cls list}.", + "!" = "This is a {.cls {class(additional_gr_args)}}.", + i = "See `?epiprocess::growth_rate` for available arguments." + )) } arg_is_pos(check_enough_data_n, allow_null = TRUE) arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE) + if (!is.null(forecast_date) && !is.null(target_date)) { + if (forecast_date + ahead != target_date) { + cli::cli_warn(c( + "`forecast_date` + `ahead` must equal `target_date`.", + i = "{.val {forecast_date}} + {.val {ahead}} != {.val {target_date}}." + )) + } + } + breaks <- sort(breaks) if (min(breaks) > -Inf) breaks <- c(-Inf, breaks) if (max(breaks) < Inf) breaks <- c(breaks, Inf) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index ce2fa57b0..37c9aae86 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -1,15 +1,18 @@ #' Direct autoregressive forecaster with covariates #' #' This is an autoregressive forecasting model for -#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning #' that it estimates a model for a particular target horizon. #' #' #' @param epi_data An `epi_df` object #' @param outcome A character (scalar) specifying the outcome (in the #' `epi_df`). -#' @param predictors A character vector giving column(s) of predictor -#' variables. +#' @param predictors A character vector giving column(s) of predictor variables. +#' This defaults to the `outcome`. However, if manually specified, only those variables +#' specifically mentioned will be used. (The `outcome` will not be added.) +#' By default, equals the outcome. If manually specified, does not add the +#' outcome variable, so make sure to specify it. #' @param trainer A `{parsnip}` model describing the type of estimation. #' For now, we enforce `mode = "regression"`. #' @param args_list A list of customization arguments to determine @@ -35,28 +38,27 @@ #' trainer = quantile_reg(), #' args_list = arx_args_list(quantile_levels = 1:9 / 10) #' ) -arx_forecaster <- function(epi_data, - outcome, - predictors, - trainer = parsnip::linear_reg(), - args_list = arx_args_list()) { +arx_forecaster <- function( + epi_data, + outcome, + predictors = outcome, + trainer = linear_reg(), + args_list = arx_args_list()) { if (!is_regression(trainer)) { - cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") } - wf <- arx_fcast_epi_workflow( - epi_data, outcome, predictors, trainer, args_list - ) + wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list) + wf <- fit(wf, epi_data) - latest <- get_test_data( - hardhat::extract_preprocessor(wf), epi_data, TRUE, args_list$nafill_buffer, - args_list$forecast_date %||% max(epi_data$time_value) - ) - - wf <- generics::fit(wf, epi_data) - preds <- predict(wf, new_data = latest) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) + preds <- forecast( + wf, + fill_locf = TRUE, + n_recent = args_list$nafill_buffer, + forecast_date = args_list$forecast_date %||% max(epi_data$time_value) + ) %>% + as_tibble() %>% + select(-time_value) structure( list( @@ -80,16 +82,18 @@ arx_forecaster <- function(epi_data, #' use [quantile_reg()]) but can be omitted. #' #' @inheritParams arx_forecaster -#' @param trainer A `{parsnip}` model describing the type of estimation. -#' For now, we enforce `mode = "regression"`. May be `NULL` (the default). +#' @param trainer A `{parsnip}` model describing the type of estimation. For +#' now, we enforce `mode = "regression"`. May be `NULL` if you'd like to +#' decide later. #' #' @return An unfitted `epi_workflow`. #' @export #' @seealso [arx_forecaster()] #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value >= as.Date("2021-12-01")) +#' filter(time_value >= as.Date("2021-12-01")) #' #' arx_fcast_epi_workflow( #' jhu, "death_rate", @@ -104,16 +108,16 @@ arx_forecaster <- function(epi_data, arx_fcast_epi_workflow <- function( epi_data, outcome, - predictors, - trainer = NULL, + predictors = outcome, + trainer = linear_reg(), args_list = arx_args_list()) { # --- validation validate_forecaster_inputs(epi_data, outcome, predictors) if (!inherits(args_list, c("arx_fcast", "alist"))) { - cli::cli_abort("args_list was not created using `arx_args_list().") + cli_abort("`args_list` was not created using `arx_args_list()`.") } if (!(is.null(trainer) || is_regression(trainer))) { - cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") } lags <- arx_lags_validator(predictors, args_list$lags) @@ -126,24 +130,22 @@ arx_fcast_epi_workflow <- function( r <- r %>% step_epi_ahead(!!outcome, ahead = args_list$ahead) %>% step_epi_naomit() %>% - step_training_window(n_recent = args_list$n_training) %>% - { - if (!is.null(args_list$check_enough_data_n)) { - check_enough_train_data( - ., - all_predictors(), - !!outcome, - n = args_list$check_enough_data_n, - epi_keys = args_list$check_enough_data_epi_keys, - drop_na = FALSE - ) - } else { - . - } - } + step_training_window(n_recent = args_list$n_training) + + if (!is.null(args_list$check_enough_data_n)) { + r <- check_enough_train_data( + r, + all_predictors(), + !!outcome, + n = args_list$check_enough_data_n, + epi_keys = args_list$check_enough_data_epi_keys, + drop_na = FALSE + ) + } + forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - target_date <- args_list$target_date %||% forecast_date + args_list$ahead + target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) # --- postprocessor f <- frosting() %>% layer_predict() # %>% layer_naomit() @@ -154,7 +156,7 @@ arx_fcast_epi_workflow <- function( rlang::eval_tidy(trainer$args$quantile_levels) )) args_list$quantile_levels <- quantile_levels - trainer$args$quantile_levels <- rlang::enquo(quantile_levels) + trainer$args$quantile_levels <- enquo(quantile_levels) f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>% layer_point_from_distn() } else { @@ -260,6 +262,15 @@ arx_args_list <- function( arg_is_pos(check_enough_data_n, allow_null = TRUE) arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE) + if (!is.null(forecast_date) && !is.null(target_date)) { + if (forecast_date + ahead != target_date) { + cli_warn(c( + "`forecast_date` + `ahead` must equal `target_date`.", + i = "{.val {forecast_date}} + {.val {ahead}} != {.val {target_date}}." + )) + } + } + max_lags <- max(lags) structure( enlist( @@ -289,8 +300,8 @@ print.arx_fcast <- function(x, ...) { } compare_quantile_args <- function(alist, tlist) { - default_alist <- eval(formals(arx_args_list)$quantile_level) - default_tlist <- eval(formals(quantile_reg)$quantile_level) + default_alist <- eval(formals(arx_args_list)$quantile_levels) + default_tlist <- eval(formals(quantile_reg)$quantile_levels) if (setequal(alist, default_alist)) { if (setequal(tlist, default_tlist)) { return(sort(unique(union(alist, tlist)))) @@ -304,7 +315,7 @@ compare_quantile_args <- function(alist, tlist) { if (setequal(alist, tlist)) { return(sort(unique(alist))) } - rlang::abort(c( + cli_abort(c( "You have specified different, non-default, quantiles in the trainier and `arx_args` options.", i = "Please only specify quantiles in one location." )) diff --git a/R/autoplot.R b/R/autoplot.R new file mode 100644 index 000000000..648c74e33 --- /dev/null +++ b/R/autoplot.R @@ -0,0 +1,294 @@ +#' @importFrom ggplot2 autoplot aes geom_point geom_line geom_ribbon geom_linerange +#' @export +ggplot2::autoplot + +#' Automatically plot an `epi_workflow` or `canned_epipred` object +#' +#' For a fit workflow, the training data will be displayed, the response by +#' default. If `predictions` is not `NULL` then point and interval forecasts +#' will be shown as well. Unfit workflows will result in an error, (you +#' can simply call `autoplot()` on the original `epi_df`). +#' +#' +#' +#' +#' @inheritParams epiprocess::autoplot.epi_df +#' @param object An `epi_workflow` +#' @param predictions A data frame with predictions. If `NULL`, only the +#' original data is shown. +#' @param .levels A numeric vector of levels to plot for any prediction bands. +#' More than 3 levels begins to be difficult to see. +#' @param ... Ignored +#' @param .facet_by Similar to `.color_by` except that the default is to +#' display the response. +#' @param .base_color If available, prediction bands will be shown with this +#' color. +#' @param .point_pred_color If available, point forecasts will be shown with +#' this color. +#' +#' @name autoplot-epipred +#' @examples +#' library(dplyr) +#' jhu <- case_death_rate_subset %>% +#' filter(time_value >= as.Date("2021-11-01")) +#' +#' r <- epi_recipe(jhu) %>% +#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_ahead(death_rate, ahead = 7) %>% +#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% +#' step_epi_naomit() +#' +#' f <- frosting() %>% +#' layer_residual_quantiles( +#' quantile_levels = c(.025, .1, .25, .75, .9, .975) +#' ) %>% +#' layer_threshold(starts_with(".pred")) %>% +#' layer_add_target_date() +#' +#' wf <- epi_workflow(r, linear_reg(), f) %>% fit(jhu) +#' +#' autoplot(wf) +#' +#' latest <- jhu %>% filter(time_value >= max(time_value) - 14) +#' preds <- predict(wf, latest) +#' autoplot(wf, preds, .max_facets = 4) +#' +#' # ------- Show multiple horizons +#' +#' p <- lapply(c(7, 14, 21, 28), function(h) { +#' r <- epi_recipe(jhu) %>% +#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_ahead(death_rate, ahead = h) %>% +#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% +#' step_epi_naomit() +#' ewf <- epi_workflow(r, linear_reg(), f) %>% fit(jhu) +#' forecast(ewf) +#' }) +#' +#' p <- do.call(rbind, p) +#' autoplot(wf, p, .max_facets = 4) +#' +#' # ------- Plotting canned forecaster output +#' +#' jhu <- case_death_rate_subset %>% +#' filter(time_value >= as.Date("2021-11-01")) +#' flat <- flatline_forecaster(jhu, "death_rate") +#' autoplot(flat, .max_facets = 4) +#' +#' arx <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), +#' args_list = arx_args_list(ahead = 14L) +#' ) +#' autoplot(arx, .max_facets = 6) +NULL + +#' @export +#' @rdname autoplot-epipred +autoplot.epi_workflow <- function( + object, predictions = NULL, + .levels = c(.5, .8, .95), ..., + .color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"), + .facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"), + .base_color = "dodgerblue4", + .point_pred_color = "orange", + .max_facets = Inf) { + rlang::check_dots_empty() + arg_is_probabilities(.levels) + rlang::arg_match(.color_by) + rlang::arg_match(.facet_by) + + if (!workflows::is_trained_workflow(object)) { + cli_abort(c( + "Can't plot an untrained {.cls epi_workflow}.", + i = "Do you need to call `fit()`?" + )) + } + + mold <- workflows::extract_mold(object) + y <- mold$outcomes + if (ncol(y) > 1) { + y <- y[, 1] + cli_warn("Multiple outcome variables were detected. Displaying only 1.") + } + keys <- c("geo_value", "time_value", "key") + mold_roles <- names(mold$extras$roles) + edf <- bind_cols(mold$extras$roles[mold_roles %in% keys], y) + if (starts_with_impl("ahead_", names(y))) { + old_name_y <- unlist(strsplit(names(y), "_")) + shift <- as.numeric(old_name_y[2]) + new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_") + edf <- rename(edf, !!new_name_y := !!names(y)) + } else if (starts_with_impl("lag_", names(y))) { + old_name_y <- unlist(strsplit(names(y), "_")) + shift <- -as.numeric(old_name_y[2]) + new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_") + edf <- rename(edf, !!new_name_y := !!names(y)) + } + + if (!is.null(shift)) { + edf <- mutate(edf, time_value = time_value + shift) + } + extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value")) + if (length(extra_keys) == 0L) extra_keys <- NULL + edf <- as_epi_df(edf, + as_of = object$fit$meta$as_of, + other_keys = extra_keys %||% character() + ) + if (is.null(predictions)) { + return(autoplot( + edf, new_name_y, + .color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color, + .max_facets = .max_facets + )) + } + + if ("target_date" %in% names(predictions)) { + if ("time_value" %in% names(predictions)) { + predictions <- select(predictions, -time_value) + } + predictions <- rename(predictions, time_value = target_date) + } + pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(edf)) + if (!pred_cols_ok$ok) { + cli_warn(c( + "`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.", + i = "Plotting the original data." + )) + return(autoplot( + edf, !!new_name_y, + .color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color, + .max_facets = .max_facets + )) + } + + # First we plot the history, always faceted by everything + bp <- autoplot(edf, !!new_name_y, + .color_by = "none", .facet_by = "all_keys", + .base_color = "black", .max_facets = .max_facets + ) + + # Now, prepare matching facets in the predictions + ek <- epi_keys_only(edf) + predictions <- predictions %>% + mutate( + .facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"), + ) + if (.max_facets < Inf) { + top_n <- levels(as.factor(bp$data$.facets))[seq_len(.max_facets)] + predictions <- filter(predictions, .facets %in% top_n) %>% + mutate(.facets = droplevels(.facets)) + } + + + if (".pred_distn" %in% names(predictions)) { + bp <- plot_bands(bp, predictions, .levels, .base_color) + } + + if (".pred" %in% names(predictions)) { + ntarget_dates <- n_distinct(predictions$time_value) + if (ntarget_dates > 1L) { + bp <- bp + + geom_line( + data = predictions, aes(y = .data$.pred), + color = .point_pred_color + ) + } else { + bp <- bp + + geom_point( + data = predictions, aes(y = .data$.pred), + color = .point_pred_color + ) + } + } + bp +} + +#' @export +#' @rdname autoplot-epipred +autoplot.canned_epipred <- function( + object, ..., + .color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"), + .facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"), + .base_color = "dodgerblue4", + .point_pred_color = "orange", + .max_facets = Inf) { + rlang::check_dots_empty() + rlang::arg_match(.color_by) + rlang::arg_match(.facet_by) + + ewf <- object$epi_workflow + predictions <- object$predictions %>% + dplyr::rename(time_value = target_date) + + autoplot(ewf, predictions, + .color_by = .color_by, .facet_by = .facet_by, + .base_color = .base_color, .max_facets = .max_facets + ) +} + +starts_with_impl <- function(x, vars) { + n <- nchar(x) + x == substr(vars, 1, n) +} + +plot_bands <- function( + base_plot, predictions, + levels = c(.5, .8, .95), + fill = "blue4", + alpha = 0.6, + linewidth = 0.05) { + innames <- names(predictions) + n <- length(levels) + alpha <- alpha / (n - 1) + l <- (1 - levels) / 2 + l <- c(rev(l), 1 - l) + + ntarget_dates <- dplyr::n_distinct(predictions$time_value) + + predictions <- predictions %>% + mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, l), l)) %>% + pivot_quantiles_wider(.pred_distn) + qnames <- setdiff(names(predictions), innames) + + for (i in 1:n) { + bottom <- qnames[i] + top <- rev(qnames)[i] + if (i == 1) { + if (ntarget_dates > 1L) { + base_plot <- base_plot + + geom_ribbon( + data = predictions, + aes(ymin = .data[[bottom]], ymax = .data[[top]]), + alpha = 0.2, linewidth = linewidth, fill = fill + ) + } else { + base_plot <- base_plot + + geom_linerange( + data = predictions, + aes(ymin = .data[[bottom]], ymax = .data[[top]]), + alpha = 0.2, linewidth = 2, color = fill + ) + } + } else { + if (ntarget_dates > 1L) { + base_plot <- base_plot + + geom_ribbon( + data = predictions, + aes(ymin = .data[[bottom]], ymax = .data[[top]]), + fill = fill, alpha = alpha + ) + } else { + base_plot <- base_plot + + geom_linerange( + data = predictions, + aes(ymin = .data[[bottom]], ymax = .data[[top]]), + color = fill, alpha = alpha, linewidth = 2 + ) + } + } + } + base_plot +} + +find_level <- function(x) { + unique((x < .5) * (1 - 2 * x) + (x > .5) * (1 - 2 * (1 - x))) +} diff --git a/R/bake.epi_recipe.R b/R/bake.epi_recipe.R deleted file mode 100644 index 6857df4ef..000000000 --- a/R/bake.epi_recipe.R +++ /dev/null @@ -1,105 +0,0 @@ -#' Bake an epi_recipe -#' -#' @param object A trained object such as a [recipe()] with at least -#' one preprocessing operation. -#' @param new_data An `epi_df`, data frame or tibble for whom the -#' preprocessing will be applied. If `NULL` is given to `new_data`, -#' the pre-processed _training data_ will be returned. -#' @param ... One or more selector functions to choose which variables will be -#' returned by the function. See [recipes::selections()] for -#' more details. If no selectors are given, the default is to -#' use [tidyselect::everything()]. -#' @return An `epi_df` that may have different columns than the -#' original columns in `new_data`. -#' @importFrom rlang is_empty quos -#' @importFrom tibble is_tibble as_tibble -#' @importFrom methods is -#' @rdname bake -#' @export -bake.epi_recipe <- function(object, new_data, ...) { - if (rlang::is_missing(new_data)) { - rlang::abort("'new_data' must be either an epi_df or NULL. No value is not allowed.") - } - - if (is.null(new_data)) { - return(epi_juice(object, ...)) - } - - if (!fully_trained(object)) { - rlang::abort("At least one step has not been trained. Please run `prep`.") - } - - terms <- quos(...) - if (is_empty(terms)) { - terms <- quos(tidyselect::everything()) - } - - # In case someone used the deprecated `newdata`: - if (is.null(new_data) || is.null(ncol(new_data))) { - if (any(names(terms) == "newdata")) { - rlang::abort("Please use `new_data` instead of `newdata` with `bake`.") - } else { - rlang::abort("Please pass a data set to `new_data`.") - } - } - - if (!is_tibble(new_data)) { - new_data <- as_tibble(new_data) - } - - recipes:::check_role_requirements(object, new_data) - - recipes:::check_nominal_type(new_data, object$orig_lvls) - - # Drop completely new columns from `new_data` and reorder columns that do - # still exist to match the ordering used when training - original_names <- names(new_data) - original_training_names <- unique(object$var_info$variable) - bakeable_names <- intersect(original_training_names, original_names) - new_data <- new_data[, bakeable_names] - - n_steps <- length(object$steps) - - for (i in seq_len(n_steps)) { - step <- object$steps[[i]] - - if (recipes:::is_skipable(step)) { - next - } - - new_data <- bake(step, new_data = new_data) - - if (!is_tibble(new_data)) { - abort("bake() methods should always return tibbles") - } - } - - # Use `last_term_info`, which maintains info on all columns that got added - # and removed from the training data. This is important for skipped steps - # which might have resulted in columns not being added/removed in the test - # set. - info <- object$last_term_info - - # Now reduce to only user selected columns - out_names <- recipes_eval_select(terms, new_data, info, - check_case_weights = FALSE - ) - new_data <- new_data[, out_names] - - # The levels are not null when no nominal data are present or - # if strings_as_factors = FALSE in `prep` - if (!is.null(object$levels)) { - var_levels <- object$levels - var_levels <- var_levels[out_names] - check_values <- - vapply(var_levels, function(x) { - (!all(is.na(x))) - }, c(all = TRUE)) - var_levels <- var_levels[check_values] - if (length(var_levels) > 0) { - new_data <- recipes:::strings2factors(new_data, var_levels) - } - } - - new_data -} diff --git a/R/canned-epipred.R b/R/canned-epipred.R index 802d0f7e4..0adc0536a 100644 --- a/R/canned-epipred.R +++ b/R/canned-epipred.R @@ -1,18 +1,18 @@ validate_forecaster_inputs <- function(epi_data, outcome, predictors) { - if (!epiprocess::is_epi_df(epi_data)) { - cli::cli_abort(c( + if (!is_epi_df(epi_data)) { + cli_abort(c( "`epi_data` must be an {.cls epi_df}.", "!" = "This one is a {.cls {class(epi_data)}}." )) } - arg_is_chr(predictors) arg_is_chr_scalar(outcome) + arg_is_chr(predictors) if (!outcome %in% names(epi_data)) { - cli::cli_abort("{.var {outcome}} was not found in the training data.") + cli_abort("{.var {outcome}} was not found in the training data.") } check <- hardhat::check_column_names(epi_data, predictors) if (!check$ok) { - cli::cli_abort(c( + cli_abort(c( "At least one predictor was not found in the training data.", "!" = "The following required columns are missing: {.val {check$missing_names}}." )) @@ -29,7 +29,7 @@ arx_lags_validator <- function(predictors, lags) { if (l == 1) { lags <- rep(lags, p) } else if (length(lags) != p) { - cli::cli_abort(c( + cli_abort(c( "You have requested {p} predictor(s) but {l} different lags.", i = "Lags must be a vector or a list with length == number of predictors." )) @@ -39,7 +39,7 @@ arx_lags_validator <- function(predictors, lags) { lags <- lags[order(match(names(lags), predictors))] } else { predictors_miss <- setdiff(predictors, names(lags)) - cli::cli_abort(c( + cli_abort(c( "If lags is a named list, then all predictors must be present.", i = "The predictors are {.var {predictors}}.", i = "So lags is missing {.var {predictors_miss}}'." @@ -51,9 +51,15 @@ arx_lags_validator <- function(predictors, lags) { } + #' @export print.alist <- function(x, ...) { - utils::str(x) + nm <- names(x) + for (i in seq_along(x)) { + if (is.null(x[[i]])) x[[i]] <- "NULL" + if (length(x[[i]]) == 0L) x[[i]] <- "_empty_" + cli::cli_bullets(c("*" = "{nm[[i]]} : {.val {x[[i]]}}")) + } } #' @export @@ -67,11 +73,17 @@ print.canned_epipred <- function(x, name, ...) { ) cli::cli_text("") cli::cli_text("Training data was an {.cls epi_df} with:") - cli::cli_ul(c( - "Geography: {.field {x$metadata$training$geo_type}},", - "Time type: {.field {x$metadata$training$time_type}},", - "Using data up-to-date as of: {.field {format(x$metadata$training$as_of)}}." - )) + fn_meta <- function() { + cli::cli_ul() + cli::cli_li("Geography: {.field {x$metadata$training$geo_type}},") + if (!is.null(x$metadata$training$other_keys)) { + cli::cli_li("Other keys: {.field {x$metadata$training$other_keys}},") + } + cli::cli_li("Time type: {.field {x$metadata$training$time_type}},") + cli::cli_li("Using data up-to-date as of: {.field {format(x$metadata$training$as_of)}}.") + cli::cli_end() + } + fn_meta() cli::cli_text("") cli::cli_rule("Predictions") diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index abb231bca..b2e7434e2 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -1,7 +1,7 @@ #' Predict the future with the most recent value #' #' This is a simple forecasting model for -#' [epiprocess::epi_df] data. It uses the most recent observation as the +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the #' forecast for any future date, and produces intervals by shuffling the quantiles #' of the residuals of such a "flatline" forecast and incrementing these #' forward over all available training data. @@ -12,7 +12,7 @@ #' This forecaster is meant to produce exactly the CDC Baseline used for #' [COVID19ForecastHub](https://covid19forecasthub.org) #' -#' @param epi_data An [`epiprocess::epi_df`] +#' @param epi_data An [`epiprocess::epi_df`][epiprocess::as_epi_df] #' @param outcome A scalar character for the column name we wish to predict. #' @param args_list A list of additional arguments as created by the #' [cdc_baseline_args_list()] constructor function. @@ -29,11 +29,11 @@ #' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>% #' select(-pop, -death_rate) %>% #' group_by(geo_value) %>% -#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>% +#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>% #' ungroup() %>% #' filter(weekdays(time_value) == "Saturday") #' -#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") +#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum") #' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn) #' #' if (require(ggplot2)) { @@ -47,7 +47,7 @@ #' geom_line(aes(y = .pred), color = "orange") + #' geom_line( #' data = weekly_deaths %>% filter(geo_value %in% four_states), -#' aes(x = time_value, y = deaths) +#' aes(x = time_value, y = deaths_7dsum) #' ) + #' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) + #' labs(x = "Date", y = "Weekly deaths") + @@ -61,9 +61,9 @@ cdc_baseline_forecaster <- function( args_list = cdc_baseline_args_list()) { validate_forecaster_inputs(epi_data, outcome, "time_value") if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) { - cli_stop("args_list was not created using `cdc_baseline_args_list().") + cli_abort("`args_list` was not created using `cdc_baseline_args_list().") } - keys <- epi_keys(epi_data) + keys <- key_colnames(epi_data) ek <- kill_time_value(keys) outcome <- rlang::sym(outcome) @@ -75,7 +75,7 @@ cdc_baseline_forecaster <- function( step_training_window(n_recent = args_list$n_training) forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - # target_date <- args_list$target_date %||% forecast_date + args_list$ahead + # target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) latest <- get_test_data( @@ -98,14 +98,14 @@ cdc_baseline_forecaster <- function( # layer_add_target_date(target_date = target_date) if (args_list$nonneg) f <- layer_threshold(f, ".pred") - eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") + eng <- linear_reg(engine = "flatline") wf <- epi_workflow(r, eng, f) - wf <- generics::fit(wf, epi_data) + wf <- fit(wf, epi_data) preds <- suppressWarnings(predict(wf, new_data = latest)) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) %>% - dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency) + as_tibble() %>% + select(-time_value) %>% + mutate(target_date = forecast_date + ahead * args_list$data_frequency) structure( list( @@ -218,11 +218,11 @@ parse_period <- function(x) { mult <- switch(mult, day = 1L, wee = 7L, - cli::cli_abort("incompatible timespan in `aheads`.") + cli_abort("incompatible timespan in `aheads`.") ) x <- as.numeric(x[1]) * mult } - if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.") + if (length(x) > 2L) cli_abort("incompatible timespan in `aheads`.") } stopifnot(rlang::is_integerish(x)) as.integer(x) diff --git a/R/check_enough_train_data.R b/R/check_enough_train_data.R index af2183d15..1279a3712 100644 --- a/R/check_enough_train_data.R +++ b/R/check_enough_train_data.R @@ -49,13 +49,13 @@ check_enough_train_data <- columns = NULL, skip = TRUE, id = rand_id("enough_train_data")) { - add_check( + recipes::add_check( recipe, check_enough_train_data_new( n = n, epi_keys = epi_keys, drop_na = drop_na, - terms = rlang::enquos(...), + terms = enquos(...), role = role, trained = trained, columns = columns, @@ -67,7 +67,7 @@ check_enough_train_data <- check_enough_train_data_new <- function(n, epi_keys, drop_na, terms, role, trained, columns, skip, id) { - check( + recipes::check( subclass = "enough_train_data", prefix = "check_", n = n, @@ -83,30 +83,24 @@ check_enough_train_data_new <- } #' @export -#' @importFrom dplyr group_by summarise ungroup across all_of n -#' @importFrom tidyr drop_na prep.check_enough_train_data <- function(x, training, info = NULL, ...) { - col_names <- recipes_eval_select(x$terms, training, info) + col_names <- recipes::recipes_eval_select(x$terms, training, info) if (is.null(x$n)) { x$n <- length(col_names) } + if (x$drop_na) { + training <- tidyr::drop_na(training) + } cols_not_enough_data <- training %>% - { - if (x$drop_na) { - drop_na(.) - } else { - . - } - } %>% group_by(across(all_of(.env$x$epi_keys))) %>% - summarise(across(all_of(.env$col_names), ~ n() < .env$x$n), .groups = "drop") %>% + summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$x$n), .groups = "drop") %>% summarise(across(all_of(.env$col_names), any), .groups = "drop") %>% unlist() %>% names(.)[.] if (length(cols_not_enough_data) > 0) { - cli::cli_abort( + cli_abort( "The following columns don't have enough data to predict: {cols_not_enough_data}." ) } @@ -132,16 +126,16 @@ bake.check_enough_train_data <- function(object, new_data, ...) { #' @export print.check_enough_train_data <- function(x, width = max(20, options()$width - 30), ...) { title <- paste0("Check enough data (n = ", x$n, ") for ") - print_step(x$columns, x$terms, x$trained, title, width) + recipes::print_step(x$columns, x$terms, x$trained, title, width) invisible(x) } #' @export tidy.check_enough_train_data <- function(x, ...) { - if (is_trained(x)) { + if (recipes::is_trained(x)) { res <- tibble(terms = unname(x$columns)) } else { - res <- tibble(terms = sel2char(x$terms)) + res <- tibble(terms = recipes::sel2char(x$terms)) } res$id <- x$id res$n <- x$n diff --git a/R/compat-purrr.R b/R/compat-purrr.R index 712926f73..e06038e44 100644 --- a/R/compat-purrr.R +++ b/R/compat-purrr.R @@ -1,37 +1,8 @@ -# See https://github.com/r-lib/rlang/blob/main/R/compat-purrr.R - - -map <- function(.x, .f, ...) { - .f <- rlang::as_function(.f, env = rlang::global_env()) - lapply(.x, .f, ...) -} - -walk <- function(.x, .f, ...) { - map(.x, .f, ...) - invisible(.x) -} - walk2 <- function(.x, .y, .f, ...) { map2(.x, .y, .f, ...) invisible(.x) } -map_lgl <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, logical(1), ...) -} - -map_int <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, integer(1), ...) -} - -map_dbl <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, double(1), ...) -} - -map_chr <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, character(1), ...) -} - map_vec <- function(.x, .f, ...) { out <- map(.x, .f, ...) vctrs::list_unchop(out) @@ -48,61 +19,3 @@ map2_dfr <- function(.x, .y, .f, ..., .id = NULL) { res <- map2(.x, .y, .f, ...) dplyr::bind_rows(res, .id = .id) } - -.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) { - .f <- rlang::as_function(.f, env = rlang::global_env()) - out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE) - names(out) <- names(.x) - out -} - -.rlang_purrr_args_recycle <- function(args) { - lengths <- map_int(args, length) - n <- max(lengths) - - stopifnot(all(lengths == 1L | lengths == n)) - to_recycle <- lengths == 1L - args[to_recycle] <- map(args[to_recycle], function(x) rep.int(x, n)) - - args -} - -map2 <- function(.x, .y, .f, ...) { - .f <- rlang::as_function(.f, env = rlang::global_env()) - out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE) - if (length(out) == length(.x)) { - rlang::set_names(out, names(.x)) - } else { - rlang::set_names(out, NULL) - } -} -map2_lgl <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "logical") -} -map2_int <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "integer") -} -map2_dbl <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "double") -} -map2_chr <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "character") -} -imap <- function(.x, .f, ...) { - map2(.x, names(.x) %||% seq_along(.x), .f, ...) -} - -pmap <- function(.l, .f, ...) { - .f <- as.function(.f) - args <- .rlang_purrr_args_recycle(.l) - do.call("mapply", c( - FUN = list(quote(.f)), - args, MoreArgs = quote(list(...)), - SIMPLIFY = FALSE, USE.NAMES = FALSE - )) -} - -reduce <- function(.x, .f, ..., .init) { - f <- function(x, y) .f(x, y, ...) - Reduce(f, .x, init = .init) -} diff --git a/R/create-layer.R b/R/create-layer.R index 69aeee7eb..0268a906f 100644 --- a/R/create-layer.R +++ b/R/create-layer.R @@ -7,8 +7,8 @@ #' @inheritParams usethis::use_test #' #' @importFrom rlang %||% -#' @export -#' +#' @noRd +#' @keywords internal #' @examples #' \dontrun{ #' diff --git a/R/data.R b/R/data.R index 622d1c7d7..71e5bdcd3 100644 --- a/R/data.R +++ b/R/data.R @@ -56,3 +56,32 @@ #' \url{https://www.census.gov/data/tables/time-series/demo/popest/2010s-total-puerto-rico-municipios.html}, #' and \url{https://www.census.gov/data/tables/2010/dec/2010-island-areas.html} "state_census" + +#' Subset of Statistics Canada median employment income for postsecondary graduates +#' +#' @format An [epiprocess::epi_df][epiprocess::as_epi_df] with 10193 rows and 8 variables: +#' \describe{ +#' \item{geo_value}{The province in Canada associated with each +#' row of measurements.} +#' \item{time_value}{The time value, a year integer in YYYY format} +#' \item{edu_qual}{The education qualification} +#' \item{fos}{The field of study} +#' \item{age_group}{The age group; either 15 to 34 or 35 to 64} +#' \item{num_graduates}{The number of graduates for the given row of characteristics} +#' \item{med_income_2y}{The median employment income two years after graduation} +#' \item{med_income_5y}{The median employment income five years after graduation} +#' } +#' @source This object contains modified data from the following Statistics Canada +#' data table: \href{https://www150.statcan.gc.ca/t1/tbl1/en/tv.action?pid=3710011501}{ +#' Characteristics and median employment income of longitudinal cohorts of postsecondary +#' graduates two and five years after graduation, by educational qualification and +#' field of study (primary groupings) +#' } +#' +#' Modifications: +#' * Only provincial-level geo_values are kept +#' * Only age group, field of study, and educational qualification are kept as +#' covariates. For the remaining covariates, we keep aggregated values and +#' drop the level-specific rows. +#' * No modifications were made to the time range of the data +"grad_employ_subset" diff --git a/R/dist_quantiles.R b/R/dist_quantiles.R index 750e9560d..dd97ec809 100644 --- a/R/dist_quantiles.R +++ b/R/dist_quantiles.R @@ -1,10 +1,21 @@ #' @importFrom vctrs field vec_cast new_rcrd -new_quantiles <- function(values = double(), quantile_levels = double()) { +new_quantiles <- function(values = double(1), quantile_levels = double(1)) { arg_is_probabilities(quantile_levels) vec_cast(values, double()) vec_cast(quantile_levels, double()) + values <- unname(values) + if (length(values) == 0L) { + return(new_rcrd( + list( + values = rep(NA_real_, length(quantile_levels)), + quantile_levels = quantile_levels + ), + class = c("dist_quantiles", "dist_default") + )) + } stopifnot(length(values) == length(quantile_levels)) + stopifnot(!vctrs::vec_duplicate_any(quantile_levels)) if (is.unsorted(quantile_levels)) { o <- vctrs::vec_order(quantile_levels) @@ -37,13 +48,23 @@ format.dist_quantiles <- function(x, digits = 2, ...) { #' A distribution parameterized by a set of quantiles #' -#' @param values A vector of values -#' @param quantile_levels A vector of probabilities corresponding to `values` +#' @param values A vector (or list of vectors) of values. +#' @param quantile_levels A vector (or list of vectors) of probabilities +#' corresponding to `values`. +#' +#' When creating multiple sets of `values`/`quantile_levels` resulting in +#' different distributions, the sizes must match. See the examples below. +#' +#' @return A vector of class `"distribution"`. #' #' @export #' #' @examples -#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) +#' dist_quantiles(1:4, 1:4 / 5) +#' dist_quantiles(list(1:3, 1:4), list(1:3 / 4, 1:4 / 5)) +#' dstn <- dist_quantiles(list(1:4, 8:11), c(.2, .4, .6, .8)) +#' dstn +#' #' quantile(dstn, p = c(.1, .25, .5, .9)) #' median(dstn) #' @@ -51,16 +72,25 @@ format.dist_quantiles <- function(x, digits = 2, ...) { #' distributional::parameters(dstn[1]) #' nested_quantiles(dstn[1])[[1]] #' -#' dist_quantiles(1:4, 1:4 / 5) #' @importFrom vctrs as_list_of vec_recycle_common new_vctr dist_quantiles <- function(values, quantile_levels) { - if (!is.list(values)) values <- list(values) - if (!is.list(quantile_levels)) quantile_levels <- list(quantile_levels) + if (!is.list(quantile_levels)) { + assert_numeric(quantile_levels, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L) + quantile_levels <- list(quantile_levels) + } + if (!is.list(values)) { + if (length(values) == 0L) values <- NA_real_ + values <- list(values) + } values <- as_list_of(values, .ptype = double()) quantile_levels <- as_list_of(quantile_levels, .ptype = double()) args <- vec_recycle_common(values = values, quantile_levels = quantile_levels) - qntls <- as_list_of(map2(args$values, args$quantile_levels, new_quantiles)) + + qntls <- as_list_of( + map2(args$values, args$quantile_levels, new_quantiles), + .ptype = new_quantiles(NA_real_, 0.5) + ) new_vctr(qntls, class = "distribution") } @@ -87,59 +117,6 @@ validate_dist_quantiles <- function(values, quantile_levels) { } -#' Summarize a distribution with a set of quantiles -#' -#' @param x a `distribution` vector -#' @param probs a vector of probabilities at which to calculate quantiles -#' @param ... additional arguments passed on to the `quantile` method -#' -#' @return a `distribution` vector containing `dist_quantiles` -#' @export -#' -#' @examples -#' library(distributional) -#' dstn <- dist_normal(c(10, 2), c(5, 10)) -#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) -#' -#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) -#' # because this distribution is already quantiles, any extra quantiles are -#' # appended -#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) -#' -#' dstn <- c( -#' dist_normal(c(10, 2), c(5, 10)), -#' dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) -#' ) -#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) -extrapolate_quantiles <- function(x, probs, ...) { - UseMethod("extrapolate_quantiles") -} - -#' @export -#' @importFrom vctrs vec_data -extrapolate_quantiles.distribution <- function(x, probs, ...) { - arg_is_probabilities(probs) - dstn <- lapply(vec_data(x), extrapolate_quantiles, probs = probs, ...) - new_vctr(dstn, vars = NULL, class = "distribution") -} - -#' @export -extrapolate_quantiles.dist_default <- function(x, probs, ...) { - values <- quantile(x, probs, ...) - new_quantiles(values = values, quantile_levels = probs) -} - -#' @export -extrapolate_quantiles.dist_quantiles <- function(x, probs, ...) { - new_values <- quantile(x, probs, ...) - quantile_levels <- field(x, "quantile_levels") - values <- field(x, "values") - new_quantiles( - values = c(values, new_values), - quantile_levels = c(quantile_levels, probs) - ) -} - is_dist_quantiles <- function(x) { is_distribution(x) & all(stats::family(x) == "quantiles") } @@ -172,35 +149,31 @@ mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "line #' @export #' @importFrom stats quantile #' @import distributional -quantile.dist_quantiles <- function( - x, p, ..., - middle = c("cubic", "linear"), - left_tail = c("normal", "exponential"), - right_tail = c("normal", "exponential")) { +quantile.dist_quantiles <- function(x, p, ..., middle = c("cubic", "linear")) { arg_is_probabilities(p) + p <- sort(p) middle <- match.arg(middle) - left_tail <- match.arg(left_tail) - right_tail <- match.arg(right_tail) - quantile_extrapolate(x, p, middle, left_tail, right_tail) + quantile_extrapolate(x, p, middle) } -quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { +quantile_extrapolate <- function(x, tau_out, middle) { tau <- field(x, "quantile_levels") qvals <- field(x, "values") - r <- range(tau, na.rm = TRUE) + nas <- is.na(qvals) qvals_out <- rep(NA, length(tau_out)) + qvals <- qvals[!nas] + tau <- tau[!nas] # short circuit if we aren't actually extrapolating # matches to ~15 decimals if (all(tau_out %in% tau)) { return(qvals[match(tau_out, tau)]) } - if (length(qvals) < 3 || r[1] > .25 || r[2] < .75) { - cli::cli_warn(c( - "Quantile extrapolation is not possible with fewer than", - "3 quantiles or when the probs don't span [.25, .75]" - )) + if (length(tau) < 2) { + cli::cli_abort( + "Quantile extrapolation is not possible with fewer than 2 quantiles." + ) return(qvals_out) } @@ -213,7 +186,6 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { result <- tryCatch( { Q <- stats::splinefun(tau, qvals, method = "hyman") - qvals_out[indm] <- Q(tau_out[indm]) quartiles <- Q(c(.25, .5, .75)) }, error = function(e) { @@ -225,75 +197,47 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { method <- "linear" quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y } - - if (any(indm)) { qvals_out[indm] <- switch(method, linear = stats::approx(tau, qvals, tau_out[indm])$y, cubic = Q(tau_out[indm]) ) } + if (any(indl) || any(indr)) { + qv <- data.frame( + q = c(tau, tau_out[indm]), + v = c(qvals, qvals_out[indm]) + ) %>% + dplyr::distinct(q, .keep_all = TRUE) %>% + dplyr::arrange(q) + } if (any(indl)) { - qvals_out[indl] <- tail_extrapolate( - tau_out[indl], quartiles, "left", left_tail - ) + qvals_out[indl] <- tail_extrapolate(tau_out[indl], utils::head(qv, 2)) } if (any(indr)) { - qvals_out[indr] <- tail_extrapolate( - tau_out[indr], quartiles, "right", right_tail - ) + qvals_out[indr] <- tail_extrapolate(tau_out[indr], utils::tail(qv, 2)) } qvals_out } -tail_extrapolate <- function(tau_out, quartiles, tail, type) { - if (tail == "left") { - p <- c(.25, .5) - par <- quartiles[1:2] - } - if (tail == "right") { - p <- c(.75, .5) - par <- quartiles[3:2] - } - if (type == "normal") { - return(norm_tail_q(p, par, tau_out)) - } - if (type == "exponential") { - return(exp_tail_q(p, par, tau_out)) - } -} - - -exp_q_par <- function(q) { - # tau should always be c(.75, .5) or c(.25, .5) - iqr <- 2 * abs(diff(q)) - s <- iqr / (2 * log(2)) - m <- q[2] - return(list(m = m, s = s)) +logit <- function(p) { + p <- pmax(pmin(p, 1), 0) + log(p) - log(1 - p) } -exp_tail_q <- function(p, q, target) { - ms <- exp_q_par(q) - qlaplace(target, ms$m, ms$s) -} - -qlaplace <- function(p, centre = 0, b = 1) { - # lower.tail = TRUE, log.p = FALSE - centre - b * sign(p - 0.5) * log(1 - 2 * abs(p - 0.5)) -} - -norm_q_par <- function(q) { - # tau should always be c(.75, .5) or c(.25, .5) - iqr <- 2 * abs(diff(q)) - s <- iqr / 1.34897950039 # abs(diff(qnorm(c(.75, .25)))) - m <- q[2] - return(list(m = m, s = s)) +# extrapolates linearly on the logistic scale using +# the two points nearest the tail +tail_extrapolate <- function(tau_out, qv) { + if (nrow(qv) == 1L) { + return(rep(qv$v[1], length(tau_out))) + } + x <- logit(qv$q) + x0 <- logit(tau_out) + y <- qv$v + m <- diff(y) / diff(x) + m * (x0 - x[1]) + y[1] } -norm_tail_q <- function(p, q, target) { - ms <- norm_q_par(q) - stats::qnorm(target, ms$m, ms$s) -} #' @method Math dist_quantiles #' @export diff --git a/R/epi_check_training_set.R b/R/epi_check_training_set.R index 0c7dc9036..596e99887 100644 --- a/R/epi_check_training_set.R +++ b/R/epi_check_training_set.R @@ -16,7 +16,7 @@ epi_check_training_set <- function(x, rec) { if (!is.null(old_ok)) { if (all(old_ok %in% colnames(x))) { # case 1 if (!all(old_ok %in% new_ok)) { - cli::cli_warn(c( + cli_warn(c( "The recipe specifies additional keys. Because these are available,", "they are being added to the metadata of the training data." )) @@ -25,7 +25,7 @@ epi_check_training_set <- function(x, rec) { } missing_ok <- setdiff(old_ok, colnames(x)) if (length(missing_ok) > 0) { # case 2 - cli::cli_abort(c( + cli_abort(c( "The recipe specifies keys which are not in the training data.", i = "The training set is missing columns for {missing_ok}." )) @@ -45,8 +45,8 @@ validate_meta_match <- function(x, template, meta, warn_or_abort = "warn") { ) if (new_meta != old_meta) { switch(warn_or_abort, - warn = cli::cli_warn(msg), - abort = cli::cli_abort(msg) + warn = cli_warn(msg), + abort = cli_abort(msg) ) } } diff --git a/R/epi_juice.R b/R/epi_juice.R deleted file mode 100644 index d9d23df97..000000000 --- a/R/epi_juice.R +++ /dev/null @@ -1,43 +0,0 @@ -#' Extract transformed training set -#' -#' @inheritParams bake.epi_recipe -epi_juice <- function(object, ...) { - if (!fully_trained(object)) { - rlang::abort("At least one step has not been trained. Please run `prep()`.") - } - - if (!isTRUE(object$retained)) { - rlang::abort(paste0( - "Use `retain = TRUE` in `prep()` to be able ", - "to extract the training set" - )) - } - - terms <- quos(...) - if (is_empty(terms)) { - terms <- quos(dplyr::everything()) - } - - # Get user requested columns - new_data <- object$template - out_names <- recipes_eval_select(terms, new_data, object$term_info, - check_case_weights = FALSE - ) - new_data <- new_data[, out_names] - - # Since most models require factors, do the conversion from character - if (!is.null(object$levels)) { - var_levels <- object$levels - var_levels <- var_levels[out_names] - check_values <- - vapply(var_levels, function(x) { - (!all(is.na(x))) - }, c(all = TRUE)) - var_levels <- var_levels[check_values] - if (length(var_levels) > 0) { - new_data <- recipes:::strings2factors(new_data, var_levels) - } - } - - new_data -} diff --git a/R/epi_keys.R b/R/epi_keys.R deleted file mode 100644 index 4a00cbd46..000000000 --- a/R/epi_keys.R +++ /dev/null @@ -1,47 +0,0 @@ -#' Grab any keys associated to an epi_df -#' -#' @param x a data.frame, tibble, or epi_df -#' @param ... additional arguments passed on to methods -#' -#' @return If an `epi_df`, this returns all "keys". Otherwise `NULL` -#' @keywords internal -#' @export -epi_keys <- function(x, ...) { - UseMethod("epi_keys") -} - -#' @export -epi_keys.default <- function(x, ...) { - character(0L) -} - -#' @export -epi_keys.data.frame <- function(x, other_keys = character(0L), ...) { - arg_is_chr(other_keys, allow_empty = TRUE) - nm <- c("time_value", "geo_value", other_keys) - intersect(nm, names(x)) -} - -#' @export -epi_keys.epi_df <- function(x, ...) { - c("time_value", "geo_value", attributes(x)$metadata$other_keys) -} - -#' @export -epi_keys.recipe <- function(x, ...) { - x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")] -} - -# a mold is a list extracted from a fitted workflow, gives info about -# training data. But it doesn't have a class -epi_keys_mold <- function(mold) { - keys <- c("time_value", "geo_value", "key") - molded_names <- names(mold$extras$roles) - mold_keys <- map(mold$extras$roles[molded_names %in% keys], names) - unname(unlist(mold_keys)) -} - -kill_time_value <- function(v) { - arg_is_chr(v) - v[v != "time_value"] -} diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 3e5607dbb..f8216c2af 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -16,11 +16,10 @@ epi_recipe <- function(x, ...) { #' @rdname epi_recipe #' @export epi_recipe.default <- function(x, ...) { - ## if not a formula or an epi_df, we just pass to recipes::recipe - if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) { - x <- x[1, , drop = FALSE] - } - recipes::recipe(x, ...) + cli_abort(paste( + "`x` must be an {.cls epi_df} or a {.cls formula},", + "not a {.cls {class(x)[[1]]}}." + )) } #' @rdname epi_recipe @@ -42,21 +41,24 @@ epi_recipe.default <- function(x, ...) { #' #' @export #' @examples +#' library(dplyr) +#' library(recipes) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-08-01") %>% -#' dplyr::arrange(geo_value, time_value) +#' filter(time_value > "2021-08-01") %>% +#' arrange(geo_value, time_value) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% -#' recipes::step_naomit(recipes::all_predictors()) %>% +#' step_naomit(all_predictors()) %>% #' # below, `skip` means we don't do this at predict time -#' recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) +#' step_naomit(all_outcomes(), skip = TRUE) #' #' r epi_recipe.epi_df <- function(x, formula = NULL, ..., vars = NULL, roles = NULL) { + attr(x, "decay_to_tibble") <- FALSE if (!is.null(formula)) { if (!is.null(vars)) { rlang::abort( @@ -86,10 +88,10 @@ epi_recipe.epi_df <- rlang::abort("1 or more elements of `vars` are not in the data") } - keys <- epi_keys(x) # we know x is an epi_df + keys <- key_colnames(x) # we know x is an epi_df var_info <- tibble(variable = vars) - key_roles <- c("time_value", "geo_value", rep("key", length(keys) - 2)) + key_roles <- c("geo_value", rep("key", length(keys) - 2), "time_value") ## Check and add roles when available if (!is.null(roles)) { @@ -147,12 +149,16 @@ epi_recipe.formula <- function(formula, data, ...) { data <- data[1, ] # check for minus: if (!epiprocess::is_epi_df(data)) { - return(recipes::recipe(formula, data, ...)) + cli_abort(paste( + "`epi_recipe()` has been called with a non-{.cls epi_df} object.", + "Use `recipe()` instead." + )) } - f_funcs <- recipes:::fun_calls(formula) + attr(data, "decay_to_tibble") <- FALSE + f_funcs <- recipes:::fun_calls(formula, data) if (any(f_funcs == "-")) { - abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.") + cli_abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.") } # Check for other in-line functions @@ -173,12 +179,12 @@ epi_form2args <- function(formula, data, ...) { if (!rlang::is_formula(formula)) formula <- as.formula(formula) ## check for in-line formulas - recipes:::inline_check(formula) + recipes:::inline_check(formula, data) ## use rlang to get both sides of the formula outcomes <- recipes:::get_lhs_vars(formula, data) predictors <- recipes:::get_rhs_vars(formula, data, no_lhs = TRUE) - keys <- epi_keys(data) + keys <- key_colnames(data) ## if . was used on the rhs, subtract out the outcomes predictors <- predictors[!(predictors %in% outcomes)] @@ -237,7 +243,7 @@ is_epi_recipe <- function(x) { #' @details #' `add_epi_recipe` has the same behaviour as #' [workflows::add_recipe()] but sets a different -#' default blueprint to automatically handle [epiprocess::epi_df] data. +#' default blueprint to automatically handle [epiprocess::epi_df][epiprocess::as_epi_df] data. #' #' @param x A `workflow` or `epi_workflow` #' @@ -333,15 +339,11 @@ update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blu #' illustrations of the different types of updates. #' #' @param x A `epi_workflow` or `epi_recipe` object -#' #' @param which_step the number or name of the step to adjust -#' #' @param ... Used to input a parameter adjustment -#' #' @param blueprint A hardhat blueprint used for fine tuning the preprocessing. #' -#' @return -#' `x`, updated with the adjustment to the specified `epi_recipe` step. +#' @return `x`, updated with the adjustment to the specified `epi_recipe` step. #' #' @export #' @examples @@ -383,8 +385,7 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe #' @rdname adjust_epi_recipe #' @export -adjust_epi_recipe.epi_workflow <- function( - x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { +adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...) update_epi_recipe(x, recipe, blueprint = blueprint) @@ -392,8 +393,7 @@ adjust_epi_recipe.epi_workflow <- function( #' @rdname adjust_epi_recipe #' @export -adjust_epi_recipe.epi_recipe <- function( - x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { +adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { if (!(is.numeric(which_step) || is.character(which_step))) { cli::cli_abort( c("`which_step` must be a number or a character.", @@ -442,9 +442,9 @@ prep.epi_recipe <- function( } training <- recipes:::check_training_set(training, x, fresh) training <- epi_check_training_set(training, x) - training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training))) + training <- dplyr::relocate(training, dplyr::all_of(key_colnames(training))) tr_data <- recipes:::train_info(training) - keys <- epi_keys(x) + keys <- key_colnames(x) orig_lvls <- lapply(training, recipes:::get_levels) orig_lvls <- kill_levels(orig_lvls, keys) @@ -495,11 +495,14 @@ prep.epi_recipe <- function( if (!is_epi_df(training)) { # tidymodels killed our class # for now, we only allow step_epi_* to alter the metadata - training <- dplyr::dplyr_reconstruct( - epiprocess::as_epi_df(training), before_template + metadata <- attr(before_template, "metadata") + training <- as_epi_df( + training, + as_of = metadata$as_of, + other_keys = metadata$other_keys %||% character() ) } - training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training))) + training <- dplyr::relocate(training, all_of(key_colnames(training))) x$term_info <- recipes:::merge_term_info(get_types(training), x$term_info) if (!is.na(x$steps[[i]]$role)) { new_vars <- setdiff(x$term_info$variable, running_info$variable) @@ -557,6 +560,31 @@ prep.epi_recipe <- function( x } +#' @export +bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") { + meta <- NULL + if (composition == "epi_df") { + if (is_epi_df(new_data)) { + meta <- attr(new_data, "metadata") + } else if (is_epi_df(object$template)) { + meta <- attr(object$template, "metadata") + } + composition <- "tibble" + } + new_data <- NextMethod("bake") + if (!is.null(meta)) { + # Baking should have dropped epi_df-ness and metadata. Re-infer some + # metadata and assume others remain the same as the object/template: + new_data <- as_epi_df( + new_data, + as_of = meta$as_of, + other_keys = meta$other_keys %||% character() + ) + } + new_data +} + + kill_levels <- function(x, keys) { for (i in which(names(x) %in% keys)) x[[i]] <- list(values = NA, ordered = NA) x diff --git a/R/epi_shift.R b/R/epi_shift.R index b40b36ecc..eb534f1ea 100644 --- a/R/epi_shift.R +++ b/R/epi_shift.R @@ -17,9 +17,9 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { if (!is.data.frame(x)) x <- data.frame(x) if (is.null(keys)) keys <- rep("empty", nrow(x)) p_in <- ncol(x) - out_list <- tibble::tibble(i = 1:p_in, shift = shifts) %>% + out_list <- tibble(i = 1:p_in, shift = shifts) %>% tidyr::unchop(shift) %>% # what is chop - dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>% + mutate(name = paste0(out_name, 1:nrow(.))) %>% # One list element for each shifted feature pmap(function(i, shift, name) { tibble(keys, @@ -38,7 +38,7 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { epi_shift_single <- function(x, col, shift_val, newname, key_cols) { x %>% - dplyr::select(tidyselect::all_of(c(key_cols, col))) %>% - dplyr::mutate(time_value = time_value + shift_val) %>% - dplyr::rename(!!newname := {{ col }}) + select(all_of(c(key_cols, col))) %>% + mutate(time_value = time_value + shift_val) %>% + rename(!!newname := {{ col }}) } diff --git a/R/epi_workflow.R b/R/epi_workflow.R index d5e7d13a2..af4555303 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -59,107 +59,6 @@ is_epi_workflow <- function(x) { } -#' Add a model to an `epi_workflow` -#' -#' @seealso [workflows::add_model()] -#' - `add_model()` adds a parsnip model to the `epi_workflow`. -#' -#' - `remove_model()` removes the model specification as well as any fitted -#' model object. Any extra formulas are also removed. -#' -#' - `update_model()` first removes the model then adds the new -#' specification to the workflow. -#' -#' @details -#' Has the same behaviour as [workflows::add_model()] but also ensures -#' that the returned object is an `epi_workflow`. -#' -#' @inheritParams workflows::add_model -#' -#' @param x An `epi_workflow`. -#' -#' @param spec A parsnip model specification. -#' -#' @param ... Not used. -#' -#' @return -#' `x`, updated with a new, updated, or removed model. -#' -#' @export -#' @examples -#' jhu <- case_death_rate_subset %>% -#' dplyr::filter( -#' time_value > "2021-11-01", -#' geo_value %in% c("ak", "ca", "ny") -#' ) -#' -#' r <- epi_recipe(jhu) %>% -#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% -#' step_epi_ahead(death_rate, ahead = 7) -#' -#' rf_model <- rand_forest(mode = "regression") -#' -#' wf <- epi_workflow(r) -#' -#' wf <- wf %>% add_model(rf_model) -#' wf -#' -#' lm_model <- parsnip::linear_reg() -#' -#' wf <- update_model(wf, lm_model) -#' wf -#' -#' wf <- remove_model(wf) -#' wf -#' @export -add_model <- function(x, spec, ..., formula = NULL) { - UseMethod("add_model") -} - -#' @rdname add_model -#' @export -remove_model <- function(x) { - UseMethod("remove_model") -} - -#' @rdname add_model -#' @export -update_model <- function(x, spec, ..., formula = NULL) { - UseMethod("update_model") -} - -#' @rdname add_model -#' @export -add_model.epi_workflow <- function(x, spec, ..., formula = NULL) { - workflows::add_model(x, spec, ..., formula = formula) -} - -#' @rdname add_model -#' @export -remove_model.epi_workflow <- function(x) { - workflows:::validate_is_workflow(x) - - if (!workflows:::has_spec(x)) { - rlang::warn("The workflow has no model to remove.") - } - - new_epi_workflow( - pre = x$pre, - fit = workflows:::new_stage_fit(), - post = x$post, - trained = FALSE - ) -} - -#' @rdname add_model -#' @export -update_model.epi_workflow <- function(x, spec, ..., formula = NULL) { - rlang::check_dots_empty() - x <- remove_model(x) - workflows::add_model(x, spec, ..., formula = formula) -} - - #' Fit an `epi_workflow` object #' #' @description @@ -197,9 +96,16 @@ update_model.epi_workflow <- function(x, spec, ..., formula = NULL) { #' #' @export fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) { - object$fit$meta <- list(max_time_value = max(data$time_value), as_of = attributes(data)$metadata$as_of) + object$fit$meta <- list( + max_time_value = max(data$time_value), + as_of = attr(data, "metadata")$as_of, + other_keys = attr(data, "metadata")$other_keys + ) + object$original_data <- data - NextMethod() + res <- NextMethod() + class(res) <- c("epi_workflow", class(res)) + res } #' Predict from an epi_workflow @@ -216,18 +122,18 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' - Call [parsnip::predict.model_fit()] for you using the underlying fit #' parsnip model. #' -#' - Ensure that the returned object is an [epiprocess::epi_df] where +#' - Ensure that the returned object is an [epiprocess::epi_df][epiprocess::as_epi_df] where #' possible. Specifically, the output will have `time_value` and #' `geo_value` columns as well as the prediction. #' -#' @inheritParams parsnip::predict.model_fit -#' #' @param object An epi_workflow that has been fit by #' [workflows::fit.workflow()] #' #' @param new_data A data frame containing the new predictors to preprocess #' and predict on #' +#' @inheritParams parsnip::predict.model_fit +#' #' @return #' A data frame of model predictions, with as many rows as `new_data` has. #' If `new_data` is an `epi_df` or a data frame with `time_value` or @@ -249,24 +155,20 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' #' preds <- predict(wf, latest) #' preds -predict.epi_workflow <- function(object, new_data, ...) { +predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) { if (!workflows::is_trained_workflow(object)) { - rlang::abort( - c("Can't predict on an untrained epi_workflow.", - i = "Do you need to call `fit()`?" - ) - ) + cli::cli_abort(c( + "Can't predict on an untrained epi_workflow.", + i = "Do you need to call `fit()`?" + )) } components <- list() components$mold <- workflows::extract_mold(object) components$forged <- hardhat::forge(new_data, blueprint = components$mold$blueprint ) - components$keys <- grab_forged_keys( - components$forged, - components$mold, new_data - ) - components <- apply_frosting(object, components, new_data, ...) + components$keys <- grab_forged_keys(components$forged, object, new_data) + components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) components$predictions } @@ -282,27 +184,23 @@ predict.epi_workflow <- function(object, new_data, ...) { #' @export augment.epi_workflow <- function(x, new_data, ...) { predictions <- predict(x, new_data, ...) - if (epiprocess::is_epi_df(predictions)) { - join_by <- epi_keys(predictions) + if (is_epi_df(predictions)) { + join_by <- key_colnames(predictions) } else { - rlang::abort( - c( - "Cannot determine how to join new_data with the predictions.", - "Try converting new_data to an epi_df with `as_epi_df(new_data)`." - ) - ) + cli_abort(c( + "Cannot determine how to join new_data with the predictions.", + "Try converting new_data to an epi_df with `as_epi_df(new_data)`." + )) } complete_overlap <- intersect(names(new_data), join_by) if (length(complete_overlap) < length(join_by)) { - rlang::warn( - glue::glue( - "Your original training data had keys {join_by}, but", - "`new_data` only has {complete_overlap}. The output", - "may be strange." - ) - ) + rlang::warn(glue::glue( + "Your original training data had keys {join_by}, but", + "`new_data` only has {complete_overlap}. The output", + "may be strange." + )) } - dplyr::full_join(predictions, new_data, by = join_by) + full_join(predictions, new_data, by = join_by) } new_epi_workflow <- function( @@ -327,3 +225,54 @@ print.epi_workflow <- function(x, ...) { print_postprocessor(x) invisible(x) } + + +#' Produce a forecast from an epi workflow +#' +#' @param object An epi workflow. +#' @param ... Not used. +#' @param fill_locf Logical. Should we use locf to fill in missing data? +#' @param n_recent Integer or NULL. If filling missing data with locf = TRUE, +#' how far back are we willing to tolerate missing data? Larger values allow +#' more filling. The default NULL will determine this from the the recipe. For +#' example, suppose n_recent = 3, then if the 3 most recent observations in any +#' geo_value are all NA’s, we won’t be able to fill anything, and an error +#' message will be thrown. (See details.) +#' @param forecast_date By default, this is set to the maximum time_value in x. +#' But if there is data latency such that recent NA's should be filled, this may +#' be after the last available time_value. +#' +#' @return A forecast tibble. +#' +#' @export +forecast.epi_workflow <- function(object, ..., fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) { + rlang::check_dots_empty() + + if (!object$trained) { + cli_abort(c( + "You cannot `forecast()` a {.cls workflow} that has not been trained.", + i = "Please use `fit()` before forecasting." + )) + } + + frosting_fd <- NULL + if (has_postprocessor(object) && detect_layer(object, "layer_add_forecast_date")) { + frosting_fd <- extract_argument(object, "layer_add_forecast_date", "forecast_date") + if (!is.null(frosting_fd) && class(frosting_fd) != class(object$original_data$time_value)) { + cli_abort(c( + "Error with layer_add_forecast_date():", + i = "The type of `forecast_date` must match the type of the `time_value` column in the data." + )) + } + } + + test_data <- get_test_data( + hardhat::extract_preprocessor(object), + object$original_data, + fill_locf = fill_locf, + n_recent = n_recent %||% Inf, + forecast_date = forecast_date %||% frosting_fd %||% max(object$original_data$time_value) + ) + + predict(object, new_data = test_data) +} diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 11e2ec833..ad0f95295 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -1,9 +1,17 @@ ## usethis namespace: start -#' @importFrom tibble tibble -#' @importFrom rlang := !! -#' @importFrom stats poly predict lm residuals quantile -#' @importFrom cli cli_abort -#' @importFrom lifecycle deprecated #' @import epiprocess parsnip +#' @importFrom checkmate assert_class assert_numeric +#' @importFrom checkmate test_character test_date test_function +#' @importFrom checkmate test_integerish test_logical +#' @importFrom checkmate test_numeric test_scalar +#' @importFrom cli cli_abort cli_warn +#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by +#' @importFrom dplyr full_join relocate summarise everything +#' @importFrom dplyr summarize filter mutate select left_join rename ungroup +#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg +#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match +#' @importFrom stats poly predict lm residuals quantile +#' @importFrom tibble as_tibble +na_chr <- NA_character_ ## usethis namespace: end NULL diff --git a/R/extrapolate_quantiles.R b/R/extrapolate_quantiles.R new file mode 100644 index 000000000..3362e339e --- /dev/null +++ b/R/extrapolate_quantiles.R @@ -0,0 +1,65 @@ +#' Summarize a distribution with a set of quantiles +#' +#' @param x a `distribution` vector +#' @param probs a vector of probabilities at which to calculate quantiles +#' @param replace_na logical. If `x` contains `NA`'s, these are imputed if +#' possible (if `TRUE`) or retained (if `FALSE`). This only effects +#' elements of class `dist_quantiles`. +#' @param ... additional arguments passed on to the `quantile` method +#' +#' @return a `distribution` vector containing `dist_quantiles`. Any elements +#' of `x` which were originally `dist_quantiles` will now have a superset +#' of the original `quantile_values` (the union of those and `probs`). +#' @export +#' +#' @examples +#' library(distributional) +#' dstn <- dist_normal(c(10, 2), c(5, 10)) +#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) +#' +#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) +#' # because this distribution is already quantiles, any extra quantiles are +#' # appended +#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) +#' +#' dstn <- c( +#' dist_normal(c(10, 2), c(5, 10)), +#' dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) +#' ) +#' extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) +extrapolate_quantiles <- function(x, probs, replace_na = TRUE, ...) { + UseMethod("extrapolate_quantiles") +} + +#' @export +#' @importFrom vctrs vec_data +extrapolate_quantiles.distribution <- function(x, probs, replace_na = TRUE, ...) { + rlang::check_dots_empty() + arg_is_lgl_scalar(replace_na) + arg_is_probabilities(probs) + if (is.unsorted(probs)) probs <- sort(probs) + dstn <- lapply(vec_data(x), extrapolate_quantiles, probs = probs, replace_na = replace_na) + new_vctr(dstn, vars = NULL, class = "distribution") +} + +#' @export +extrapolate_quantiles.dist_default <- function(x, probs, replace_na = TRUE, ...) { + values <- quantile(x, probs, ...) + new_quantiles(values = values, quantile_levels = probs) +} + +#' @export +extrapolate_quantiles.dist_quantiles <- function(x, probs, replace_na = TRUE, ...) { + orig_probs <- field(x, "quantile_levels") + orig_values <- field(x, "values") + new_probs <- c(orig_probs, probs) + dups <- duplicated(new_probs) + if (!replace_na || !anyNA(orig_values)) { + new_values <- c(orig_values, quantile(x, probs, ...)) + } else { + nas <- is.na(orig_values) + orig_values[nas] <- quantile(x, orig_probs[nas], ...) + new_values <- c(orig_values, quantile(x, probs, ...)) + } + new_quantiles(new_values[!dups], new_probs[!dups]) +} diff --git a/R/flatline.R b/R/flatline.R index 0f98b0e2b..fb60c920d 100644 --- a/R/flatline.R +++ b/R/flatline.R @@ -43,26 +43,26 @@ flatline <- function(formula, data) { observed <- rhs[n] # DANGER!! ek <- rhs[-n] if (length(response) > 1) { - cli_stop("flatline forecaster can accept only 1 observed time series.") + cli_abort("flatline forecaster can accept only 1 observed time series.") } keys <- kill_time_value(ek) preds <- data %>% - dplyr::mutate( + mutate( .pred = !!rlang::sym(observed), .resid = !!rlang::sym(response) - .pred ) .pred <- preds %>% - dplyr::filter(!is.na(.pred)) %>% - dplyr::group_by(!!!rlang::syms(keys)) %>% - dplyr::arrange(time_value) %>% + filter(!is.na(.pred)) %>% + group_by(!!!rlang::syms(keys)) %>% + arrange(time_value) %>% dplyr::slice_tail(n = 1L) %>% - dplyr::ungroup() %>% - dplyr::select(tidyselect::all_of(c(keys, ".pred"))) + ungroup() %>% + select(all_of(c(keys, ".pred"))) structure( list( - residuals = dplyr::select(preds, dplyr::all_of(c(keys, ".resid"))), + residuals = select(preds, all_of(c(keys, ".resid"))), .pred = .pred ), class = "flatline" @@ -80,14 +80,13 @@ predict.flatline <- function(object, newdata, ...) { metadata <- names(object)[names(object) != ".pred"] ek <- names(newdata) if (!all(metadata %in% ek)) { - cli_stop( + cli_abort(c( "`newdata` has different metadata than was used", "to fit the flatline forecaster" - ) + )) } - dplyr::left_join(newdata, object, by = metadata) %>% - dplyr::pull(.pred) + left_join(newdata, object, by = metadata)$.pred } #' @export diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 99ebc8694..55808b803 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -1,8 +1,9 @@ #' Predict the future with today's value #' #' This is a simple forecasting model for -#' [epiprocess::epi_df] data. It uses the most recent observation as the -#' forcast for any future date, and produces intervals based on the quantiles +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent +#' observation as the +#' forecast for any future date, and produces intervals based on the quantiles #' of the residuals of such a "flatline" forecast over all available training #' data. #' @@ -13,7 +14,7 @@ #' This forecaster is very similar to that used by the #' [COVID19ForecastHub](https://covid19forecasthub.org) #' -#' @param epi_data An [epiprocess::epi_df] +#' @param epi_data An [epiprocess::epi_df][epiprocess::as_epi_df] #' @param outcome A scalar character for the column name we wish to predict. #' @param args_list A list of dditional arguments as created by the #' [flatline_args_list()] constructor function. @@ -33,9 +34,9 @@ flatline_forecaster <- function( args_list = flatline_args_list()) { validate_forecaster_inputs(epi_data, outcome, "time_value") if (!inherits(args_list, c("flat_fcast", "alist"))) { - cli_stop("args_list was not created using `flatline_args_list().") + cli_abort("`args_list` was not created using `flatline_args_list().") } - keys <- epi_keys(epi_data) + keys <- key_colnames(epi_data) ek <- kill_time_value(keys) outcome <- rlang::sym(outcome) @@ -47,13 +48,7 @@ flatline_forecaster <- function( step_training_window(n_recent = args_list$n_training) forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) - target_date <- args_list$target_date %||% forecast_date + args_list$ahead - - - latest <- get_test_data( - epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer, - forecast_date - ) + target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) f <- frosting() %>% layer_predict() %>% @@ -66,13 +61,18 @@ flatline_forecaster <- function( layer_add_target_date(target_date = target_date) if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) - eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") + eng <- linear_reg(engine = "flatline") wf <- epi_workflow(r, eng, f) - wf <- generics::fit(wf, epi_data) - preds <- suppressWarnings(predict(wf, new_data = latest)) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) + wf <- fit(wf, epi_data) + preds <- suppressWarnings(forecast( + wf, + fill_locf = TRUE, + n_recent = args_list$nafill_buffer, + forecast_date = forecast_date + )) %>% + as_tibble() %>% + select(-time_value) structure( list( @@ -131,6 +131,15 @@ flatline_args_list <- function( if (is.finite(n_training)) arg_is_pos_int(n_training) if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) + if (!is.null(forecast_date) && !is.null(target_date)) { + if (forecast_date + ahead != target_date) { + cli::cli_warn(c( + "`forecast_date` + `ahead` must equal `target_date`.", + i = "{.val {forecast_date}} + {.val {ahead}} != {.val {target_date}}." + )) + } + } + structure( enlist( ahead, diff --git a/R/flusight_hub_formatter.R b/R/flusight_hub_formatter.R index 0dbd1a954..3e0eb1aaa 100644 --- a/R/flusight_hub_formatter.R +++ b/R/flusight_hub_formatter.R @@ -56,23 +56,26 @@ abbr_to_location <- function(abbr) { #' @export #' #' @examples -#' if (require(dplyr)) { -#' weekly_deaths <- case_death_rate_subset %>% -#' select(geo_value, time_value, death_rate) %>% -#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>% -#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>% -#' select(-pop, -death_rate) %>% -#' group_by(geo_value) %>% -#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>% -#' ungroup() %>% -#' filter(weekdays(time_value) == "Saturday") +#' library(dplyr) +#' weekly_deaths <- case_death_rate_subset %>% +#' filter( +#' time_value >= as.Date("2021-09-01"), +#' geo_value %in% c("ca", "ny", "dc", "ga", "vt") +#' ) %>% +#' select(geo_value, time_value, death_rate) %>% +#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>% +#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>% +#' select(-pop, -death_rate) %>% +#' group_by(geo_value) %>% +#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>% +#' ungroup() %>% +#' filter(weekdays(time_value) == "Saturday") #' -#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") -#' flusight_hub_formatter(cdc) -#' flusight_hub_formatter(cdc, target = "wk inc covid deaths") -#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) -#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") -#' } +#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum") +#' flusight_hub_formatter(cdc) +#' flusight_hub_formatter(cdc, target = "wk inc covid deaths") +#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) +#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") flusight_hub_formatter <- function( object, ..., .fcast_period = c("daily", "weekly")) { @@ -94,7 +97,7 @@ flusight_hub_formatter.data.frame <- function( optional_names <- c("ahead", "target_date") hardhat::validate_column_names(object, required_names) if (!any(optional_names %in% names(object))) { - cli::cli_abort("At least one of {.val {optional_names}} must be present.") + cli_abort("At least one of {.val {optional_names}} must be present.") } dots <- enquos(..., .named = TRUE) @@ -102,38 +105,38 @@ flusight_hub_formatter.data.frame <- function( object <- object %>% # combine the predictions and the distribution - dplyr::mutate(.pred_distn = nested_quantiles(.pred_distn)) %>% + mutate(.pred_distn = nested_quantiles(.pred_distn)) %>% tidyr::unnest(.pred_distn) %>% # now we create the correct column names - dplyr::rename( + rename( value = values, output_type_id = quantile_levels, reference_date = forecast_date ) %>% # convert to fips codes, and add any constant cols passed in ... - dplyr::mutate(location = abbr_to_location(tolower(geo_value)), geo_value = NULL) + mutate(location = abbr_to_location(tolower(geo_value)), geo_value = NULL) # create target_end_date / horizon, depending on what is available pp <- ifelse(match.arg(.fcast_period) == "daily", 1L, 7L) has_ahead <- charmatch("ahead", names(object)) if ("target_date" %in% names(object) && !is.na(has_ahead)) { object <- object %>% - dplyr::rename( + rename( target_end_date = target_date, horizon = !!names(object)[has_ahead] ) } else if (!is.na(has_ahead)) { # ahead present, not target date object <- object %>% - dplyr::rename(horizon = !!names(object)[has_ahead]) %>% - dplyr::mutate(target_end_date = horizon * pp + reference_date) + rename(horizon = !!names(object)[has_ahead]) %>% + mutate(target_end_date = horizon * pp + reference_date) } else { # target_date present, not ahead object <- object %>% - dplyr::rename(target_end_date = target_date) %>% - dplyr::mutate(horizon = as.integer((target_end_date - reference_date)) / pp) + rename(target_end_date = target_date) %>% + mutate(horizon = as.integer((target_end_date - reference_date)) / pp) } object %>% - dplyr::relocate( + relocate( reference_date, horizon, target_end_date, location, output_type_id, value ) %>% - dplyr::mutate(!!!dots) + mutate(!!!dots) } diff --git a/R/frosting.R b/R/frosting.R index 505fd5bcc..8474edbdf 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -8,15 +8,16 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' latest <- jhu %>% -#' dplyr::filter(time_value >= max(time_value) - 14) +#' filter(time_value >= max(time_value) - 14) #' #' # Add frosting to a workflow and predict #' f <- frosting() %>% @@ -84,7 +85,8 @@ validate_has_postprocessor <- function(x, ..., call = caller_env()) { rlang::check_dots_empty() has_postprocessor <- has_postprocessor_frosting(x) if (!has_postprocessor) { - message <- c("The workflow must have a frosting postprocessor.", + message <- c( + "The workflow must have a frosting postprocessor.", i = "Provide one with `add_frosting()`." ) rlang::abort(message, call = call) @@ -125,6 +127,7 @@ update_frosting <- function(x, frosting, ...) { #' #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% @@ -132,7 +135,7 @@ update_frosting <- function(x, frosting, ...) { #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' # in the frosting from the workflow #' f1 <- frosting() %>% @@ -177,11 +180,10 @@ adjust_frosting.epi_workflow <- function( adjust_frosting.frosting <- function( x, which_layer, ...) { if (!(is.numeric(which_layer) || is.character(which_layer))) { - cli::cli_abort( - c("`which_layer` must be a number or a character.", - i = "`which_layer` has class {.cls {class(which_layer)[1]}}." - ) - ) + cli_abort(c( + "`which_layer` must be a number or a character.", + i = "`which_layer` has class {.cls {class(which_layer)[1]}}." + )) } else if (is.numeric(which_layer)) { x$layers[[which_layer]] <- update(x$layers[[which_layer]], ...) } else { @@ -190,7 +192,7 @@ adjust_frosting.frosting <- function( if (!starts_with_layer) which_layer <- paste0("layer_", which_layer) if (!(which_layer %in% layer_names)) { - cli::cli_abort(c( + cli_abort(c( "`which_layer` does not appear in the available `frosting` layer names. ", i = "The layer names are {.val {layer_names}}." )) @@ -199,7 +201,7 @@ adjust_frosting.frosting <- function( if (length(which_layer_idx) == 1) { x$layers[[which_layer_idx]] <- update(x$layers[[which_layer_idx]], ...) } else { - cli::cli_abort(c( + cli_abort(c( "`which_layer` is not unique. Matches layers: {.val {which_layer_idx}}.", i = "Please use the layer number instead for precise alterations." )) @@ -216,7 +218,7 @@ add_postprocessor <- function(x, postprocessor, ..., call = caller_env()) { if (is_frosting(postprocessor)) { return(add_frosting(x, postprocessor)) } - cli::cli_abort("`postprocessor` must be a frosting object.", call = call) + cli_abort("`postprocessor` must be a frosting object.", call = call) } is_frosting <- function(x) { @@ -227,7 +229,7 @@ is_frosting <- function(x) { validate_frosting <- function(x, ..., arg = "`x`", call = caller_env()) { rlang::check_dots_empty() if (!is_frosting(x)) { - cli::cli_abort( + cli_abort( "{arg} must be a frosting postprocessor, not a {.cls {class(x)[[1]]}}.", .call = call ) @@ -260,14 +262,14 @@ new_frosting <- function() { #' @export #' #' @examples -#' +#' library(dplyr) #' # Toy example to show that frosting can be created and added for postprocessing #' f <- frosting() #' wf <- epi_workflow() %>% add_frosting(f) #' #' # A more realistic example #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -275,7 +277,6 @@ new_frosting <- function() { #' step_epi_naomit() #' #' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' latest <- get_test_data(recipe = r, x = jhu) #' #' f <- frosting() %>% #' layer_predict() %>% @@ -283,7 +284,7 @@ new_frosting <- function() { #' #' wf1 <- wf %>% add_frosting(f) #' -#' p <- predict(wf1, latest) +#' p <- forecast(wf1) #' p frosting <- function(layers = NULL, requirements = NULL) { if (!is_null(layers) || !is_null(requirements)) { @@ -308,7 +309,7 @@ extract_frosting <- function(x, ...) { #' @export extract_frosting.default <- function(x, ...) { - cli::cli_abort(c( + cli_abort(c( "Frosting is only available for epi_workflows currently.", i = "Can you use `epi_workflow()` instead of `workflow()`?" )) @@ -320,7 +321,7 @@ extract_frosting.epi_workflow <- function(x, ...) { if (has_postprocessor_frosting(x)) { return(x$post$actions$frosting$frosting) } else { - cli_stop("The epi_workflow does not have a postprocessor.") + cli_abort("The epi_workflow does not have a postprocessor.") } } @@ -343,7 +344,7 @@ apply_frosting <- function(workflow, ...) { #' @export apply_frosting.default <- function(workflow, components, ...) { if (has_postprocessor(workflow)) { - cli::cli_abort(c( + cli_abort(c( "Postprocessing is only available for epi_workflows currently.", i = "Can you use `epi_workflow()` instead of `workflow()`?" )) @@ -356,9 +357,11 @@ apply_frosting.default <- function(workflow, components, ...) { #' @rdname apply_frosting #' @importFrom rlang is_null #' @importFrom rlang abort +#' @param type,opts forwarded (along with `...`) to [`predict.model_fit()`] and +#' [`slather()`] for supported layers #' @export apply_frosting.epi_workflow <- - function(workflow, components, new_data, ...) { + function(workflow, components, new_data, type = NULL, opts = list(), ...) { the_fit <- workflows::extract_fit_parsnip(workflow) if (!has_postprocessor(workflow)) { @@ -372,12 +375,12 @@ apply_frosting.epi_workflow <- } if (!has_postprocessor_frosting(workflow)) { - cli::cli_warn(c( + cli_warn(c( "Only postprocessors of class {.cls frosting} are allowed.", "Returning unpostprocessed predictions." )) components$predictions <- predict( - the_fit, components$forged$predictors, ... + the_fit, components$forged$predictors, type, opts, ... ) components$predictions <- dplyr::bind_cols( components$keys, components$predictions @@ -398,10 +401,28 @@ apply_frosting.epi_workflow <- layers ) } + if (length(layers) > 1L && + (!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) { + cli_abort(" + Passing `type`, `opts`, or `...` into `predict.epi_workflow()` is not + supported if you have frosting layers other than `layer_predict`. Please + provide these arguments earlier (i.e. while constructing the frosting + object) by passing them into an explicit call to `layer_predict(), and + adjust the remaining layers to account for resulting differences in + output format from these settings. + ", class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers") + } for (l in seq_along(layers)) { la <- layers[[l]] - components <- slather(la, components, workflow, new_data) + if (inherits(la, "layer_predict")) { + components <- slather(la, components, workflow, new_data, type = type, opts = opts, ...) + } else { + # The check above should ensure we have default `type` and `opts`, and + # empty `...`; don't forward these default `type` and `opts`, to avoid + # upsetting some slather method validation. + components <- slather(la, components, workflow, new_data) + } } return(components) diff --git a/R/get_test_data.R b/R/get_test_data.R index e76715daf..694e73b06 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -1,7 +1,7 @@ #' Get test data for prediction based on longest lag period #' #' Based on the longest lag period in the recipe, -#' `get_test_data()` creates an [epi_df] +#' `get_test_data()` creates an [epi_df][epiprocess::as_epi_df] #' with columns `geo_value`, `time_value` #' and other variables in the original dataset, #' which will be used to create features necessary to produce forecasts. @@ -42,14 +42,13 @@ #' get_test_data(recipe = rec, x = case_death_rate_subset) #' @importFrom rlang %@% #' @export - get_test_data <- function( recipe, x, fill_locf = FALSE, n_recent = NULL, forecast_date = max(x$time_value)) { - if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.") + if (!is_epi_df(x)) cli_abort("`x` must be an `epi_df`.") arg_is_lgl(fill_locf) arg_is_scalar(fill_locf) arg_is_scalar(n_recent, allow_null = TRUE) @@ -60,25 +59,23 @@ get_test_data <- function( check <- hardhat::check_column_names(x, colnames(recipe$template)) if (!check$ok) { - cli::cli_abort(c( + cli_abort(c( "Some variables used for training are not available in {.arg x}.", i = "The following required columns are missing: {check$missing_names}" )) } - if (class(forecast_date) != class(x$time_value)) { - cli::cli_abort("`forecast_date` must be the same class as `x$time_value`.") + cli_abort("`forecast_date` must be the same class as `x$time_value`.") } - - if (forecast_date < max(x$time_value)) { - cli::cli_abort("`forecast_date` must be no earlier than `max(x$time_value)`") + cli_abort("`forecast_date` must be no earlier than `max(x$time_value)`") } min_lags <- min(map_dbl(recipe$steps, ~ min(.x$lag %||% Inf)), Inf) max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0) max_horizon <- max(map_dbl(recipe$steps, ~ max(.x$horizon %||% 0)), 0) - min_required <- max_lags + max_horizon + max_slide <- max(map_dbl(recipe$steps, ~ max(.x$before %||% 0)), 0) + min_required <- max_lags + max_horizon + max_slide if (is.null(n_recent)) n_recent <- min_required + 1 # one extra for filling if (n_recent <= min_required) n_recent <- min_required + n_recent @@ -86,7 +83,7 @@ get_test_data <- function( # Probably needs a fix based on the time_type of the epi_df avail_recent <- diff(range(x$time_value)) if (avail_recent < min_required) { - cli::cli_abort(c( + cli_abort(c( "You supplied insufficient recent data for this recipe. ", "!" = "You need at least {min_required} days of data,", "!" = "but `x` contains only {avail_recent}." @@ -94,44 +91,42 @@ get_test_data <- function( } x <- arrange(x, time_value) - groups <- kill_time_value(epi_keys(recipe)) + groups <- epi_keys_only(recipe) # If we skip NA completion, we remove undesirably early time values # Happens globally, over all groups keep <- max(n_recent, min_required + 1) - x <- dplyr::filter(x, forecast_date - time_value <= keep) + x <- filter(x, forecast_date - time_value <= keep) # Pad with explicit missing values up to and including the forecast_date # x is grouped here x <- pad_to_end(x, groups, forecast_date) %>% - epiprocess::group_by(dplyr::across(dplyr::all_of(groups))) + group_by(across(all_of(groups))) # If all(lags > 0), then we get rid of recent data if (min_lags > 0 && min_lags < Inf) { - x <- dplyr::filter(x, forecast_date - time_value >= min_lags) + x <- filter(x, forecast_date - time_value >= min_lags) } # Now, fill forward missing data if requested if (fill_locf) { cannot_be_used <- x %>% - dplyr::filter(forecast_date - time_value <= n_recent) %>% - dplyr::mutate(fillers = forecast_date - time_value > min_required) %>% - dplyr::summarize( - dplyr::across( - -tidyselect::any_of(epi_keys(recipe)), + filter(forecast_date - time_value <= n_recent) %>% + mutate(fillers = forecast_date - time_value > min_required) %>% + summarize( + across( + -any_of(key_colnames(recipe)), ~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1)) ), .groups = "drop" ) %>% - dplyr::select(-fillers) %>% - dplyr::summarise(dplyr::across( - -tidyselect::any_of(epi_keys(recipe)), ~ any(.x) - )) %>% + select(-fillers) %>% + summarise(across(-any_of(key_colnames(recipe)), ~ any(.x))) %>% unlist() if (any(cannot_be_used)) { bad_vars <- names(cannot_be_used)[cannot_be_used] if (recipes::is_trained(recipe)) { - cli::cli_abort(c( + cli_abort(c( "The variables {.var {bad_vars}} have too many recent missing", `!` = "values to be filled automatically. ", i = "You should either choose `n_recent` larger than its current ", @@ -143,15 +138,15 @@ get_test_data <- function( x <- tidyr::fill(x, !time_value) } - dplyr::filter(x, forecast_date - time_value <= min_required) %>% - epiprocess::ungroup() + filter(x, forecast_date - time_value <= min_required) %>% + ungroup() } pad_to_end <- function(x, groups, end_date) { - itval <- epiprocess:::guess_period(c(x$time_value, end_date), "time_value") + itval <- guess_period(c(x$time_value, end_date), "time_value") completed_time_values <- x %>% - dplyr::group_by(dplyr::across(tidyselect::all_of(groups))) %>% - dplyr::summarise( + group_by(across(all_of(groups))) %>% + summarise( time_value = rlang::list2( time_value = Seq(max(time_value) + itval, end_date, itval) ) @@ -159,8 +154,8 @@ pad_to_end <- function(x, groups, end_date) { unnest("time_value") %>% mutate(time_value = vctrs::vec_cast(time_value, x$time_value)) - dplyr::bind_rows(x, completed_time_values) %>% - dplyr::arrange(dplyr::across(tidyselect::all_of(c("time_value", groups)))) + bind_rows(x, completed_time_values) %>% + arrange(across(all_of(c("time_value", groups)))) } Seq <- function(from, to, by) { diff --git a/R/grab_names.R b/R/grab_names.R deleted file mode 100644 index 7ff3ac77e..000000000 --- a/R/grab_names.R +++ /dev/null @@ -1,23 +0,0 @@ -#' Get the names from a data frame via tidy select -#' -#' Given a data.frame, use `` syntax to choose -#' some variables. Return the names of those variables -#' -#' As this is an internal function, no checks are performed. -#' -#' @param dat a data.frame -#' @param ... <[`tidy-select`][dplyr::dplyr_tidy_select]> One or more unquoted -#' expressions separated by commas. Variable names can be used as if they -#' were positions in the data frame, so expressions like `x:y` can -#' be used to select a range of variables. -#' -#' @export -#' @keywords internal -#' @return a character vector -#' @examples -#' df <- data.frame(a = 1, b = 2, cc = rep(NA, 3)) -#' grab_names(df, dplyr::starts_with("c")) -grab_names <- function(dat, ...) { - x <- rlang::expr(c(...)) - names(tidyselect::eval_select(x, dat)) -} diff --git a/R/import-standalone-purrr.R b/R/import-standalone-purrr.R new file mode 100644 index 000000000..623142a0e --- /dev/null +++ b/R/import-standalone-purrr.R @@ -0,0 +1,240 @@ +# Standalone file: do not edit by hand +# Source: +# ---------------------------------------------------------------------- +# +# --- +# repo: r-lib/rlang +# file: standalone-purrr.R +# last-updated: 2023-02-23 +# license: https://unlicense.org +# imports: rlang +# --- +# +# This file provides a minimal shim to provide a purrr-like API on top of +# base R functions. They are not drop-in replacements but allow a similar style +# of programming. +# +# ## Changelog +# +# 2023-02-23: +# * Added `list_c()` +# +# 2022-06-07: +# * `transpose()` is now more consistent with purrr when inner names +# are not congruent (#1346). +# +# 2021-12-15: +# * `transpose()` now supports empty lists. +# +# 2021-05-21: +# * Fixed "object `x` not found" error in `imap()` (@mgirlich) +# +# 2020-04-14: +# * Removed `pluck*()` functions +# * Removed `*_cpl()` functions +# * Used `as_function()` to allow use of `~` +# * Used `.` prefix for helpers +# +# nocov start + +map <- function(.x, .f, ...) { + .f <- as_function(.f, env = global_env()) + lapply(.x, .f, ...) +} +walk <- function(.x, .f, ...) { + map(.x, .f, ...) + invisible(.x) +} + +map_lgl <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, logical(1), ...) +} +map_int <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, integer(1), ...) +} +map_dbl <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, double(1), ...) +} +map_chr <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, character(1), ...) +} +.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) { + .f <- as_function(.f, env = global_env()) + out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE) + names(out) <- names(.x) + out +} + +map2 <- function(.x, .y, .f, ...) { + .f <- as_function(.f, env = global_env()) + out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE) + if (length(out) == length(.x)) { + set_names(out, names(.x)) + } else { + set_names(out, NULL) + } +} +map2_lgl <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "logical") +} +map2_int <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "integer") +} +map2_dbl <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "double") +} +map2_chr <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "character") +} +imap <- function(.x, .f, ...) { + map2(.x, names(.x) %||% seq_along(.x), .f, ...) +} + +pmap <- function(.l, .f, ...) { + .f <- as.function(.f) + args <- .rlang_purrr_args_recycle(.l) + do.call("mapply", c( + FUN = list(quote(.f)), + args, MoreArgs = quote(list(...)), + SIMPLIFY = FALSE, USE.NAMES = FALSE + )) +} +.rlang_purrr_args_recycle <- function(args) { + lengths <- map_int(args, length) + n <- max(lengths) + + stopifnot(all(lengths == 1L | lengths == n)) + to_recycle <- lengths == 1L + args[to_recycle] <- map(args[to_recycle], function(x) rep.int(x, n)) + + args +} + +keep <- function(.x, .f, ...) { + .x[.rlang_purrr_probe(.x, .f, ...)] +} +discard <- function(.x, .p, ...) { + sel <- .rlang_purrr_probe(.x, .p, ...) + .x[is.na(sel) | !sel] +} +map_if <- function(.x, .p, .f, ...) { + matches <- .rlang_purrr_probe(.x, .p) + .x[matches] <- map(.x[matches], .f, ...) + .x +} +.rlang_purrr_probe <- function(.x, .p, ...) { + if (is_logical(.p)) { + stopifnot(length(.p) == length(.x)) + .p + } else { + .p <- as_function(.p, env = global_env()) + map_lgl(.x, .p, ...) + } +} + +compact <- function(.x) { + Filter(length, .x) +} + +transpose <- function(.l) { + if (!length(.l)) { + return(.l) + } + + inner_names <- names(.l[[1]]) + + if (is.null(inner_names)) { + fields <- seq_along(.l[[1]]) + } else { + fields <- set_names(inner_names) + .l <- map(.l, function(x) { + if (is.null(names(x))) { + set_names(x, inner_names) + } else { + x + } + }) + } + + # This way missing fields are subsetted as `NULL` instead of causing + # an error + .l <- map(.l, as.list) + + map(fields, function(i) { + map(.l, .subset2, i) + }) +} + +every <- function(.x, .p, ...) { + .p <- as_function(.p, env = global_env()) + + for (i in seq_along(.x)) { + if (!rlang::is_true(.p(.x[[i]], ...))) return(FALSE) + } + TRUE +} +some <- function(.x, .p, ...) { + .p <- as_function(.p, env = global_env()) + + for (i in seq_along(.x)) { + if (rlang::is_true(.p(.x[[i]], ...))) return(TRUE) + } + FALSE +} +negate <- function(.p) { + .p <- as_function(.p, env = global_env()) + function(...) !.p(...) +} + +reduce <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(x, y, ...) + Reduce(f, .x, init = .init) +} +reduce_right <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(y, x, ...) + Reduce(f, .x, init = .init, right = TRUE) +} +accumulate <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(x, y, ...) + Reduce(f, .x, init = .init, accumulate = TRUE) +} +accumulate_right <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(y, x, ...) + Reduce(f, .x, init = .init, right = TRUE, accumulate = TRUE) +} + +detect <- function(.x, .f, ..., .right = FALSE, .p = is_true) { + .p <- as_function(.p, env = global_env()) + .f <- as_function(.f, env = global_env()) + + for (i in .rlang_purrr_index(.x, .right)) { + if (.p(.f(.x[[i]], ...))) { + return(.x[[i]]) + } + } + NULL +} +detect_index <- function(.x, .f, ..., .right = FALSE, .p = is_true) { + .p <- as_function(.p, env = global_env()) + .f <- as_function(.f, env = global_env()) + + for (i in .rlang_purrr_index(.x, .right)) { + if (.p(.f(.x[[i]], ...))) { + return(i) + } + } + 0L +} +.rlang_purrr_index <- function(x, right = FALSE) { + idx <- seq_along(x) + if (right) { + idx <- rev(idx) + } + idx +} + +list_c <- function(x) { + inject(c(!!!x)) +} + +# nocov end diff --git a/R/key_colnames.R b/R/key_colnames.R new file mode 100644 index 000000000..b9ebde5dc --- /dev/null +++ b/R/key_colnames.R @@ -0,0 +1,27 @@ +#' @export +key_colnames.recipe <- function(x, ...) { + geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"] + time_key <- x$var_info$variable[x$var_info$role %in% "time_value"] + keys <- x$var_info$variable[x$var_info$role %in% "key"] + c(geo_key, keys, time_key) %||% character(0L) +} + +#' @export +key_colnames.epi_workflow <- function(x, ...) { + # safer to look at the mold than the preprocessor + mold <- hardhat::extract_mold(x) + molded_names <- names(mold$extras$roles) + geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value) + time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value) + keys <- names(mold$extras$roles[molded_names %in% "key"]$key) + c(geo_key, keys, time_key) %||% character(0L) +} + +kill_time_value <- function(v) { + arg_is_chr(v) + v[v != "time_value"] +} + +epi_keys_only <- function(x, ...) { + kill_time_value(key_colnames(x, ...)) +} diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 6bb2cf572..3d5ea010b 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -19,15 +19,16 @@ #' #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' latest <- jhu %>% -#' dplyr::filter(time_value >= max(time_value) - 14) +#' filter(time_value >= max(time_value) - 14) #' #' # Don't specify `forecast_date` (by default, this should be last date in latest) #' f <- frosting() %>% @@ -68,6 +69,9 @@ #' p3 layer_add_forecast_date <- function(frosting, forecast_date = NULL, id = rand_id("add_forecast_date")) { + arg_is_chr_scalar(id) + arg_is_scalar(forecast_date, allow_null = TRUE) + # can't validate the type of forecast_date until we know the time_type add_layer( frosting, layer_add_forecast_date_new( @@ -78,39 +82,40 @@ layer_add_forecast_date <- } layer_add_forecast_date_new <- function(forecast_date, id) { - forecast_date <- arg_to_date(forecast_date, allow_null = TRUE) - arg_is_chr_scalar(id) layer("add_forecast_date", forecast_date = forecast_date, id = id) } #' @export -slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { +slather.layer_add_forecast_date <- function(object, components, workflow, + new_data, ...) { + rlang::check_dots_empty() if (is.null(object$forecast_date)) { - max_time_value <- max( + max_time_value <- as.Date(max( workflows::extract_preprocessor(workflow)$max_time_value, workflow$fit$meta$max_time_value, max(new_data$time_value) - ) - object$forecast_date <- max_time_value + )) + forecast_date <- max_time_value + } else { + forecast_date <- object$forecast_date } - as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of - as_of_fit <- workflow$fit$meta$as_of - as_of_post <- attributes(new_data)$metadata$as_of - - as_of_date <- as.Date(max(as_of_pre, as_of_fit, as_of_post)) - if (object$forecast_date < as_of_date) { - cli_warn( - c("The forecast_date is less than the most ", - "recent update date of the data: ", - i = "forecast_date = {object$forecast_date} while data is from {as_of_date}." - ) - ) - } - components$predictions <- dplyr::bind_cols( + expected_time_type <- attr( + workflows::extract_preprocessor(workflow)$template, "metadata" + )$time_type + if (expected_time_type == "week") expected_time_type <- "day" + if (expected_time_type == "integer") expected_time_type <- "year" + validate_date( + forecast_date, expected_time_type, + call = rlang::expr(layer_add_forecast_date()) + ) + forecast_date <- coerce_time_type(forecast_date, expected_time_type) + object$forecast_date <- forecast_date + components$predictions <- bind_cols( components$predictions, - forecast_date = as.Date(object$forecast_date) + forecast_date = forecast_date ) + components } diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index f2fee889f..094ec8501 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -20,25 +20,25 @@ #' #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' latest <- get_test_data(r, jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' # Use ahead + forecast date #' f <- frosting() %>% #' layer_predict() %>% -#' layer_add_forecast_date(forecast_date = "2022-05-31") %>% +#' layer_add_forecast_date(forecast_date = as.Date("2022-05-31")) %>% #' layer_add_target_date() %>% #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) #' -#' p <- predict(wf1, latest) +#' p <- forecast(wf1) #' p #' #' # Use ahead + max time value from pre, fit, post @@ -49,7 +49,7 @@ #' layer_naomit(.pred) #' wf2 <- wf %>% add_frosting(f2) #' -#' p2 <- predict(wf2, latest) +#' p2 <- forecast(wf2) #' p2 #' #' # Specify own target date @@ -59,12 +59,13 @@ #' layer_naomit(.pred) #' wf3 <- wf %>% add_frosting(f3) #' -#' p3 <- predict(wf3, latest) +#' p3 <- forecast(wf3) #' p3 layer_add_target_date <- function(frosting, target_date = NULL, id = rand_id("add_target_date")) { - target_date <- arg_to_date(target_date, allow_null = TRUE) arg_is_chr_scalar(id) + arg_is_scalar(target_date, allow_null = TRUE) + # can't validate the type of target_date until we know the time_type add_layer( frosting, layer_add_target_date_new( @@ -79,41 +80,50 @@ layer_add_target_date_new <- function(id = id, target_date = target_date) { } #' @export -slather.layer_add_target_date <- function(object, components, workflow, new_data, ...) { +slather.layer_add_target_date <- function(object, components, workflow, + new_data, ...) { rlang::check_dots_empty() the_recipe <- workflows::extract_recipe(workflow) the_frosting <- extract_frosting(workflow) + expected_time_type <- attr( + workflows::extract_preprocessor(workflow)$template, "metadata" + )$time_type + if (expected_time_type == "week") expected_time_type <- "day" + if (expected_time_type == "integer") expected_time_type <- "year" + if (!is.null(object$target_date)) { - target_date <- as.Date(object$target_date) - } else { # null target date case - if (detect_layer(the_frosting, "layer_add_forecast_date") && - !is.null(extract_argument( - the_frosting, - "layer_add_forecast_date", "forecast_date" + target_date <- object$target_date + validate_date( + target_date, expected_time_type, + call = expr(layer_add_target_date()) + ) + target_date <- coerce_time_type(target_date, expected_time_type) + } else if ( + detect_layer(the_frosting, "layer_add_forecast_date") && + !is.null(forecast_date <- extract_argument( + the_frosting, "layer_add_forecast_date", "forecast_date" ))) { - forecast_date <- extract_argument( - the_frosting, - "layer_add_forecast_date", "forecast_date" - ) - - ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - - target_date <- forecast_date + ahead - } else { - max_time_value <- max( - workflows::extract_preprocessor(workflow)$max_time_value, - workflow$fit$meta$max_time_value, - max(new_data$time_value) - ) - - ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - - target_date <- max_time_value + ahead - } + validate_date( + forecast_date, expected_time_type, + call = rlang::expr(layer_add_forecast_date()) + ) + forecast_date <- coerce_time_type(forecast_date, expected_time_type) + ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") + target_date <- forecast_date + ahead + } else { + max_time_value <- as.Date(max( + workflows::extract_preprocessor(workflow)$max_time_value, + workflow$fit$meta$max_time_value, + max(new_data$time_value) + )) + ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") + target_date <- max_time_value + ahead } - components$predictions <- dplyr::bind_cols(components$predictions, + object$target_date <- target_date + components$predictions <- bind_cols( + components$predictions, target_date = target_date ) components diff --git a/R/layer_cdc_flatline_quantiles.R b/R/layer_cdc_flatline_quantiles.R index 760fdf068..8d16ba32f 100644 --- a/R/layer_cdc_flatline_quantiles.R +++ b/R/layer_cdc_flatline_quantiles.R @@ -55,6 +55,7 @@ #' @export #' #' @examples +#' library(dplyr) #' r <- epi_recipe(case_death_rate_subset) %>% #' # data is "daily", so we fit this to 1 ahead, the result will contain #' # 1 day ahead residuals @@ -64,24 +65,20 @@ #' #' forecast_date <- max(case_death_rate_subset$time_value) #' -#' latest <- get_test_data( -#' epi_recipe(case_death_rate_subset), case_death_rate_subset -#' ) -#' #' f <- frosting() %>% #' layer_predict() %>% #' layer_cdc_flatline_quantiles(aheads = c(7, 14, 21, 28), symmetrize = TRUE) #' -#' eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") +#' eng <- linear_reg(engine = "flatline") #' #' wf <- epi_workflow(r, eng, f) %>% fit(case_death_rate_subset) -#' preds <- suppressWarnings(predict(wf, new_data = latest)) %>% -#' dplyr::select(-time_value) %>% -#' dplyr::mutate(forecast_date = forecast_date) +#' preds <- forecast(wf) %>% +#' select(-time_value) %>% +#' mutate(forecast_date = forecast_date) #' preds #' #' preds <- preds %>% -#' unnest(.pred_distn_all) %>% +#' tidyr::unnest(.pred_distn_all) %>% #' pivot_quantiles_wider(.pred_distn) %>% #' mutate(target_date = forecast_date + ahead) #' @@ -166,15 +163,13 @@ slather.layer_cdc_flatline_quantiles <- } the_fit <- workflows::extract_fit_parsnip(workflow) if (!inherits(the_fit, "_flatline")) { - cli::cli_warn( - c( - "Predictions for this workflow were not produced by the {.cls flatline}", - "{.pkg parsnip} engine. Results may be unexpected. See {.fn epipredict::flatline}." - ) - ) + cli::cli_warn(c( + "Predictions for this workflow were not produced by the {.cls flatline}", + "{.pkg parsnip} engine. Results may be unexpected. See {.fn epipredict::flatline}." + )) } p <- components$predictions - ek <- kill_time_value(epi_keys_mold(components$mold)) + ek <- epi_keys_only(workflow) r <- grab_residuals(the_fit, components) avail_grps <- character(0L) @@ -200,7 +195,7 @@ slather.layer_cdc_flatline_quantiles <- c(cols_in_preds$missing_names, cols_in_resids$missing_names) )) } else { # not flatline, but we'll try - key_cols <- dplyr::bind_cols( + key_cols <- bind_cols( geo_value = components$mold$extras$roles$geo_value, components$mold$extras$roles$key ) @@ -215,26 +210,26 @@ slather.layer_cdc_flatline_quantiles <- object$by_key, c(cols_in_preds$missing_names, cols_in_resids$missing_names) )) - r <- dplyr::bind_cols(key_cols, r) + r <- bind_cols(key_cols, r) } } r <- r %>% - dplyr::select(tidyselect::all_of(c(avail_grps, ".resid"))) %>% - dplyr::group_by(!!!rlang::syms(avail_grps)) %>% - dplyr::summarise(.resid = list(.resid), .groups = "drop") + select(all_of(c(avail_grps, ".resid"))) %>% + group_by(!!!rlang::syms(avail_grps)) %>% + summarise(.resid = list(.resid), .groups = "drop") - res <- dplyr::left_join(p, r, by = avail_grps) %>% + res <- left_join(p, r, by = avail_grps) %>% dplyr::rowwise() %>% - dplyr::mutate( + mutate( .pred_distn_all = propagate_samples( .resid, .pred, object$quantile_levels, object$aheads, object$nsim, object$symmetrize, object$nonneg ) ) %>% - dplyr::select(tidyselect::all_of(c(avail_grps, ".pred_distn_all"))) + select(all_of(c(avail_grps, ".pred_distn_all"))) # res <- check_pname(res, components$predictions, object) - components$predictions <- dplyr::left_join( + components$predictions <- left_join( components$predictions, res, by = avail_grps @@ -271,7 +266,7 @@ propagate_samples <- function( } } res <- res[aheads] - list(tibble::tibble( + list(tibble( ahead = aheads, .pred_distn = map_vec( res, ~ dist_quantiles(quantile(.x, quantile_levels), quantile_levels) diff --git a/R/layer_naomit.R b/R/layer_naomit.R index 33c93f0ab..209a663b4 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -11,16 +11,15 @@ #' @return an updated `frosting` postprocessor #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' -#' latest <- get_test_data(recipe = r, x = jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% @@ -28,14 +27,14 @@ #' #' wf1 <- wf %>% add_frosting(f) #' -#' p <- predict(wf1, latest) +#' p <- forecast(wf1) #' p layer_naomit <- function(frosting, ..., id = rand_id("naomit")) { arg_is_chr_scalar(id) add_layer( frosting, layer_naomit_new( - terms = dplyr::enquos(...), + terms = enquos(...), id = id ) ) @@ -47,11 +46,12 @@ layer_naomit_new <- function(terms, id) { #' @export slather.layer_naomit <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) components$predictions <- components$predictions %>% - dplyr::filter(dplyr::if_any(dplyr::all_of(col_names), ~ !is.na(.x))) + filter(dplyr::if_any(all_of(col_names), ~ !is.na(.x))) components } diff --git a/R/layer_point_from_distn.R b/R/layer_point_from_distn.R index d1bc50f62..f14008748 100644 --- a/R/layer_point_from_distn.R +++ b/R/layer_point_from_distn.R @@ -16,17 +16,17 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>% fit(jhu) -#' -#' latest <- get_test_data(recipe = r, x = jhu) +#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>% +#' fit(jhu) #' #' f1 <- frosting() %>% #' layer_predict() %>% @@ -35,7 +35,7 @@ #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f1) #' -#' p1 <- predict(wf1, latest) +#' p1 <- forecast(wf1) #' p1 #' #' f2 <- frosting() %>% @@ -44,7 +44,7 @@ #' layer_naomit(.pred) #' wf2 <- wf %>% add_frosting(f2) #' -#' p2 <- predict(wf2, latest) +#' p2 <- forecast(wf2) #' p2 layer_point_from_distn <- function(frosting, ..., @@ -78,7 +78,6 @@ layer_point_from_distn_new <- function(type, name, id) { #' @export slather.layer_point_from_distn <- function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { rlang::warn( @@ -88,14 +87,15 @@ slather.layer_point_from_distn <- ) return(components) } + rlang::check_dots_empty() dstn <- match.fun(object$type)(dstn) if (is.null(object$name)) { components$predictions$.pred <- dstn } else { - dstn <- tibble::tibble(dstn = dstn) + dstn <- tibble(dstn = dstn) dstn <- check_pname(dstn, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!dstn) + components$predictions <- mutate(components$predictions, !!!dstn) } components } diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 2b7057bef..9275d910c 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -47,9 +47,10 @@ #' @return an updated `frosting` postprocessor #' @export #' @examples -#' jhu <- epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% -#' dplyr::select(geo_value, time_value, cases) +#' library(dplyr) +#' jhu <- jhu_csse_daily_subset %>% +#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% +#' select(geo_value, time_value, cases) #' #' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) #' @@ -74,21 +75,11 @@ #' df_pop_col = "value" #' ) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% +#' wf <- epi_workflow(r, linear_reg()) %>% #' fit(jhu) %>% #' add_frosting(f) #' -#' latest <- get_test_data( -#' recipe = r, -#' x = epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter( -#' time_value > "2021-11-01", -#' geo_value %in% c("ca", "ny") -#' ) %>% -#' dplyr::select(geo_value, time_value, cases) -#' ) -#' -#' predict(wf, latest) +#' forecast(wf) layer_population_scaling <- function(frosting, ..., df, @@ -103,7 +94,7 @@ layer_population_scaling <- function(frosting, arg_is_chr(df_pop_col, suffix, id) arg_is_chr(by, allow_null = TRUE) if (rate_rescaling <= 0) { - cli_stop("`rate_rescaling` should be a positive number") + cli_abort("`rate_rescaling` must be a positive number.") } add_layer( @@ -138,27 +129,22 @@ layer_population_scaling_new <- #' @export slather.layer_population_scaling <- function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() stopifnot( "Only one population column allowed for scaling" = length(object$df_pop_col) == 1 ) + rlang::check_dots_empty() - try_join <- try( - dplyr::left_join(components$predictions, object$df, - by = object$by - ), - silent = TRUE + object$by <- object$by %||% intersect( + epi_keys_only(components$predictions), + colnames(select(object$df, !object$df_pop_col)) ) - if (any(grepl("Join columns must be present in data", unlist(try_join)))) { - cli_stop(c( - "columns in `by` selectors of `layer_population_scaling` ", - "must be present in data and match" - )) - } + joinby <- list(x = names(object$by) %||% object$by, y = object$by) + hardhat::validate_column_names(components$predictions, joinby$x) + hardhat::validate_column_names(object$df, joinby$y) - object$df <- object$df %>% - dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) + # object$df <- object$df %>% + # dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) pop_col <- rlang::sym(object$df_pop_col) exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) @@ -166,18 +152,18 @@ slather.layer_population_scaling <- suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions)) - components$predictions <- dplyr::left_join( + components$predictions <- left_join( components$predictions, object$df, by = object$by, suffix = c("", ".df") ) %>% - dplyr::mutate(dplyr::across( - dplyr::all_of(col_names), + mutate(across( + all_of(col_names), ~ .x * !!pop_col / object$rate_rescaling, .names = "{.col}{suffix}" )) %>% - dplyr::select(-dplyr::any_of(col_to_remove)) + select(-any_of(col_to_remove)) components } diff --git a/R/layer_predict.R b/R/layer_predict.R index b40c24be5..6ca17ac24 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -16,6 +16,7 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' @@ -24,7 +25,7 @@ #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' latest <- jhu %>% filter(time_value >= max(time_value) - 14) #' #' # Predict layer alone @@ -45,12 +46,19 @@ layer_predict <- id = rand_id("predict_default")) { arg_is_chr_scalar(id) arg_is_chr_scalar(type, allow_null = TRUE) + assert_class(opts, "list") + dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE) + if (any(rlang::names2(dots_list) == "")) { + cli_abort("All `...` arguments must be named.", + class = "epipredict__layer_predict__unnamed_dot" + ) + } add_layer( frosting, layer_predict_new( type = type, opts = opts, - dots_list = rlang::list2(...), # can't figure how to use this + dots_list = dots_list, id = id ) ) @@ -62,17 +70,28 @@ layer_predict_new <- function(type, opts, dots_list, id) { } #' @export -slather.layer_predict <- function(object, components, workflow, new_data, ...) { +slather.layer_predict <- function(object, components, workflow, new_data, type = NULL, opts = list(), ...) { + arg_is_chr_scalar(type, allow_null = TRUE) + if (!is.null(object$type) && !is.null(type) && !identical(object$type, type)) { + cli_abort(" + Conflicting `type` settings were specified during frosting construction + (in call to `layer_predict()`) and while slathering (in call to + `slather()`/ `predict()`/etc.): {object$type} vs. {type}. Please remove + one of these `type` settings. + ", class = "epipredict__layer_predict__conflicting_type_settings") + } + assert_class(opts, "list") + the_fit <- workflows::extract_fit_parsnip(workflow) - components$predictions <- predict( + components$predictions <- rlang::inject(predict( the_fit, components$forged$predictors, - type = object$type, opts = object$opts - ) - components$predictions <- dplyr::bind_cols( - components$keys, components$predictions - ) + type = object$type %||% type, + opts = c(object$opts, opts), + !!!object$dots_list, ... + )) + components$predictions <- bind_cols(components$keys, components$predictions) components } diff --git a/R/layer_predictive_distn.R b/R/layer_predictive_distn.R index b72be6ec3..b28e0c765 100644 --- a/R/layer_predictive_distn.R +++ b/R/layer_predictive_distn.R @@ -20,17 +20,16 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' -#' latest <- get_test_data(recipe = r, x = jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% @@ -38,7 +37,7 @@ #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) #' -#' p <- predict(wf1, latest) +#' p <- forecast(wf1) #' p layer_predictive_distn <- function(frosting, ..., @@ -75,6 +74,7 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) { slather.layer_predictive_distn <- function(object, components, workflow, new_data, ...) { the_fit <- workflows::extract_fit_parsnip(workflow) + rlang::check_dots_empty() m <- components$predictions$.pred r <- grab_residuals(the_fit, components) @@ -92,9 +92,9 @@ slather.layer_predictive_distn <- if (!all(is.infinite(truncate))) { dstn <- distributional::dist_truncated(dstn, truncate[1], truncate[2]) } - dstn <- tibble::tibble(dstn = dstn) + dstn <- tibble(dstn = dstn) dstn <- check_pname(dstn, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!dstn) + components$predictions <- mutate(components$predictions, !!!dstn) components } diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index a99eed326..5f87ded29 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -1,9 +1,14 @@ #' Returns predictive quantiles #' #' This function calculates quantiles when the prediction was _distributional_. -#' Currently, the only distributional engine is `quantile_reg()`. -#' If this engine is used, then this layer will grab out estimated (or extrapolated) -#' quantiles at the requested quantile values. +#' +#' Currently, the only distributional modes/engines are +#' * `quantile_reg()` +#' * `smooth_quantile_reg()` +#' * `rand_forest(mode = "regression") %>% set_engine("grf_quantiles")` +#' +#' If these engines were used, then this layer will grab out estimated +#' (or extrapolated) quantiles at the requested quantile values. #' #' @param frosting a `frosting` postprocessor #' @param ... Unused, include for consistency with other layers. @@ -17,8 +22,9 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -28,15 +34,13 @@ #' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>% #' fit(jhu) #' -#' latest <- get_test_data(recipe = r, x = jhu) -#' #' f <- frosting() %>% #' layer_predict() %>% #' layer_quantile_distn() %>% #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) #' -#' p <- predict(wf1, latest) +#' p <- forecast(wf1) #' p layer_quantile_distn <- function(frosting, ..., @@ -81,6 +85,8 @@ slather.layer_quantile_distn <- "These are of class {.cls {class(dstn)}}." )) } + rlang::check_dots_empty() + dstn <- dist_quantiles( quantile(dstn, object$quantile_levels), object$quantile_levels @@ -90,9 +96,9 @@ slather.layer_quantile_distn <- if (!all(is.infinite(truncate))) { dstn <- snap(dstn, truncate[1], truncate[2]) } - dstn <- tibble::tibble(dstn = dstn) + dstn <- tibble(dstn = dstn) dstn <- check_pname(dstn, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!dstn) + components$predictions <- mutate(components$predictions, !!!dstn) components } diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index bd4ed27e3..eae151905 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -14,33 +14,38 @@ #' residual quantiles added to the prediction #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' -#' latest <- get_test_data(recipe = r, x = jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% -#' layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) %>% +#' layer_residual_quantiles( +#' quantile_levels = c(0.0275, 0.975), +#' symmetrize = FALSE +#' ) %>% #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) #' -#' p <- predict(wf1, latest) +#' p <- forecast(wf1) #' #' f2 <- frosting() %>% #' layer_predict() %>% -#' layer_residual_quantiles(quantile_levels = c(0.3, 0.7), by_key = "geo_value") %>% +#' layer_residual_quantiles( +#' quantile_levels = c(0.3, 0.7), +#' by_key = "geo_value" +#' ) %>% #' layer_naomit(.pred) #' wf2 <- wf %>% add_frosting(f2) #' -#' p2 <- predict(wf2, latest) +#' p2 <- forecast(wf2) layer_residual_quantiles <- function( frosting, ..., quantile_levels = c(0.05, 0.95), @@ -77,6 +82,8 @@ layer_residual_quantiles_new <- function( #' @export slather.layer_residual_quantiles <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() + the_fit <- workflows::extract_fit_parsnip(workflow) if (is.null(object$quantile_levels)) { @@ -88,7 +95,7 @@ slather.layer_residual_quantiles <- ## Handle any grouping requests if (length(object$by_key) > 0L) { - key_cols <- dplyr::bind_cols( + key_cols <- bind_cols( geo_value = components$mold$extras$roles$geo_value, components$mold$extras$roles$key ) @@ -101,33 +108,42 @@ slather.layer_residual_quantiles <- )) } if (length(common) > 0L) { - r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid"))) + r <- r %>% select(any_of(c(common, ".resid"))) common_in_r <- common[common %in% names(r)] - if (length(common_in_r) != length(common)) { + if (length(common_in_r) == length(common)) { + r <- left_join(key_cols, r, by = common_in_r) + } else { cli::cli_warn(c( "Some grouping keys are not in data.frame returned by the", "`residuals()` method. Groupings may not be correct." )) + r <- bind_cols(key_cols, select(r, .resid)) %>% + group_by(!!!rlang::syms(common)) } - r <- dplyr::bind_cols(key_cols, r) %>% - dplyr::group_by(!!!rlang::syms(common)) } } r <- r %>% - dplyr::summarize( + summarize( dstn = list(quantile( c(.resid, s * .resid), probs = object$quantile_levels, na.rm = TRUE )) ) + # Check for NA + if (any(sapply(r$dstn, is.na))) { + cli::cli_abort(c( + "Residual quantiles could not be calculated due to missing residuals.", + i = "This may be due to `n_train` < `ahead` in your {.cls epi_recipe}." + )) + } estimate <- components$predictions$.pred - res <- tibble::tibble( + res <- tibble( .pred_distn = dist_quantiles(map2(estimate, r$dstn, "+"), object$quantile_levels) ) res <- check_pname(res, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!res) + components$predictions <- mutate(components$predictions, !!!res) components } diff --git a/R/layer_threshold_preds.R b/R/layer_threshold_preds.R index 957ac2419..56f8059ab 100644 --- a/R/layer_threshold_preds.R +++ b/R/layer_threshold_preds.R @@ -22,23 +22,20 @@ #' @return an updated `frosting` postprocessor #' @export #' @examples - +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value < "2021-03-08", -#' geo_value %in% c("ak", "ca", "ar")) +#' filter(time_value < "2021-03-08", geo_value %in% c("ak", "ca", "ar")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' -#' latest <- get_test_data(r, jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% #' layer_threshold(.pred, lower = 0.180, upper = 0.310) #' wf <- wf %>% add_frosting(f) -#' p <- predict(wf, latest) +#' p <- forecast(wf) #' p layer_threshold <- function(frosting, ..., lower = 0, upper = Inf, id = rand_id("threshold")) { @@ -48,7 +45,7 @@ layer_threshold <- add_layer( frosting, layer_threshold_new( - terms = dplyr::enquos(...), + terms = enquos(...), lower = lower, upper = upper, id = id @@ -100,16 +97,12 @@ snap.dist_quantiles <- function(x, lower, upper, ...) { #' @export slather.layer_threshold <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) components$predictions <- components$predictions %>% - dplyr::mutate( - dplyr::across( - dplyr::all_of(col_names), - ~ snap(.x, object$lower, object$upper) - ) - ) + mutate(across(all_of(col_names), ~ snap(.x, object$lower, object$upper))) components } diff --git a/R/layer_unnest.R b/R/layer_unnest.R index 64b17a306..a6fc9f0af 100644 --- a/R/layer_unnest.R +++ b/R/layer_unnest.R @@ -15,7 +15,7 @@ layer_unnest <- function(frosting, ..., id = rand_id("unnest")) { add_layer( frosting, layer_unnest_new( - terms = dplyr::enquos(...), + terms = enquos(...), id = id ) ) @@ -28,6 +28,7 @@ layer_unnest_new <- function(terms, id) { #' @export slather.layer_unnest <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layers.R b/R/layers.R index c93423d32..aa515a917 100644 --- a/R/layers.R +++ b/R/layers.R @@ -41,15 +41,15 @@ layer <- function(subclass, ..., .prefix = "layer_") { #' in the layer, and the values are the new values to update the layer with. #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' latest <- jhu %>% -#' dplyr::filter(time_value >= max(time_value) - 14) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) +#' latest <- jhu %>% filter(time_value >= max(time_value) - 14) #' #' # Specify a `forecast_date` that is greater than or equal to `as_of` date #' f <- frosting() %>% @@ -144,11 +144,12 @@ pull_layer_name <- function(x) { #' @export #' @rdname layer-processors -validate_layer <- function(x, ..., arg = "`x`", call = caller_env()) { +validate_layer <- function(x, ..., arg = rlang::caller_arg(x), + call = caller_env()) { rlang::check_dots_empty() if (!is_layer(x)) { - glubort( - "{arg} must be a frosting layer, not a {class(x)[[1]]}.", + cli::cli_abort( + "{arg} must be a frosting layer, not a {.cls {class(x)[[1]]}}.", .call = call ) } diff --git a/R/make_grf_quantiles.R b/R/make_grf_quantiles.R new file mode 100644 index 000000000..253ea1ac7 --- /dev/null +++ b/R/make_grf_quantiles.R @@ -0,0 +1,193 @@ +#' Random quantile forests via grf +#' +#' [grf::quantile_forest()] fits random forests in a way that makes it easy +#' to calculate _quantile_ forests. Currently, this is the only engine +#' provided here, since quantile regression is the typical use-case. +#' +#' @section Tuning Parameters: +#' +#' This model has 3 tuning parameters: +#' +#' - `mtry`: # Randomly Selected Predictors (type: integer, default: see below) +#' - `trees`: # Trees (type: integer, default: 2000L) +#' - `min_n`: Minimal Node Size (type: integer, default: 5) +#' +#' `mtry` depends on the number of columns in the design matrix. +#' The default in [grf::quantile_forest()] is `min(ceiling(sqrt(ncol(X)) + 20), ncol(X))`. +#' +#' For categorical predictors, a one-hot encoding is always used. This makes +#' splitting efficient, but has implications for the `mtry` choice. A factor +#' with many levels will become a large number of columns in the design matrix +#' which means that some of these may be selected frequently for potential splits. +#' This is different than in other implementations of random forest. For more +#' details, see [the `grf` discussion](https://grf-labs.github.io/grf/articles/categorical_inputs.html). +#' +#' @section Translation from parsnip to the original package: +#' +#' ```{r, translate-engine} +#' rand_forest( +#' mode = "regression", # you must specify the `mode = regression` +#' mtry = integer(1), +#' trees = integer(1), +#' min_n = integer(1) +#' ) %>% +#' set_engine("grf_quantiles") %>% +#' translate() +#' ``` +#' +#' @section Case weights: +#' +#' Case weights are not supported. +#' +#' @examples +#' library(grf) +#' tib <- data.frame( +#' y = rnorm(100), x = rnorm(100), z = rnorm(100), +#' f = factor(sample(letters[1:3], 100, replace = TRUE)) +#' ) +#' spec <- rand_forest(engine = "grf_quantiles", mode = "regression") +#' out <- fit(spec, formula = y ~ x + z, data = tib) +#' predict(out, new_data = tib[1:5, ]) %>% +#' pivot_quantiles_wider(.pred) +#' +#' # -- adjusting the desired quantiles +#' +#' spec <- rand_forest(mode = "regression") %>% +#' set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10)) +#' out <- fit(spec, formula = y ~ x + z, data = tib) +#' predict(out, new_data = tib[1:5, ]) %>% +#' pivot_quantiles_wider(.pred) +#' +#' # -- a more complicated task +#' +#' library(dplyr) +#' dat <- case_death_rate_subset %>% +#' filter(time_value > as.Date("2021-10-01")) +#' rec <- epi_recipe(dat) %>% +#' step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_ahead(death_rate, ahead = 7) %>% +#' step_epi_naomit() +#' frost <- frosting() %>% +#' layer_predict() %>% +#' layer_threshold(.pred) +#' spec <- rand_forest(mode = "regression") %>% +#' set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75)) +#' +#' ewf <- epi_workflow(rec, spec, frost) %>% +#' fit(dat) %>% +#' forecast() +#' ewf %>% +#' rename(forecast_date = time_value) %>% +#' mutate(target_date = forecast_date + 7L) %>% +#' pivot_quantiles_wider(.pred) +#' +#' @name grf_quantiles +NULL + + + +make_grf_quantiles <- function() { + parsnip::set_model_engine( + model = "rand_forest", mode = "regression", eng = "grf_quantiles" + ) + parsnip::set_dependency( + model = "rand_forest", eng = "grf_quantiles", pkg = "grf", + mode = "regression" + ) + + + # These are the arguments to the parsnip::rand_forest() that must be + # translated from grf::quantile_forest + parsnip::set_model_arg( + model = "rand_forest", + eng = "grf_quantiles", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE + ) + parsnip::set_model_arg( + model = "rand_forest", + eng = "grf_quantiles", + parsnip = "trees", + original = "num.trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE + ) + parsnip::set_model_arg( + model = "rand_forest", + eng = "grf_quantiles", + parsnip = "min_n", + original = "min.node.size", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE + ) + + # the `value` list describes how grf::quantile_forest expects to receive + # arguments. In particular, it needs X and Y to be passed in as a matrices. + # But the matrix interface in parsnip calls these x and y. So the data + # slot translates them + # + # protect - prevents the user from passing X and Y arguments themselves + # defaults - engine specific arguments (not model specific) that we allow + # the user to change + parsnip::set_fit( + model = "rand_forest", + eng = "grf_quantiles", + mode = "regression", + value = list( + interface = "matrix", + protect = c("X", "Y"), + data = c(x = "X", y = "Y"), + func = c(pkg = "grf", fun = "quantile_forest"), + defaults = list( + quantiles = c(0.1, 0.5, 0.9), + num.threads = 1L, + seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max)) + ) + ) + ) + + parsnip::set_encoding( + model = "rand_forest", + eng = "grf_quantiles", + mode = "regression", + options = list( + # one hot is the closest to typical factor handling in randomForest + # (1 vs all splitting), though since we aren't bagging, + # factors with many levels could be visited frequently + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) + ) + + # turn the predictions into a tibble with a dist_quantiles column + process_qrf_preds <- function(x, object) { + quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig + x <- x$predictions + out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) + out <- dist_quantiles(out, list(quantile_levels)) + return(dplyr::tibble(.pred = out)) + } + + parsnip::set_pred( + model = "rand_forest", + eng = "grf_quantiles", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = process_qrf_preds, + func = c(fun = "predict"), + # map between parsnip::predict args and grf::quantile_forest args + args = list( + object = quote(object$fit), + newdata = quote(new_data), + seed = rlang::expr(sample.int(10^5, 1)), + verbose = FALSE + ) + ) + ) +} diff --git a/R/make_quantile_reg.R b/R/make_quantile_reg.R index 832ef50f8..2157aa470 100644 --- a/R/make_quantile_reg.R +++ b/R/make_quantile_reg.R @@ -3,26 +3,31 @@ #' @description #' `quantile_reg()` generates a quantile regression model _specification_ for #' the [tidymodels](https://www.tidymodels.org/) framework. Currently, the -#' only supported engine is "rq" which uses [quantreg::rq()]. +#' only supported engines are "rq", which uses [quantreg::rq()]. +#' Quantile regression is also possible by combining [parsnip::rand_forest()] +#' with the `grf` engine. See [grf_quantiles]. #' #' @param mode A single character string for the type of model. #' The only possible value for this model is "regression". #' @param engine Character string naming the fitting function. Currently, only -#' "rq" is supported. +#' "rq" and "grf" are supported. #' @param quantile_levels A scalar or vector of values in (0, 1) to determine which #' quantiles to estimate (default is 0.5). +#' @param method A fitting method used by [quantreg::rq()]. See the +#' documentation for a list of options. #' #' @export #' #' @seealso [fit.model_spec()], [set_engine()] #' -#' @importFrom quantreg rq +#' #' @examples +#' library(quantreg) #' tib <- data.frame(y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100)) #' rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) %>% set_engine("rq") #' ff <- rq_spec %>% fit(y ~ ., data = tib) #' predict(ff, new_data = tib) -quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0.5) { +quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0.5, method = "br") { # Check for correct mode if (mode != "regression") { cli_abort("`mode` must be 'regression'") @@ -35,7 +40,7 @@ quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0 cli::cli_warn("Sorting `quantile_levels` to increasing order.") quantile_levels <- sort(quantile_levels) } - args <- list(quantile_levels = rlang::enquo(quantile_levels)) + args <- list(quantile_levels = rlang::enquo(quantile_levels), method = rlang::enquo(method)) # Save some empty slots for future parts of the specification parsnip::new_model_spec( @@ -54,9 +59,6 @@ make_quantile_reg <- function() { parsnip::set_new_model("quantile_reg") } parsnip::set_model_mode("quantile_reg", "regression") - - - parsnip::set_model_engine("quantile_reg", "regression", eng = "rq") parsnip::set_dependency("quantile_reg", eng = "rq", pkg = "quantreg") @@ -68,6 +70,14 @@ make_quantile_reg <- function() { func = list(pkg = "quantreg", fun = "rq"), has_submodel = FALSE ) + parsnip::set_model_arg( + model = "quantile_reg", + eng = "rq", + parsnip = "method", + original = "method", + func = list(pkg = "quantreg", fun = "rq"), + has_submodel = FALSE + ) parsnip::set_fit( model = "quantile_reg", @@ -78,7 +88,6 @@ make_quantile_reg <- function() { protect = c("formula", "data", "weights"), func = c(pkg = "quantreg", fun = "rq"), defaults = list( - method = "br", na.action = rlang::expr(stats::na.omit), model = FALSE ) @@ -101,12 +110,11 @@ make_quantile_reg <- function() { object <- parsnip::extract_fit_engine(object) type <- class(object)[1] - # can't make a method because object is second out <- switch(type, rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile rqs = { - x <- lapply(unname(split(x, seq(nrow(x)))), function(q) sort(q)) + x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) dist_quantiles(x, list(object$tau)) }, cli_abort(c( @@ -114,10 +122,9 @@ make_quantile_reg <- function() { i = "See {.fun quantreg::rq}." )) ) - return(data.frame(.pred = out)) + return(dplyr::tibble(.pred = out)) } - parsnip::set_pred( model = "quantile_reg", eng = "rq", diff --git a/R/make_smooth_quantile_reg.R b/R/make_smooth_quantile_reg.R index 9ab3a366b..448ee0fa5 100644 --- a/R/make_smooth_quantile_reg.R +++ b/R/make_smooth_quantile_reg.R @@ -21,8 +21,8 @@ #' #' @seealso [fit.model_spec()], [set_engine()] #' -#' @importFrom smoothqr smooth_qr #' @examples +#' library(smoothqr) #' tib <- data.frame( #' y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100), #' y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100), @@ -62,17 +62,16 @@ #' lines(pl$x, pl$`0.8`, col = "blue") #' lines(pl$x, pl$`0.5`, col = "red") #' -#' if (require("ggplot2")) { -#' ggplot(data.frame(x = x, y = y), aes(x)) + -#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + -#' geom_point(aes(y = y), colour = "grey") + # observed data -#' geom_function(fun = sin, colour = "black") + # truth -#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data -#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction -#' theme_bw() + -#' coord_cartesian(xlim = c(0, NA)) + -#' ylab("y") -#' } +#' library(ggplot2) +#' ggplot(data.frame(x = x, y = y), aes(x)) + +#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + +#' geom_point(aes(y = y), colour = "grey") + # observed data +#' geom_function(fun = sin, colour = "black") + # truth +#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data +#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction +#' theme_bw() + +#' coord_cartesian(xlim = c(0, NA)) + +#' ylab("y") smooth_quantile_reg <- function( mode = "regression", engine = "smoothqr", diff --git a/R/model-methods.R b/R/model-methods.R new file mode 100644 index 000000000..f3b374879 --- /dev/null +++ b/R/model-methods.R @@ -0,0 +1,131 @@ +#' Add a model to an `epi_workflow` +#' +#' @seealso [workflows::add_model()] +#' - `Add_model()` adds a parsnip model to the `epi_workflow`. +#' +#' - `Remove_model()` removes the model specification as well as any fitted +#' model object. Any extra formulas are also removed. +#' +#' - `Update_model()` first removes the model then adds the new +#' specification to the workflow. +#' +#' @details +#' Has the same behaviour as [workflows::add_model()] but also ensures +#' that the returned object is an `epi_workflow`. +#' +#' This family is called `Add_*` / `Update_*` / `Remove_*` to avoid +#' masking the related functions in `{workflows}`. We also provide +#' aliases with the lower-case names. However, in the event that +#' `{workflows}` is loaded after `{epipredict}`, these may fail to function +#' properly. +#' +#' @inheritParams workflows::add_model +#' +#' @param x An `epi_workflow`. +#' +#' @param spec A parsnip model specification. +#' +#' @param ... Not used. +#' +#' @return +#' `x`, updated with a new, updated, or removed model. +#' +#' @export +#' @examples +#' library(dplyr) +#' jhu <- case_death_rate_subset %>% +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' +#' r <- epi_recipe(jhu) %>% +#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_ahead(death_rate, ahead = 7) +#' +#' rf_model <- rand_forest(mode = "regression") +#' +#' wf <- epi_workflow(r) +#' +#' wf <- wf %>% Add_model(rf_model) +#' wf +#' +#' lm_model <- linear_reg() +#' +#' wf <- Update_model(wf, lm_model) +#' wf +#' +#' wf <- Remove_model(wf) +#' wf +#' @export +Add_model <- function(x, spec, ..., formula = NULL) { + UseMethod("Add_model") +} + +#' @rdname Add_model +#' @export +Remove_model <- function(x) { + UseMethod("Remove_model") +} + +#' @rdname Add_model +#' @export +Update_model <- function(x, spec, ..., formula = NULL) { + UseMethod("Update_model") +} + +#' @rdname Add_model +#' @export +Add_model.epi_workflow <- function(x, spec, ..., formula = NULL) { + workflows::add_model(x, spec, ..., formula = formula) +} + +#' @rdname Add_model +#' @export +Remove_model.epi_workflow <- function(x) { + workflows:::validate_is_workflow(x) + + if (!workflows:::has_spec(x)) { + rlang::warn("The workflow has no model to remove.") + } + + new_epi_workflow( + pre = x$pre, + fit = workflows:::new_stage_fit(), + post = x$post, + trained = FALSE + ) +} + +#' @rdname Add_model +#' @export +Update_model.epi_workflow <- function(x, spec, ..., formula = NULL) { + rlang::check_dots_empty() + x <- Remove_model(x) + Add_model(x, spec, ..., formula = formula) +} + + +#' @rdname Add_model +#' @export +Add_model.workflow <- workflows::add_model + +#' @rdname Add_model +#' @export +Remove_model.workflow <- workflows::remove_model + +#' @rdname Add_model +#' @export +Update_model.workflow <- workflows::update_model + + +# Aliases ----------------------------------------------------------------- + +#' @rdname Add_model +#' @export +add_model <- Add_model + +#' @rdname Add_model +#' @export +remove_model <- Remove_model + +#' @rdname Add_model +#' @export +update_model <- Update_model diff --git a/R/pivot_quantiles.R b/R/pivot_quantiles.R index e632748df..f014961e6 100644 --- a/R/pivot_quantiles.R +++ b/R/pivot_quantiles.R @@ -6,16 +6,18 @@ #' @export #' #' @examples +#' library(dplyr) +#' library(tidyr) #' edf <- case_death_rate_subset[1:3, ] #' edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) #' -#' edf_nested <- edf %>% dplyr::mutate(q = nested_quantiles(q)) -#' edf_nested %>% tidyr::unnest(q) +#' edf_nested <- edf %>% mutate(q = nested_quantiles(q)) +#' edf_nested %>% unnest(q) nested_quantiles <- function(x) { stopifnot(is_dist_quantiles(x)) distributional:::dist_apply(x, .f = function(z) { - tibble::as_tibble(vec_data(z)) %>% - dplyr::mutate(dplyr::across(tidyselect::everything(), as.double)) %>% + as_tibble(vec_data(z)) %>% + mutate(across(everything(), as.double)) %>% vctrs::list_of() }) } @@ -47,31 +49,26 @@ nested_quantiles <- function(x) { #' @examples #' d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:4, 1:3 / 4)) #' d2 <- c(dist_quantiles(2:4, 2:4 / 5), dist_quantiles(3:5, 2:4 / 5)) -#' tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) +#' tib <- tibble(g = c("a", "b"), d1 = d1, d2 = d2) #' #' pivot_quantiles_longer(tib, "d1") -#' pivot_quantiles_longer(tib, tidyselect::ends_with("1")) +#' pivot_quantiles_longer(tib, dplyr::ends_with("1")) #' pivot_quantiles_longer(tib, d1, d2) pivot_quantiles_longer <- function(.data, ..., .ignore_length_check = FALSE) { cols <- validate_pivot_quantiles(.data, ...) - .data <- .data %>% - dplyr::mutate(dplyr::across(tidyselect::all_of(cols), nested_quantiles)) + .data <- .data %>% mutate(across(all_of(cols), nested_quantiles)) if (length(cols) > 1L) { lengths_check <- .data %>% - dplyr::transmute(dplyr::across( - tidyselect::all_of(cols), - ~ map_int(.x, vctrs::vec_size) - )) %>% + dplyr::transmute(across(all_of(cols), ~ map_int(.x, vctrs::vec_size))) %>% as.matrix() %>% apply(1, function(x) dplyr::n_distinct(x) == 1L) %>% all() if (lengths_check) { - .data <- tidyr::unnest(.data, tidyselect::all_of(cols), names_sep = "_") + .data <- tidyr::unnest(.data, all_of(cols), names_sep = "_") } else { if (.ignore_length_check) { for (col in cols) { - .data <- .data %>% - tidyr::unnest(tidyselect::all_of(col), names_sep = "_") + .data <- .data %>% tidyr::unnest(all_of(col), names_sep = "_") } } else { cli::cli_abort(c( @@ -82,7 +79,7 @@ pivot_quantiles_longer <- function(.data, ..., .ignore_length_check = FALSE) { } } } else { - .data <- .data %>% tidyr::unnest(tidyselect::all_of(cols)) + .data <- .data %>% tidyr::unnest(all_of(cols)) } .data } @@ -110,25 +107,29 @@ pivot_quantiles_longer <- function(.data, ..., .ignore_length_check = FALSE) { #' tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) #' #' pivot_quantiles_wider(tib, c("d1", "d2")) -#' pivot_quantiles_wider(tib, tidyselect::starts_with("d")) +#' pivot_quantiles_wider(tib, dplyr::starts_with("d")) #' pivot_quantiles_wider(tib, d2) pivot_quantiles_wider <- function(.data, ...) { cols <- validate_pivot_quantiles(.data, ...) - .data <- .data %>% - dplyr::mutate(dplyr::across(tidyselect::all_of(cols), nested_quantiles)) + .data <- .data %>% mutate(across(all_of(cols), nested_quantiles)) checks <- map_lgl(cols, ~ diff(range(vctrs::list_sizes(.data[[.x]]))) == 0L) if (!all(checks)) { nms <- cols[!checks] - cli::cli_abort( - c("Quantiles must be the same length and have the same set of taus.", - i = "Check failed for variables(s) {.var {nms}}." - ) - ) + cli::cli_abort(c( + "Quantiles must be the same length and have the same set of taus.", + i = "Check failed for variables(s) {.var {nms}}." + )) } + + # tidyr::pivot_wider can crash if there are duplicates, this generally won't + # happen in our context. To avoid, silently add an index column and remove it + # later + .hidden_index <- seq_len(nrow(.data)) + .data <- tibble::add_column(.data, .hidden_index = .hidden_index) if (length(cols) > 1L) { for (col in cols) { .data <- .data %>% - tidyr::unnest(tidyselect::all_of(col)) %>% + tidyr::unnest(all_of(col)) %>% tidyr::pivot_wider( names_from = "quantile_levels", values_from = "values", names_prefix = paste0(col, "_") @@ -136,14 +137,18 @@ pivot_quantiles_wider <- function(.data, ...) { } } else { .data <- .data %>% - tidyr::unnest(tidyselect::all_of(cols)) %>% + tidyr::unnest(all_of(cols)) %>% tidyr::pivot_wider(names_from = "quantile_levels", values_from = "values") } - .data + select(.data, -.hidden_index) } pivot_quantiles <- function(.data, ...) { - lifecycle::deprecate_stop("0.0.6", "pivot_quantiles()", "pivot_quantiles_wider()") + msg <- c( + "{.fn pivot_quantiles} was deprecated in {.pkg epipredict} 0.0.6", + i = "Please use {.fn pivot_quantiles_wider} instead." + ) + lifecycle::deprecate_stop(msg) } validate_pivot_quantiles <- function(.data, ...) { diff --git a/R/reexports-tidymodels.R b/R/reexports-tidymodels.R index 250bae962..3b28ac5c5 100644 --- a/R/reexports-tidymodels.R +++ b/R/reexports-tidymodels.R @@ -2,6 +2,10 @@ #' @export generics::fit +#' @importFrom generics forecast +#' @export +generics::forecast + #' @importFrom recipes prep #' @export recipes::prep @@ -9,3 +13,15 @@ recipes::prep #' @importFrom recipes bake #' @export recipes::bake + +#' @importFrom recipes rand_id +#' @export +recipes::rand_id + +#' @importFrom tibble tibble +#' @export +tibble::tibble + +#' @importFrom generics tidy +#' @export +generics::tidy diff --git a/R/step_epi_naomit.R b/R/step_epi_naomit.R index 1cbc9c5d9..d81ba398d 100644 --- a/R/step_epi_naomit.R +++ b/R/step_epi_naomit.R @@ -22,6 +22,6 @@ step_epi_naomit <- function(recipe) { print.step_naomit <- # not exported from recipes package function(x, width = max(20, options()$width - 30), ...) { title <- "Removing rows with NA values in " - print_step(x$columns, x$terms, x$trained, title, width) + recipes::print_step(x$columns, x$terms, x$trained, title, width) invisible(x) } diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index ec5428d8f..465d64e7f 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -15,16 +15,12 @@ #' for this step. See [recipes::selections()] for more details. #' @param role For model terms created by this step, what analysis role should #' they be assigned? `lag` is default a predictor while `ahead` is an outcome. -#' @param trained A logical to indicate if the quantities for -#' preprocessing have been estimated. #' @param lag,ahead A vector of integers. Each specified column will #' be the lag or lead for each value in the vector. Lag integers must be #' nonnegative, while ahead integers must be positive. -#' @param prefix A prefix to indicate what type of variable this is +#' @param prefix A character string that will be prefixed to the new column. #' @param default Determines what fills empty rows #' left by leading/lagging (defaults to NA). -#' @param columns A character string of variable names that will -#' be populated (eventually) by the `terms` argument. #' @param skip A logical. Should the step be skipped when the #' recipe is baked by [bake()]? While all operations are baked #' when [prep()] is run, some operations may not be able to be @@ -53,43 +49,36 @@ step_epi_lag <- function(recipe, ..., - role = "predictor", - trained = FALSE, lag, + role = "predictor", prefix = "lag_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_lag")) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli_abort("This step can only operate on an `epi_recipe`.") } if (missing(lag)) { - rlang::abort( - c("The `lag` argument must not be empty.", - i = "Did you perhaps pass an integer in `...` accidentally?" - ) - ) + cli_abort(c( + "The `lag` argument must not be empty.", + i = "Did you perhaps pass an integer in `...` accidentally?" + )) } arg_is_nonneg_int(lag) arg_is_chr_scalar(prefix, id) - if (!is.null(columns)) { - rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to lag." - )) - } - add_step( + + recipes::add_step( recipe, step_epi_lag_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, - trained = trained, - lag = lag, + trained = FALSE, + lag = as.integer(lag), prefix = prefix, default = default, - keys = epi_keys(recipe), - columns = columns, + keys = key_colnames(recipe), + columns = NULL, skip = skip, id = id ) @@ -104,43 +93,36 @@ step_epi_lag <- step_epi_ahead <- function(recipe, ..., - role = "outcome", - trained = FALSE, ahead, + role = "outcome", prefix = "ahead_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_ahead")) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli_abort("This step can only operate on an `epi_recipe`.") } if (missing(ahead)) { - rlang::abort( - c("The `ahead` argument must not be empty.", - i = "Did you perhaps pass an integer in `...` accidentally?" - ) - ) + cli_abort(c( + "The `ahead` argument must not be empty.", + i = "Did you perhaps pass an integer in `...` accidentally?" + )) } arg_is_nonneg_int(ahead) arg_is_chr_scalar(prefix, id) - if (!is.null(columns)) { - rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to lead." - )) - } - add_step( + + recipes::add_step( recipe, step_epi_ahead_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, - trained = trained, - ahead = ahead, + trained = FALSE, + ahead = as.integer(ahead), prefix = prefix, default = default, - keys = epi_keys(recipe), - columns = columns, + keys = key_colnames(recipe), + columns = NULL, skip = skip, id = id ) @@ -151,7 +133,7 @@ step_epi_ahead <- step_epi_lag_new <- function(terms, role, trained, lag, prefix, default, keys, columns, skip, id) { - step( + recipes::step( subclass = "epi_lag", terms = terms, role = role, @@ -169,7 +151,7 @@ step_epi_lag_new <- step_epi_ahead_new <- function(terms, role, trained, ahead, prefix, default, keys, columns, skip, id) { - step( + recipes::step( subclass = "epi_ahead", terms = terms, role = role, @@ -196,7 +178,7 @@ prep.step_epi_lag <- function(x, training, info = NULL, ...) { prefix = x$prefix, default = x$default, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) @@ -212,7 +194,7 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { prefix = x$prefix, default = x$default, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) @@ -223,7 +205,7 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { #' @export bake.step_epi_lag <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, lag = object$lag) %>% - dplyr::mutate( + mutate( newname = glue::glue("{object$prefix}{lag}_{col}"), shift_val = lag, lag = NULL @@ -233,32 +215,28 @@ bake.step_epi_lag <- function(object, new_data, ...) { new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { - rlang::abort( - paste0( - "Name collision occured in `", class(object)[1], - "`. The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + cli_abort(c( + "Name collision occured in {.cls {class(object)[1]}}", + "The following variable name{?s} already exist{?s/}: {.val {new_data_names[intersection]}}." + )) } ok <- object$keys shifted <- reduce( pmap(grid, epi_shift_single, x = new_data, key_cols = ok), - dplyr::full_join, + full_join, by = ok ) - dplyr::full_join(new_data, shifted, by = ok) %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + full_join(new_data, shifted, by = ok) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% + arrange(time_value) %>% + ungroup() } #' @export bake.step_epi_ahead <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, ahead = object$ahead) %>% - dplyr::mutate( + mutate( newname = glue::glue("{object$prefix}{ahead}_{col}"), shift_val = -ahead, ahead = NULL @@ -268,26 +246,22 @@ bake.step_epi_ahead <- function(object, new_data, ...) { new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { - rlang::abort( - paste0( - "Name collision occured in `", class(object)[1], - "`. The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + cli_abort(c( + "Name collision occured in {.cls {class(object)[1]}}", + "The following variable name{?s} already exist{?s/}: {.val {new_data_names[intersection]}}." + )) } ok <- object$keys shifted <- reduce( pmap(grid, epi_shift_single, x = new_data, key_cols = ok), - dplyr::full_join, + full_join, by = ok ) - dplyr::full_join(new_data, shifted, by = ok) %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + full_join(new_data, shifted, by = ok) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% + arrange(time_value) %>% + ungroup() } diff --git a/R/step_epi_slide.R b/R/step_epi_slide.R new file mode 100644 index 000000000..c7d3f9fbd --- /dev/null +++ b/R/step_epi_slide.R @@ -0,0 +1,293 @@ +#' Calculate a rolling window transformation +#' +#' `step_epi_slide()` creates a *specification* of a recipe step +#' that will generate one or more new columns of derived data by "sliding" +#' a computation along existing data. +#' +#' @inheritParams step_epi_lag +#' @param .f A function in one of the following formats: +#' 1. An unquoted function name with no arguments, e.g., `mean` +#' 2. A character string name of a function, e.g., `"mean"`. Note that this +#' can be difficult to examine for mistakes (so the misspelling `"maen"` +#' won't produce an error until you try to actually fit the model) +#' 3. A base `R` lambda function, e.g., `function(x) mean(x, na.rm = TRUE)` +#' 4. A new-style base `R` lambda function, e.g., `\(x) mean(x, na.rm = TRUE)` +#' 5. A one-sided formula like `~ mean(.x, na.rm = TRUE)`. +#' +#' Note that in cases 3 and 4, `x` can be any variable name you like (for +#' example `\(dog) mean(dog, na.rm = TRUE)` will work). But in case 5, the +#' argument must be named `.x`. A common, though very difficult to debug +#' error is using something like `function(x) mean`. This will not work +#' because it returns the function mean, rather than `mean(x)` +#' @param .window_size the size of the sliding window, required. Usually a +#' non-negative integer will suffice (e.g. for data indexed by date, but more +#' restrictive in other time_type cases (see [epiprocess::epi_slide()] for +#' details). For example, set to 7 for a 7-day window. +#' @param .align a character string indicating how the window should be aligned. +#' By default, this is "right", meaning the slide_window will be anchored with +#' its right end point on the reference date. (see [epiprocess::epi_slide()] +#' for details). +#' @param f_name a character string of at most 20 characters that describes the +#' function. This will be combined with `prefix` and the columns in `...` to +#' name the result using `{prefix}{f_name}_{column}`. By default it will be +#' determined automatically using `clean_f_name()`. +#' +#' @template step-return +#' +#' @export +#' @examples +#' library(dplyr) +#' jhu <- case_death_rate_subset %>% +#' filter(time_value >= as.Date("2021-01-01"), geo_value %in% c("ca", "ny")) +#' rec <- epi_recipe(jhu) %>% +#' step_epi_slide(case_rate, death_rate, +#' .f = \(x) mean(x, na.rm = TRUE), +#' .window_size = 7L +#' ) +#' bake(prep(rec, jhu), new_data = NULL) +step_epi_slide <- function(recipe, + ..., + .f, + .window_size = NULL, + .align = c("right", "center", "left"), + role = "predictor", + prefix = "epi_slide_", + f_name = clean_f_name(.f), + skip = FALSE, + id = rand_id("epi_slide")) { + if (!is_epi_recipe(recipe)) { + cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") + } + .f <- validate_slide_fun(.f) + if (is.null(.window_size)) { + cli_abort("step_epi_slide: `.window_size` must be specified.") + } + epiprocess:::validate_slide_window_arg(.window_size, attributes(recipe$template)$metadata$time_type) + .align <- rlang::arg_match(.align) + arg_is_chr_scalar(role, prefix, id) + arg_is_lgl_scalar(skip) + + recipes::add_step( + recipe, + step_epi_slide_new( + terms = enquos(...), + .window_size = .window_size, + .align = .align, + .f = .f, + f_name = f_name, + role = role, + trained = FALSE, + prefix = prefix, + keys = key_colnames(recipe), + columns = NULL, + skip = skip, + id = id + ) + ) +} + + +step_epi_slide_new <- + function(terms, + .window_size, + .align, + .f, + f_name, + role, + trained, + prefix, + keys, + columns, + skip, + id) { + recipes::step( + subclass = "epi_slide", + terms = terms, + .window_size = .window_size, + .align = .align, + .f = .f, + f_name = f_name, + role = role, + trained = trained, + prefix = prefix, + keys = keys, + columns = columns, + skip = skip, + id = id + ) + } + + +#' @export +prep.step_epi_slide <- function(x, training, info = NULL, ...) { + col_names <- recipes::recipes_eval_select(x$terms, data = training, info = info) + + recipes::check_type(training[, col_names], types = c("double", "integer")) + + step_epi_slide_new( + terms = x$terms, + .window_size = x$.window_size, + .align = x$.align, + .f = x$.f, + f_name = x$f_name, + role = x$role, + trained = TRUE, + prefix = x$prefix, + keys = x$keys, + columns = col_names, + skip = x$skip, + id = x$id + ) +} + + +#' @export +bake.step_epi_slide <- function(object, new_data, ...) { + recipes::check_new_data(names(object$columns), object, new_data) + col_names <- object$columns + name_prefix <- paste0(object$prefix, object$f_name, "_") + newnames <- glue::glue("{name_prefix}{col_names}") + ## ensure no name clashes + new_data_names <- colnames(new_data) + intersection <- new_data_names %in% newnames + if (any(intersection)) { + nms <- new_data_names[intersection] + cli_abort( + c("In `step_epi_slide()` a name collision occurred. The following variable name{?s} already exist{?/s}:", + `*` = "{.var {nms}}" + ), + call = caller_env(), + class = "epipredict__step__name_collision_error" + ) + } + # TODO: Uncomment this whenever we make the optimized versions available. + # if (any(vapply(c(mean, sum), \(x) identical(x, object$.f), logical(1L)))) { + # cli_warn( + # c( + # "There is an optimized version of both mean and sum. See `step_epi_slide_mean`, `step_epi_slide_sum`, + # or `step_epi_slide_opt`." + # ), + # class = "epipredict__step_epi_slide__optimized_version" + # ) + # } + epi_slide_wrapper( + new_data, + object$.window_size, + object$.align, + object$columns, + c(object$.f), + object$f_name, + kill_time_value(object$keys), + object$prefix + ) +} + + +#' Wrapper to handle epi_slide particulars +#' +#' @description +#' This should simplify somewhat in the future when we can run `epi_slide` on +#' columns. Surprisingly, lapply is several orders of magnitude faster than +#' using roughly equivalent tidy select style. +#' +#' @param fns vector of functions, even if it's length 1. +#' @param group_keys the keys to group by. likely `epi_keys` (without `time_value`) +#' +#' @importFrom tidyr crossing +#' @importFrom dplyr bind_cols group_by ungroup +#' @importFrom epiprocess epi_slide +#' @keywords internal +epi_slide_wrapper <- function(new_data, .window_size, .align, columns, fns, fn_names, group_keys, name_prefix) { + cols_fns <- tidyr::crossing(col_name = columns, fn_name = fn_names, fn = fns) + # Iterate over the rows of cols_fns. For each row number, we will output a + # transformed column. The first result returns all the original columns along + # with the new column. The rest just return the new column. + seq_len(nrow(cols_fns)) %>% + lapply(function(comp_i) { + col_name <- cols_fns[[comp_i, "col_name"]] + fn_name <- cols_fns[[comp_i, "fn_name"]] + fn <- cols_fns[[comp_i, "fn"]][[1L]] + result_name <- paste(name_prefix, fn_name, col_name, sep = "_") + result <- new_data %>% + group_by(across(all_of(group_keys))) %>% + epi_slide( + .window_size = .window_size, + .align = .align, + .new_col_name = result_name, + .f = function(slice, geo_key, ref_time_value) { + fn(slice[[col_name]]) + } + ) %>% + ungroup() + + if (comp_i == 1L) { + result + } else { + result[result_name] + } + }) %>% + bind_cols() +} + + +#' @export +print.step_epi_slide <- function(x, width = max(20, options()$width - 30), ...) { + print_epi_step( + x$columns, x$terms, x$trained, + title = "Calculating epi_slide for ", + conjunction = "with", extra_text = x$f_name + ) + invisible(x) +} + +#' Create short function names +#' +#' @param .f a function, character string, or lambda. For example, `mean`, +#' `"mean"`, `~ mean(.x)` or `\(x) mean(x, na.rm = TRUE)`. +#' @param max_length integer determining how long names can be +#' +#' @return a character string of length at most `max_length` that +#' (partially) describes the function. +#' @export +#' +#' @examples +#' clean_f_name(mean) +#' clean_f_name("mean") +#' clean_f_name(~ mean(.x, na.rm = TRUE)) +#' clean_f_name(\(x) mean(x, na.rm = TRUE)) +#' clean_f_name(function(x) mean(x, na.rm = TRUE, trim = 0.2357862)) +clean_f_name <- function(.f, max_length = 20L) { + if (rlang::is_formula(.f, scoped = TRUE)) { + f_name <- rlang::f_name(.f) + } else if (rlang::is_character(.f)) { + f_name <- .f + } else if (rlang::is_function(.f)) { + f_name <- as.character(substitute(.f)) + if (length(f_name) > 1L) { + f_name <- f_name[3] + if (nchar(f_name) > max_length - 5L) { + f_name <- paste0(substr(f_name, 1L, max(max_length - 8L, 5L)), "...") + } + f_name <- paste0("[ ]{", f_name, "}") + } + } + if (nchar(f_name) > max_length) { + f_name <- paste0(substr(f_name, 1L, max_length - 3L), "...") + } + f_name +} + + +validate_slide_fun <- function(.f) { + if (rlang::quo(.f) %>% rlang::quo_is_missing()) { + cli_abort("In, `step_epi_slide()`, `.f` may not be missing.") + } + if (rlang::is_formula(.f, scoped = TRUE)) { + cli_abort("In, `step_epi_slide()`, `.f` cannot be a formula.") + } else if (rlang::is_character(.f)) { + .f <- rlang::as_function(.f) + } else if (!rlang::is_function(.f)) { + cli_abort("In, `step_epi_slide()`, `.f` must be a function.") + } + .f +} diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index 74cfff284..06f8da4cf 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -8,13 +8,11 @@ #' @param horizon Bandwidth for the sliding window, when `method` is #' "rel_change" or "linear_reg". See [epiprocess::growth_rate()] for more #' details. -#' @param method Either "rel_change", "linear_reg", "smooth_spline", or -#' "trend_filter", indicating the method to use for the growth rate -#' calculation. The first two are local methods: they are run in a sliding +#' @param method Either "rel_change" or "linear_reg", +#' indicating the method to use for the growth rate +#' calculation. These are local methods: they are run in a sliding #' fashion over the sequence (in order to estimate derivatives and hence -#' growth rates); the latter two are global methods: they are run once over -#' the entire sequence. See [epiprocess::growth_rate()] for more -#' details. +#' growth rates). See [epiprocess::growth_rate()] for more details. #' @param log_scale Should growth rates be estimated using the parameterization #' on the log scale? See details for an explanation. Default is `FALSE`. #' @param replace_Inf Sometimes, the growth rate calculation can result in @@ -39,64 +37,55 @@ #' r #' #' r %>% -#' recipes::prep() %>% -#' recipes::bake(case_death_rate_subset) +#' prep(case_death_rate_subset) %>% +#' bake(case_death_rate_subset) step_growth_rate <- function(recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, replace_Inf = NA, prefix = "gr_", - columns = NULL, skip = FALSE, id = rand_id("growth_rate"), additional_gr_args_list = list()) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") } - method <- match.arg(method) + method <- rlang::arg_match(method) arg_is_pos_int(horizon) arg_is_scalar(horizon) if (!is.null(replace_Inf)) { - if (length(replace_Inf) != 1L) rlang::abort("replace_Inf must be a scalar.") + if (length(replace_Inf) != 1L) cli_abort("replace_Inf must be a scalar.") if (!is.na(replace_Inf)) arg_is_numeric(replace_Inf) } arg_is_chr(role) arg_is_chr_scalar(prefix, id) - arg_is_lgl_scalar(trained, log_scale, skip) + arg_is_lgl_scalar(log_scale, skip) if (!is.list(additional_gr_args_list)) { - rlang::abort( - c("`additional_gr_args_list` must be a list.", - i = "See `?epiprocess::growth_rate` for available options." - ) - ) - } - - if (!is.null(columns)) { - rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to use." + cli_abort(c( + "`additional_gr_args_list` must be a {.cls list}.", + i = "See `?epiprocess::growth_rate` for available options." )) } - add_step( + recipes::add_step( recipe, step_growth_rate_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, - trained = trained, + trained = FALSE, horizon = horizon, method = method, log_scale = log_scale, replace_Inf = replace_Inf, prefix = prefix, - keys = epi_keys(recipe), - columns = columns, + keys = key_colnames(recipe), + columns = NULL, skip = skip, id = id, additional_gr_args_list = additional_gr_args_list @@ -119,7 +108,7 @@ step_growth_rate_new <- skip, id, additional_gr_args_list) { - step( + recipes::step( subclass = "growth_rate", terms = terms, role = role, @@ -151,7 +140,7 @@ prep.step_growth_rate <- function(x, training, info = NULL, ...) { replace_Inf = x$replace_Inf, prefix = x$prefix, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id, additional_gr_args_list = x$additional_gr_args_list @@ -170,24 +159,23 @@ bake.step_growth_rate <- function(object, new_data, ...) { new_data_names <- colnames(new_data) intersection <- new_data_names %in% newnames if (any(intersection)) { - rlang::abort( - c(paste0("Name collision occured in `", class(object)[1], "`."), - i = paste( - "The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + nms <- new_data_names[intersection] + cli_abort( + c("In `step_growth_rate()` a name collision occurred. The following variable name{?s} already exist{?/s}:", + `*` = "{.var {nms}}" + ), + call = caller_env(), + class = "epipredict__step__name_collision_error" ) } ok <- object$keys gr <- new_data %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% dplyr::transmute( time_value = time_value, - dplyr::across( - dplyr::all_of(object$columns), + across( + all_of(object$columns), ~ epiprocess::growth_rate( time_value, .x, method = object$method, @@ -197,23 +185,19 @@ bake.step_growth_rate <- function(object, new_data, ...) { .names = "{object$prefix}{object$horizon}_{object$method}_{.col}" ) ) %>% - dplyr::ungroup() %>% - dplyr::mutate(time_value = time_value + object$horizon) # shift x0 right + ungroup() %>% + mutate(time_value = time_value + object$horizon) # shift x0 right + if (!is.null(object$replace_Inf)) { gr <- gr %>% - dplyr::mutate( - dplyr::across( - !dplyr::all_of(ok), - ~ vec_replace_inf(.x, object$replace_Inf) - ) - ) + mutate(across(!all_of(ok), ~ vec_replace_inf(.x, object$replace_Inf))) } - dplyr::left_join(new_data, gr, by = ok) %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + left_join(new_data, gr, by = ok) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% + arrange(time_value) %>% + ungroup() } diff --git a/R/step_lag_difference.R b/R/step_lag_difference.R index 21878eaa7..39ae1ba59 100644 --- a/R/step_lag_difference.R +++ b/R/step_lag_difference.R @@ -16,48 +16,40 @@ #' @export #' @examples #' r <- epi_recipe(case_death_rate_subset) %>% -#' step_lag_difference(case_rate, death_rate, horizon = c(7, 14)) +#' step_lag_difference(case_rate, death_rate, horizon = c(7, 14)) %>% +#' step_epi_naomit() #' r #' #' r %>% -#' recipes::prep() %>% -#' recipes::bake(case_death_rate_subset) +#' prep(case_death_rate_subset) %>% +#' bake(case_death_rate_subset) step_lag_difference <- function(recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, prefix = "lag_diff_", - columns = NULL, skip = FALSE, id = rand_id("lag_diff")) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") } arg_is_pos_int(horizon) arg_is_chr(role) arg_is_chr_scalar(prefix, id) - arg_is_lgl_scalar(trained, skip) + arg_is_lgl_scalar(skip) - if (!is.null(columns)) { - rlang::abort( - c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to use." - ) - ) - } - add_step( + recipes::add_step( recipe, step_lag_difference_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, - trained = trained, + trained = FALSE, horizon = horizon, prefix = prefix, - keys = epi_keys(recipe), - columns = columns, + keys = key_colnames(recipe), + columns = NULL, skip = skip, id = id ) @@ -75,7 +67,7 @@ step_lag_difference_new <- columns, skip, id) { - step( + recipes::step( subclass = "lag_difference", terms = terms, role = role, @@ -100,7 +92,7 @@ prep.step_lag_difference <- function(x, training, info = NULL, ...) { horizon = x$horizon, prefix = x$prefix, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) @@ -108,47 +100,46 @@ prep.step_lag_difference <- function(x, training, info = NULL, ...) { epi_shift_single_diff <- function(x, col, horizon, newname, key_cols) { - x <- x %>% dplyr::select(tidyselect::all_of(c(key_cols, col))) + x <- x %>% select(all_of(c(key_cols, col))) y <- x %>% - dplyr::mutate(time_value = time_value + horizon) %>% - dplyr::rename(!!newname := {{ col }}) - x <- dplyr::left_join(x, y, by = key_cols) + mutate(time_value = time_value + horizon) %>% + rename(!!newname := {{ col }}) + x <- left_join(x, y, by = key_cols) x[, newname] <- x[, col] - x[, newname] - x %>% dplyr::select(tidyselect::all_of(c(key_cols, newname))) + x %>% select(all_of(c(key_cols, newname))) } #' @export bake.step_lag_difference <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, horizon = object$horizon) %>% - dplyr::mutate(newname = glue::glue("{object$prefix}{horizon}_{col}")) + mutate(newname = glue::glue("{object$prefix}{horizon}_{col}")) ## ensure no name clashes new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { - rlang::abort( - c(paste0("Name collision occured in `", class(object)[1], "`."), - i = paste( - "The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + nms <- new_data_names[intersection] + cli_abort( + c("In `step_lag_difference()` a name collision occurred. The following variable name{?s} already exist{?/s}:", + `*` = "{.var {nms}}" + ), + call = caller_env(), + class = "epipredict__step__name_collision_error" ) } ok <- object$keys shifted <- reduce( pmap(grid, epi_shift_single_diff, x = new_data, key_cols = ok), - dplyr::full_join, + full_join, by = ok ) - dplyr::left_join(new_data, shifted, by = ok) %>% - dplyr::group_by(dplyr::across(tidyselect::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + left_join(new_data, shifted, by = ok) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% + arrange(time_value) %>% + ungroup() } diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index ce87ea759..4e4d3aa26 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -11,17 +11,7 @@ #' passed will *divide* the selected variables while the `rate_rescaling` #' argument is a common *multiplier* of the selected variables. #' -#' @param recipe A recipe object. The step will be added to the sequence of -#' operations for this recipe. The recipe should contain information about the -#' `epi_df` such as column names. -#' @param ... One or more selector functions to scale variables -#' for this step. See [recipes::selections()] for more details. -#' @param role For model terms created by this step, what analysis role should -#' they be assigned? By default, the new columns created by this step from the -#' original variables will be used as predictors in a model. Other options can -#' be ard are not limited to "outcome". -#' @param trained A logical to indicate if the quantities for preprocessing -#' have been estimated. +#' @inheritParams step_epi_lag #' @param df a data frame that contains the population data to be used for #' inverting the existing scaling. #' @param by A (possibly named) character vector of variables to join by. @@ -49,25 +39,15 @@ #' @param create_new TRUE to create a new column and keep the original column #' in the `epi_df` #' @param suffix a character. The suffix added to the column name if -#' `crete_new = TRUE`. Default to "_scaled". -#' @param columns A character string of variable names that will -#' be populated (eventually) by the `terms` argument. -#' @param skip A logical. Should the step be skipped when the -#' recipe is baked by [bake()]? While all operations are baked -#' when [prep()] is run, some operations may not be able to be -#' conducted on new data (e.g. processing the outcome variable(s)). -#' Care should be taken when using `skip = TRUE` as it may affect -#' the computations for subsequent operations. -#' @param id A unique identifier for the step +#' `create_new = TRUE`. Default to "_scaled". #' #' @return Scales raw data by the population #' @export #' @examples -#' library(epiprocess) -#' library(epipredict) -#' jhu <- epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% -#' dplyr::select(geo_value, time_value, cases) +#' library(dplyr) +#' jhu <- jhu_csse_daily_subset %>% +#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% +#' select(geo_value, time_value, cases) #' #' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) #' @@ -92,57 +72,44 @@ #' df_pop_col = "value" #' ) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% +#' wf <- epi_workflow(r, linear_reg()) %>% #' fit(jhu) %>% #' add_frosting(f) #' -#' latest <- get_test_data( -#' recipe = r, -#' epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter( -#' time_value > "2021-11-01", -#' geo_value %in% c("ca", "ny") -#' ) %>% -#' dplyr::select(geo_value, time_value, cases) -#' ) -#' -#' -#' predict(wf, latest) +#' forecast(wf) step_population_scaling <- function(recipe, ..., role = "raw", - trained = FALSE, df, by = NULL, df_pop_col, rate_rescaling = 1, create_new = TRUE, suffix = "_scaled", - columns = NULL, skip = FALSE, id = rand_id("population_scaling")) { - arg_is_scalar(role, trained, df_pop_col, rate_rescaling, create_new, suffix, id) + arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id) arg_is_lgl(create_new, skip) arg_is_chr(df_pop_col, suffix, id) - arg_is_chr(by, columns, allow_null = TRUE) + arg_is_chr(by, allow_null = TRUE) if (rate_rescaling <= 0) { - cli_stop("`rate_rescaling` should be a positive number") + cli_abort("`rate_rescaling` must be a positive number.") } - add_step( + recipes::add_step( recipe, step_population_scaling_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, - trained = trained, + trained = FALSE, df = df, by = by, df_pop_col = df_pop_col, rate_rescaling = rate_rescaling, create_new = create_new, suffix = suffix, - columns = columns, + columns = NULL, skip = skip, id = id ) @@ -152,7 +119,7 @@ step_population_scaling <- step_population_scaling_new <- function(role, trained, df, by, df_pop_col, rate_rescaling, terms, create_new, suffix, columns, skip, id) { - step( + recipes::step( subclass = "population_scaling", terms = terms, role = role, @@ -181,30 +148,21 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { rate_rescaling = x$rate_rescaling, create_new = x$create_new, suffix = x$suffix, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) } #' @export -bake.step_population_scaling <- function(object, - new_data, - ...) { - stopifnot( - "Only one population column allowed for scaling" = - length(object$df_pop_col) == 1 +bake.step_population_scaling <- function(object, new_data, ...) { + object$by <- object$by %||% intersect( + epi_keys_only(new_data), + colnames(select(object$df, !object$df_pop_col)) ) - - try_join <- try(dplyr::left_join(new_data, object$df, by = object$by), - silent = TRUE - ) - if (any(grepl("Join columns must be present in data", unlist(try_join)))) { - cli_stop(c( - "columns in `by` selectors of `step_population_scaling` ", - "must be present in data and match" - )) - } + joinby <- list(x = names(object$by) %||% object$by, y = object$by) + hardhat::validate_column_names(new_data, joinby$x) + hardhat::validate_column_names(object$df, joinby$y) if (object$suffix != "_scaled" && object$create_new == FALSE) { cli::cli_warn(c( @@ -213,23 +171,22 @@ bake.step_population_scaling <- function(object, )) } - object$df <- object$df %>% - dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) + object$df <- mutate(object$df, across(dplyr::where(is.character), tolower)) pop_col <- rlang::sym(object$df_pop_col) suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(new_data)) - dplyr::left_join(new_data, - object$df, - by = object$by, suffix = c("", ".df") - ) %>% - dplyr::mutate(dplyr::across(dplyr::all_of(object$columns), - ~ .x * object$rate_rescaling / !!pop_col, - .names = "{.col}{suffix}" - )) %>% + left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>% + mutate( + across( + all_of(object$columns), + ~ .x * object$rate_rescaling / !!pop_col, + .names = "{.col}{suffix}" + ) + ) %>% # removed so the models do not use the population column - dplyr::select(-dplyr::any_of(col_to_remove)) + select(!any_of(col_to_remove)) } #' @export diff --git a/R/step_training_window.R b/R/step_training_window.R index 7102d29d8..eafc076c7 100644 --- a/R/step_training_window.R +++ b/R/step_training_window.R @@ -5,18 +5,13 @@ #' observations in `time_value` per group, where the groups are formed #' based on the remaining `epi_keys`. #' -#' @param recipe A recipe object. The step will be added to the -#' sequence of operations for this recipe. -#' @param role Not used by this step since no new variables are created. -#' @param trained A logical to indicate if the quantities for -#' preprocessing have been estimated. #' @param n_recent An integer value that represents the number of most recent #' observations that are to be kept in the training window per key #' The default value is 50. #' @param epi_keys An optional character vector for specifying "key" variables #' to group on. The default, `NULL`, ensures that every key combination is #' limited. -#' @param id A character string that is unique to this step to identify it. +#' @inheritParams step_epi_lag #' @template step-return #' #' @details Note that `step_epi_lead()` and `step_epi_lag()` should come @@ -25,13 +20,10 @@ #' @export #' #' @examples -#' tib <- tibble::tibble( +#' tib <- tibble( #' x = 1:10, #' y = 1:10, -#' time_value = rep(seq(as.Date("2020-01-01"), -#' by = 1, -#' length.out = 5 -#' ), times = 2), +#' time_value = rep(seq(as.Date("2020-01-01"), by = 1, length.out = 5), 2), #' geo_value = rep(c("ca", "hi"), each = 5) #' ) %>% #' as_epi_df() @@ -42,18 +34,16 @@ #' bake(new_data = NULL) #' #' epi_recipe(y ~ x, data = tib) %>% -#' recipes::step_naomit() %>% +#' step_epi_naomit() %>% #' step_training_window(n_recent = 3) %>% #' prep(tib) %>% #' bake(new_data = NULL) step_training_window <- function(recipe, role = NA, - trained = FALSE, n_recent = 50, epi_keys = NULL, id = rand_id("training_window")) { - arg_is_lgl_scalar(trained) arg_is_scalar(n_recent, id) arg_is_pos(n_recent) if (is.finite(n_recent)) arg_is_pos_int(n_recent) @@ -63,7 +53,7 @@ step_training_window <- recipe, step_training_window_new( role = role, - trained = trained, + trained = FALSE, n_recent = n_recent, epi_keys = epi_keys, skip = TRUE, @@ -87,7 +77,7 @@ step_training_window_new <- #' @export prep.step_training_window <- function(x, training, info = NULL, ...) { - ekt <- kill_time_value(epi_keys(training)) + ekt <- epi_keys_only(training) ek <- x$epi_keys %||% ekt %||% character(0L) hardhat::validate_column_names(training, ek) @@ -108,10 +98,10 @@ bake.step_training_window <- function(object, new_data, ...) { if (object$n_recent < Inf) { new_data <- new_data %>% - dplyr::group_by(dplyr::across(dplyr::all_of(object$epi_keys))) %>% - dplyr::arrange(time_value) %>% + group_by(across(all_of(object$epi_keys))) %>% + arrange(time_value) %>% dplyr::slice_tail(n = object$n_recent) %>% - dplyr::ungroup() + ungroup() } new_data @@ -122,10 +112,7 @@ print.step_training_window <- function(x, width = max(20, options()$width - 30), ...) { title <- "# of recent observations per key limited to:" n_recent <- x$n_recent - tr_obj <- format_selectors(rlang::enquos(n_recent), width) - recipes::print_step( - tr_obj, rlang::enquos(n_recent), - x$trained, title, width - ) + tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width) + recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width) invisible(x) } diff --git a/R/tidy.R b/R/tidy.R index 06835eff0..61b298411 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -26,8 +26,9 @@ #' `type` (the method, e.g. "predict", "naomit"), and a character column `id`. #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -52,21 +53,17 @@ tidy.frosting <- function(x, number = NA, id = NA, ...) { num_oper <- length(x$layers) pattern <- "^layer_" - if (length(id) != 1L) { - rlang::abort("If `id` is provided, it must be a length 1 character vector.") - } - - if (length(number) != 1L) { - rlang::abort("If `number` is provided, it must be a length 1 integer vector.") - } + arg_is_chr_scalar(id, allow_na = TRUE) + arg_is_scalar(number, allow_na = TRUE) + if (!is.na(number)) arg_is_int(number) if (!is.na(id)) { if (!is.na(number)) { - rlang::abort("You may specify `number` or `id`, but not both.") + cli_abort("You may specify `number` or `id`, but not both.") } layer_ids <- vapply(x$layers, function(x) x$id, character(1)) if (!(id %in% layer_ids)) { - rlang::abort("Supplied `id` not found in the frosting.") + cli_abort("Supplied `id` not found in the frosting.") } number <- which(id == layer_ids) } @@ -89,13 +86,7 @@ tidy.frosting <- function(x, number = NA, id = NA, ...) { ) } else { if (number > num_oper || length(number) > 1) { - rlang::abort( - paste0( - "`number` should be a single value between 1 and ", - num_oper, - "." - ) - ) + cli_abort("`number` should be a single value between 1 and {num_oper}.") } res <- tidy(x$layers[[number]], ...) diff --git a/R/time_types.R b/R/time_types.R new file mode 100644 index 000000000..f33974833 --- /dev/null +++ b/R/time_types.R @@ -0,0 +1,73 @@ +guess_time_type <- function(time_value) { + # similar to epiprocess:::guess_time_type() but w/o the gap handling + arg_is_scalar(time_value) + if (is.character(time_value)) { + if (nchar(time_value) <= "10") { + new_time_value <- tryCatch( + { + as.Date(time_value) + }, + error = function(e) NULL + ) + } else { + new_time_value <- tryCatch( + { + as.POSIXct(time_value) + }, + error = function(e) NULL + ) + } + if (!is.null(new_time_value)) time_value <- new_time_value + } + if (inherits(time_value, "POSIXct")) { + return("day-time") + } + if (inherits(time_value, "Date")) { + return("day") + } + if (inherits(time_value, "yearweek")) { + return("yearweek") + } + if (inherits(time_value, "yearmonth")) { + return("yearmonth") + } + if (inherits(time_value, "yearquarter")) { + return("yearquarter") + } + if (is.numeric(time_value) && all(time_value == as.integer(time_value)) && + all(time_value >= 1582)) { + return("year") + } + return("custom") +} + +coerce_time_type <- function(x, target_type) { + if (target_type == "year") { + if (is.numeric(x)) { + return(as.integer(x)) + } else { + return(as.POSIXlt(x)$year + 1900L) + } + } + switch(target_type, + "day-time" = as.POSIXct(x), + "day" = as.Date(x), + "week" = as.Date(x), + "yearweek" = tsibble::yearweek(x), + "yearmonth" = tsibble::yearmonth(x), + "yearquarter" = tsibble::yearquarter(x) + ) +} + +validate_date <- function(x, expected, arg = rlang::caller_arg(x), + call = rlang::caller_env()) { + time_type_x <- guess_time_type(x) + ok <- time_type_x == expected + if (!ok) { + cli_abort(c( + "The {.arg {arg}} was given as a {.val {time_type_x}} while the", + `!` = "`time_type` of the training data was {.val {expected}}.", + i = "See {.topic epiprocess::epi_df} for descriptions of these are determined." + ), call = call) + } +} diff --git a/R/utils-arg.R b/R/utils-arg.R index 091987722..081d153fb 100644 --- a/R/utils-arg.R +++ b/R/utils-arg.R @@ -2,204 +2,200 @@ # http://adv-r.had.co.nz/Computing-on-the-language.html#substitute # Modeled after / copied from rundel/ghclass -handle_arg_list <- function(..., tests) { +handle_arg_list <- function(..., .tests) { values <- list(...) names <- eval(substitute(alist(...))) names <- map(names, deparse) - walk2(names, values, tests) -} - -arg_is_scalar <- function(..., allow_null = FALSE, allow_na = FALSE) { - handle_arg_list( - ..., - tests = function(name, value) { - if (length(value) > 1 | (!allow_null & length(value) == 0)) { - cli::cli_abort("Argument {.val {name}} must be of length 1.") - } - if (!is.null(value)) { - if (is.na(value) & !allow_na) { - cli::cli_abort( - "Argument {.val {name}} must not be a missing value ({.val {NA}})." - ) - } - } - } - ) -} - - -arg_is_lgl <- function(..., allow_null = FALSE, allow_na = FALSE, allow_empty = FALSE) { - handle_arg_list( - ..., - tests = function(name, value) { - if (is.null(value) & !allow_null) { - cli::cli_abort("Argument {.val {name}} must be of logical type.") - } - if (any(is.na(value)) & !allow_na) { - cli::cli_abort("Argument {.val {name}} must not contain any missing values ({.val {NA}}).") - } - if (!is.null(value) & (length(value) == 0 & !allow_empty)) { - cli::cli_abort("Argument {.val {name}} must have length >= 1.") - } - if (!is.null(value) & length(value) != 0 & !is.logical(value)) { - cli::cli_abort("Argument {.val {name}} must be of logical type.") - } - } - ) -} - -arg_is_lgl_scalar <- function(..., allow_null = FALSE, allow_na = FALSE) { - arg_is_lgl(..., allow_null = allow_null, allow_na = allow_na) - arg_is_scalar(..., allow_null = allow_null, allow_na = allow_na) + walk2(names, values, .tests) } -arg_is_numeric <- function(..., allow_null = FALSE) { - handle_arg_list( - ..., - tests = function(name, value) { - if (!(is.numeric(value) | (is.null(value) & allow_null))) { - cli::cli_abort("All {.val {name}} must numeric.") - } +arg_is_scalar <- function(..., allow_null = FALSE, allow_na = FALSE, + call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_scalar(value, null.ok = allow_null, na.ok = allow_na) + if (!ok) { + cli_abort("{.arg {name}} must be a scalar.", call = call) } - ) -} - -arg_is_pos <- function(..., allow_null = FALSE) { - arg_is_numeric(..., allow_null = allow_null) - handle_arg_list( - ..., - tests = function(name, value) { - if (!(all(value > 0) | (is.null(value) & allow_null))) { - cli::cli_abort("All {.val {name}} must be positive number(s).") - } + }) +} + +arg_is_lgl <- function(..., allow_null = FALSE, allow_na = FALSE, + allow_empty = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_logical(value, + null.ok = allow_null, any.missing = allow_na, + min.len = as.integer(!allow_empty) + ) + if (!ok) { + cli_abort("{.arg {name}} must be of type {.cls logical}.", call = call) } - ) -} - -arg_is_nonneg <- function(..., allow_null = FALSE) { - arg_is_numeric(..., allow_null = allow_null) - handle_arg_list( - ..., - tests = function(name, value) { - if (!(all(value >= 0) | (is.null(value) & allow_null))) { - cli::cli_abort("All {.val {name}} must be nonnegative number(s).") - } + }) +} + +arg_is_lgl_scalar <- function(..., allow_null = FALSE, allow_na = FALSE, + call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_logical(value, + null.ok = allow_null, any.missing = allow_na, + min.len = 1, max.len = 1 + ) + if (!ok) { + cli_abort( + "{.arg {name}} must be a scalar of type {.cls logical}.", + call = call + ) } - ) + }) } -arg_is_int <- function(..., allow_null = FALSE) { - arg_is_numeric(..., allow_null = allow_null) - handle_arg_list( - ..., - tests = function(name, value) { - if (!(all(value %% 1 == 0) | (is.null(value) & allow_null))) { - cli::cli_abort("All {.val {name}} must be whole positive number(s).") - } +arg_is_numeric <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_numeric(value, null.ok = allow_null, any.missing = FALSE) + if (!ok) { + cli_abort("{.arg {name}} must be of type {.cls numeric}.", call = call) } - ) -} - -arg_is_pos_int <- function(..., allow_null = FALSE) { - arg_is_int(..., allow_null = allow_null) - arg_is_pos(..., allow_null = allow_null) -} - - -arg_is_nonneg_int <- function(..., allow_null = FALSE) { - arg_is_int(..., allow_null = allow_null) - arg_is_nonneg(..., allow_null = allow_null) -} - -arg_is_date <- function(..., allow_null = FALSE, allow_na = FALSE) { - handle_arg_list( - ..., - tests = function(name, value) { - if (is.null(value) & !allow_null) { - cli::cli_abort("Argument {.val {name}} may not be `NULL`.") - } - if (any(is.na(value)) & !allow_na) { - cli::cli_abort("Argument {.val {name}} must not contain any missing values ({.val {NA}}).") - } - if (!(is(value, "Date") | is.null(value) | all(is.na(value)))) { - cli::cli_abort("Argument {.val {name}} must be a Date. Try `as.Date()`.") - } + }) +} + +arg_is_pos <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_numeric( + value, + lower = .Machine$double.eps, + null.ok = allow_null, any.missing = FALSE + ) + if (!ok) { + len <- length(value) + cli_abort( + "{.arg {name}} must be {cli::qty(len)} {?a/} strictly positive number{?s}.", + call = call + ) } - ) -} - -arg_is_probabilities <- function(..., allow_null = FALSE) { - arg_is_numeric(..., allow_null = allow_null) - handle_arg_list( - ..., - tests = function(name, value) { - if (!((all(value >= 0) && all(value <= 1)) | (is.null(value) & allow_null))) { - cli::cli_abort("All {.val {name}} must be in [0,1].") - } + }) +} + +arg_is_nonneg <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_numeric(value, lower = 0, null.ok = allow_null, any.missing = FALSE) + if (!ok) { + len <- length(value) + cli_abort( + "{.arg {name}} must be {cli::qty(len)} {?a/} non-negative number{?s}.", + call = call + ) } - ) -} - -arg_is_chr <- function(..., allow_null = FALSE, allow_na = FALSE, allow_empty = FALSE) { - handle_arg_list( - ..., - tests = function(name, value) { - if (is.null(value) & !allow_null) { - cli::cli_abort("Argument {.val {name}} may not be `NULL`.") - } - if (any(is.na(value)) & !allow_na) { - cli::cli_abort("Argument {.val {name}} must not contain any missing values ({.val {NA}}).") - } - if (!is.null(value) & (length(value) == 0L & !allow_empty)) { - cli::cli_abort("Argument {.val {name}} must have length > 0.") - } - if (!(is.character(value) | is.null(value) | all(is.na(value)))) { - cli::cli_abort("Argument {.val {name}} must be of character type.") - } + }) +} + +arg_is_int <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_integerish(value, null.ok = allow_null) + if (!ok) { + len <- length(value) + cli_abort( + "{.arg {name}} must be {cli::qty(len)} {?a/} integer{?s}.", + call = call + ) } - ) -} - -arg_is_chr_scalar <- function(..., allow_null = FALSE, allow_na = FALSE) { - arg_is_chr(..., allow_null = allow_null, allow_na = allow_na) - arg_is_scalar(..., allow_null = allow_null, allow_na = allow_na) -} - - -arg_is_function <- function(..., allow_null = FALSE) { - handle_arg_list( - ..., - tests = function(name, value) { - if (is.null(value) & !allow_null) { - cli::cli_abort("Argument {.val {name}} must be a function.") - } - if (!is.null(value) & !is.function(value)) { - cli::cli_abort("Argument {.val {name}} must be a function.") - } + }) +} + +arg_is_pos_int <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_integerish(value, null.ok = allow_null, lower = 1, any.missing = FALSE) + if (!ok) { + len <- length(value) + cli_abort( + "{.arg {name}} must be {cli::qty(len)} {?a/} positive integer{?s}.", + call = call + ) + } + }) +} + +arg_is_nonneg_int <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_integerish(value, null.ok = allow_null, lower = 0, any.missing = FALSE) + if (!ok) { + len <- length(value) + cli_abort( + "{.arg {name}} must be {cli::qty(len)} {?a/} non-negative integer{?s}.", + call = call + ) } - ) + }) +} + +arg_is_date <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_date(value, null.ok = allow_null) + if (!ok) { + len <- length(value) + cli_abort( + "{.arg {name}} must be {cli::qty(len)} {?a/} date{?s}.", + call = call + ) + } + }) +} + +arg_is_probabilities <- function(..., allow_null = FALSE, allow_na = FALSE, + call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_numeric(value, + lower = 0, upper = 1, null.ok = allow_null, + any.missing = allow_na + ) + if (!ok) { + cli_abort("{.arg {name}} must lie in [0, 1].", call = call) + } + }) +} + +arg_is_chr <- function(..., allow_null = FALSE, allow_na = FALSE, allow_empty = FALSE, + call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_character(value, + null.ok = allow_null, any.missing = allow_na, + min.len = as.integer(!allow_empty) + ) + if (!ok) { + cli_abort("{.arg {name}} must be of type {.cls character}.", call = call) + } + }) +} + +arg_is_chr_scalar <- function(..., allow_null = FALSE, allow_na = FALSE, + call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_character(value, + null.ok = allow_null, any.missing = allow_na, + len = 1L + ) + if (!ok) { + cli_abort( + "{.arg {name}} must be a scalar of type {.cls character}.", + call = call + ) + } + }) } - - -arg_is_sorted <- function(..., allow_null = FALSE) { - handle_arg_list( - ..., - tests = function(name, value) { - if (is.unsorted(value, na.rm = TRUE) | (is.null(value) & !allow_null)) { - cli::cli_abort("{.val {name}} must be sorted in increasing order.") - } +arg_is_function <- function(..., allow_null = FALSE, call = caller_env()) { + handle_arg_list(..., .tests = function(name, value) { + ok <- test_function(value, null.ok = allow_null) + if (!ok) { + cli_abort("{.arg {name}} must be of type {.cls function}.", call = call) } - ) + }) } - -arg_to_date <- function(x, allow_null = FALSE, allow_na = FALSE) { - arg_is_scalar(x, allow_null = allow_null, allow_na = allow_na) +arg_to_date <- function(x, allow_null = FALSE) { + arg_is_scalar(x, allow_null = allow_null) if (!is.null(x)) { x <- tryCatch(as.Date(x, origin = "1970-01-01"), error = function(e) NA) } - arg_is_date(x, allow_null = allow_null, allow_na = allow_na) + arg_is_date(x, allow_null = allow_null) x } diff --git a/R/utils-cli.R b/R/utils-cli.R deleted file mode 100644 index ad43c95eb..000000000 --- a/R/utils-cli.R +++ /dev/null @@ -1,28 +0,0 @@ -# Modeled after / copied from rundel/ghclass -cli_glue <- function(..., .envir = parent.frame()) { - txt <- cli::cli_format_method(cli::cli_text(..., .envir = .envir)) - - # cli_format_method does wrapping which we dont want at this stage - # so glue things back together. - paste(txt, collapse = " ") -} - -cli_stop <- function(..., .envir = parent.frame()) { - text <- cli_glue(..., .envir = .envir) - stop(paste(text, collapse = "\n"), call. = FALSE) -} - -cli_warn <- function(..., .envir = parent.frame()) { - text <- cli_glue(..., .envir = .envir) - warning(paste(text, collapse = "\n"), call. = FALSE) -} - -#' @importFrom rlang caller_env -glubort <- - function(..., .sep = "", .envir = caller_env(), .call = .envir) { - rlang::abort(glue::glue(..., .sep = .sep, .envir = .envir), call = .call) - } - -cat_line <- function(...) { - cat(paste0(..., collapse = "\n"), "\n", sep = "") -} diff --git a/R/utils-misc.R b/R/utils-misc.R index 18f6380df..b4d1c28b7 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -32,29 +32,27 @@ check_pname <- function(res, preds, object, newname = NULL) { } -grab_forged_keys <- function(forged, mold, new_data) { - keys <- c("time_value", "geo_value", "key") +grab_forged_keys <- function(forged, workflow, new_data) { forged_roles <- names(forged$extras$roles) - extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% keys]) + extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")]) # 1. these are the keys in the test data after prep/bake new_keys <- names(extras) # 2. these are the keys in the training data - old_keys <- epi_keys_mold(mold) + old_keys <- key_colnames(workflow) # 3. these are the keys in the test data as input - new_df_keys <- epi_keys(new_data, extra_keys = setdiff(new_keys, keys[1:2])) + new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value"))) if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) { cli::cli_warn(c( "Not all epi keys that were present in the training data are available", "in `new_data`. Predictions will have only the available keys." )) } - if (epiprocess::is_epi_df(new_data)) { - extras <- epiprocess::as_epi_df(extras) - attr(extras, "metadata") <- attr(new_data, "metadata") - } else if (all(keys[1:2] %in% new_keys)) { - l <- list() - if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)]) - extras <- epiprocess::as_epi_df(extras, additional_metadata = l) + if (is_epi_df(new_data)) { + meta <- attr(new_data, "metadata") + extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character()) + } else if (all(c("geo_value", "time_value") %in% new_keys)) { + if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")] + extras <- as_epi_df(extras, other_keys = other_keys %||% character()) } extras } @@ -64,11 +62,10 @@ get_parsnip_mode <- function(trainer) { return(trainer$mode) } cc <- class(trainer) - cli::cli_abort( - c("`trainer` must be a `parsnip` model.", - i = "This trainer has class(s) {.cls {cc}}." - ) - ) + cli_abort(c( + "`trainer` must be a `parsnip` model.", + i = "This trainer has class{?s}: {.cls {cc}}." + )) } is_classification <- function(trainer) { diff --git a/R/weighted_interval_score.R b/R/weighted_interval_score.R new file mode 100644 index 000000000..cd67bbee9 --- /dev/null +++ b/R/weighted_interval_score.R @@ -0,0 +1,147 @@ +#' Compute weighted interval score +#' +#' Weighted interval score (WIS), a well-known quantile-based +#' approximation of the commonly-used continuous ranked probability score +#' (CRPS). WIS is a proper score, and can be thought of as a distributional +#' generalization of absolute error. For example, see [Bracher et +#' al. (2020)](https://arxiv.org/abs/2005.12881) for discussion in the context +#' of COVID-19 forecasting. +#' +#' @param x distribution. A vector of class distribution. Ideally, this vector +#' contains `dist_quantiles()`, though other distributions are supported when +#' `quantile_levels` is specified. See below. +#' @param actual double. Actual value(s) +#' @param quantile_levels probabilities. If specified, the score will be +#' computed at this set of levels. +#' @param ... not used +#' +#' @return a vector of nonnegative scores. +#' +#' @export +#' @examples +#' quantile_levels <- c(.2, .4, .6, .8) +#' predq_1 <- 1:4 # +#' predq_2 <- 8:11 +#' dstn <- dist_quantiles(list(predq_1, predq_2), quantile_levels) +#' actual <- c(3.3, 7.1) +#' weighted_interval_score(dstn, actual) +#' weighted_interval_score(dstn, actual, c(.25, .5, .75)) +#' +#' library(distributional) +#' dstn <- dist_normal(c(.75, 2)) +#' weighted_interval_score(dstn, 1, c(.25, .5, .75)) +#' +#' # Missing value behaviours +#' dstn <- dist_quantiles(c(1, 2, NA, 4), 1:4 / 5) +#' weighted_interval_score(dstn, 2.5) +#' weighted_interval_score(dstn, 2.5, 1:9 / 10) +#' weighted_interval_score(dstn, 2.5, 1:9 / 10, na_handling = "drop") +#' weighted_interval_score(dstn, 2.5, na_handling = "propagate") +#' weighted_interval_score(dist_quantiles(1:4, 1:4 / 5), 2.5, 1:9 / 10, +#' na_handling = "fail" +#' ) +#' +#' +#' # Using some actual forecasts -------- +#' library(dplyr) +#' jhu <- case_death_rate_subset %>% +#' filter(time_value >= "2021-10-01", time_value <= "2021-12-01") +#' preds <- flatline_forecaster( +#' jhu, "death_rate", +#' flatline_args_list(quantile_levels = c(.01, .025, 1:19 / 20, .975, .99)) +#' )$predictions +#' actuals <- case_death_rate_subset %>% +#' filter(time_value == as.Date("2021-12-01") + 7) %>% +#' select(geo_value, time_value, actual = death_rate) +#' preds <- left_join(preds, actuals, +#' by = c("target_date" = "time_value", "geo_value") +#' ) %>% +#' mutate(wis = weighted_interval_score(.pred_distn, actual)) +#' preds +weighted_interval_score <- function(x, actual, quantile_levels = NULL, ...) { + UseMethod("weighted_interval_score") +} + +#' @export +weighted_interval_score.default <- function(x, actual, + quantile_levels = NULL, ...) { + cli_abort(c( + "Weighted interval score can only be calculated if `x`", + "has class {.cls distribution}." + )) +} + +#' @export +weighted_interval_score.distribution <- function( + x, actual, + quantile_levels = NULL, ...) { + assert_numeric(actual, finite = TRUE) + l <- vctrs::vec_recycle_common(x = x, actual = actual) + map2_dbl( + .x = vctrs::vec_data(l$x), + .y = l$actual, + .f = weighted_interval_score, + quantile_levels = quantile_levels, + ... + ) +} + +#' @export +weighted_interval_score.dist_default <- function(x, actual, + quantile_levels = NULL, ...) { + rlang::check_dots_empty() + if (is.null(quantile_levels)) { + cli_warn(c( + "Weighted interval score isn't implemented for {.cls {class(x)}}", + "as we don't know what set of quantile levels to use.", + "Use a {.cls dist_quantiles} or pass `quantile_levels`.", + "The result for this element will be `NA`." + )) + return(NA) + } + x <- extrapolate_quantiles(x, probs = quantile_levels) + weighted_interval_score(x, actual, quantile_levels = NULL) +} + +#' @param na_handling character. Determines how `quantile_levels` without a +#' corresponding `value` are handled. For `"impute"`, missing values will be +#' calculated if possible using the available quantiles. For `"drop"`, +#' explicitly missing values are ignored in the calculation of the score, but +#' implicitly missing values are imputed if possible. +#' For `"propogate"`, the resulting score will be `NA` if any missing values +#' exist in the original `quantile_levels`. Finally, if +#' `quantile_levels` is specified, `"fail"` will result in +#' the score being `NA` when any required quantile levels (implicit or explicit) +#' are do not have corresponding values. +#' @describeIn weighted_interval_score Weighted interval score with +#' `dist_quantiles` allows for different `NA` behaviours. +#' @export +weighted_interval_score.dist_quantiles <- function( + x, actual, + quantile_levels = NULL, + na_handling = c("impute", "drop", "propagate", "fail"), + ...) { + rlang::check_dots_empty() + if (is.na(actual)) { + return(NA) + } + if (all(is.na(vctrs::field(x, "values")))) { + return(NA) + } + na_handling <- rlang::arg_match(na_handling) + old_quantile_levels <- field(x, "quantile_levels") + if (na_handling == "fail") { + if (is.null(quantile_levels)) { + cli_abort('`na_handling = "fail"` requires `quantile_levels` to be specified.') + } + old_values <- field(x, "values") + if (!all(quantile_levels %in% old_quantile_levels) || any(is.na(old_values))) { + return(NA) + } + } + tau <- quantile_levels %||% old_quantile_levels + x <- extrapolate_quantiles(x, probs = tau, replace_na = (na_handling == "impute")) + q <- field(x, "values")[field(x, "quantile_levels") %in% tau] + na_rm <- (na_handling == "drop") + 2 * mean(pmax(tau * (actual - q), (1 - tau) * (q - actual)), na.rm = na_rm) +} diff --git a/R/workflow-printing.R b/R/workflow-printing.R index c46d10848..d9c3446f9 100644 --- a/R/workflow-printing.R +++ b/R/workflow-printing.R @@ -105,7 +105,7 @@ print_preprocessor_formula <- function(x) { invisible(x) } -print_prepocessor_variables <- function(x) { +print_preprocessor_variables <- function(x) { variables <- workflows::extract_preprocessor(x) outcomes <- rlang::quo_get_expr(variables$outcomes) predictors <- rlang::quo_get_expr(variables$predictors) diff --git a/R/zzz.R b/R/zzz.R index bb7cff9bf..7e335b67d 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -8,4 +8,5 @@ make_flatline_reg() make_quantile_reg() make_smooth_quantile_reg() + make_grf_quantiles() } diff --git a/README.Rmd b/README.Rmd index c5dc66cab..36af14cd9 100644 --- a/README.Rmd +++ b/README.Rmd @@ -24,24 +24,26 @@ knitr::opts_chunk$set( ## Installation -You can install the development version of epipredict from [GitHub](https://github.com/) with: +To install (unless you're making changes to the package, use the stable version): -``` r -# install.packages("remotes") -remotes::install_github("cmu-delphi/epipredict") +```r +# Stable version +pak::pkg_install("cmu-delphi/epipredict@main") + +# Dev version +pak::pkg_install("cmu-delphi/epipredict@dev") ``` ## Documentation You can view documentation for the `main` branch at . - ## Goals for `epipredict` **We hope to provide:** -1. A set of basic, easy-to-use forecasters that work out of the box. You should be able to do a reasonably limited amount of customization on them. For the basic forecasters, we currently provide: - * Baseline flatline forecaster +1. A set of basic, easy-to-use forecasters that work out of the box. You should be able to do a reasonably limited amount of customization on them. For the basic forecasters, we currently provide: + * Baseline flatline forecaster * Autoregressive forecaster * Autoregressive classifier * CDC FluSight flatline forecaster @@ -54,12 +56,12 @@ You can view documentation for the `main` branch at ══ A basic forecaster of type ARX Forecaster ═══════════════════════════════ #> -#> This forecaster was fit on 2023-12-23 09:12:46. +#> This forecaster was fit on 2024-01-29 15:10:01. #> #> Training data was an with: #> • Geography: state, @@ -203,7 +207,7 @@ through the end of 2021 for the 14th of January 2022. A prediction for the death rate per 100K inhabitants is available for every state (`geo_value`) along with a 90% predictive interval. -[^1]: Other epidemiological signals for non-Covid related illnesses are +1. Other epidemiological signals for non-Covid related illnesses are also available with [`{epidatr}`](https://github.com/cmu-delphi/epidatr) which interfaces directly to Delphi’s [Epidata diff --git a/_pkgdown.yml b/_pkgdown.yml index af4227b89..468da62ac 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -1,85 +1,128 @@ url: https://cmu-delphi.github.io/epipredict/ + +# This is to give a default value to the `mode` parameter in the +# `pkgdown::build_site()` function. This is useful when building the site locally, +# as it will default to `devel` mode. In practice, this should all be handled +# dynamically by the CI/CD pipeline. +development: + mode: devel + version_label: success + template: bootstrap: 5 bootswatch: cosmo bslib: font_scale: 1.0 - primary: '#C41230' - link-color: '#C41230' + primary: "#C41230" + success: "#B4D43C" + link-color: "#C41230" navbar: bg: primary - type: dark + type: light -home: - links: - - text: Introduction to Delphi's Tooling Work - href: https://cmu-delphi.github.io/delphi-tooling-book/ - - text: The epiprocess R package - href: https://cmu-delphi.github.io/epiprocess/ - - text: The epidatr R package - href: https://github.com/cmu-delphi/epidatr/ - - text: The epidatasets R package - href: https://cmu-delphi.github.io/epidatasets/ - - text: The covidcast R package - href: https://cmu-delphi.github.io/covidcast/covidcastR/ +articles: +- title: Get started + navbar: ~ + contents: + - epipredict + - preprocessing-and-models + - arx-classifier + - update +- title: Advanced methods + contents: + - articles/sliding + - articles/smooth-qr + - articles/symptom-surveys + - panel-data +repo: + url: + home: https://github.com/cmu-delphi/epipredict/tree/main/ + source: https://github.com/cmu-delphi/epipredict/blob/main/ + issue: https://github.com/cmu-delphi/epipredict/issues + user: https://github.com/ + +home: + links: + - text: Introduction to Delphi's Tooling Work + href: https://cmu-delphi.github.io/delphi-tooling-book/ + - text: The epiprocess R package + href: https://cmu-delphi.github.io/epiprocess/ + - text: The epidatr R package + href: https://github.com/cmu-delphi/epidatr/ + - text: The epidatasets R package + href: https://cmu-delphi.github.io/epidatasets/ + - text: The covidcast R package + href: https://cmu-delphi.github.io/covidcast/covidcastR/ reference: - title: Simple forecasters desc: Complete forecasters that produce reasonable baselines contents: - - contains("flatline") - - contains("arx") - - contains("cdc") + - contains("forecaster") + - contains("classifier") + - title: Forecaster modifications + desc: Constructors to modify forecaster arguments and utilities to produce `epi_workflow` objects + contents: + - contains("args_list") + - contains("_epi_workflow") - title: Helper functions for Hub submission contents: - - flusight_hub_formatter + - flusight_hub_formatter - title: Parsnip engines desc: Prediction methods not available elsewhere contents: - - quantile_reg - - smooth_quantile_reg + - quantile_reg + - smooth_quantile_reg + - grf_quantiles - title: Custom panel data forecasting workflows contents: - epi_recipe - epi_workflow - add_epi_recipe - adjust_epi_recipe - - add_model + - Add_model - predict.epi_workflow - fit.epi_workflow - augment.epi_workflow + - forecast.epi_workflow + - title: Epi recipe preprocessing steps contents: - - starts_with("step_") - - contains("bake") - - contains("juice") + - starts_with("step_") + - contains("bake") - title: Epi recipe verification checks contents: - - check_enough_train_data + - check_enough_train_data - title: Forecast postprocessing desc: Create a series of postprocessing operations contents: - - frosting - - ends_with("_frosting") - - get_test_data - - tidy.frosting + - frosting + - ends_with("_frosting") + - get_test_data + - tidy.frosting - title: Frosting layers contents: - - contains("layer") - - contains("slather") + - contains("layer") + - contains("slather") + - title: Automatic forecast visualization + contents: + - autoplot.epi_workflow + - autoplot.canned_epipred - title: Utilities for quantile distribution processing contents: - - dist_quantiles - - extrapolate_quantiles - - nested_quantiles - - starts_with("pivot_quantiles") + - dist_quantiles + - extrapolate_quantiles + - nested_quantiles + - weighted_interval_score + - starts_with("pivot_quantiles") + - title: Other utilities + contents: + - clean_f_name - title: Included datasets contents: - - case_death_rate_subset - - state_census - - - + - case_death_rate_subset + - state_census + - grad_employ_subset diff --git a/data-raw/grad_employ_subset.R b/data-raw/grad_employ_subset.R new file mode 100644 index 000000000..38719a02e --- /dev/null +++ b/data-raw/grad_employ_subset.R @@ -0,0 +1,106 @@ +library(epipredict) +library(epiprocess) +library(cansim) +library(dplyr) +library(stringr) +library(tidyr) + +# https://www150.statcan.gc.ca/t1/tbl1/en/tv.action?pid=3710011501 +statcan_grad_employ <- get_cansim("37-10-0115-01") + +gemploy <- statcan_grad_employ %>% + select(c( + "REF_DATE", + "GEO", + # "DGUID", + # "UOM", + # "UOM_ID", + # "SCALAR_FACTOR", + # "SCALAR_ID", + # "VECTOR", + # "COORDINATE", + "VALUE", + "STATUS", + # "SYMBOL", + # "TERMINATED", + # "DECIMALS", + # "GeoUID", + # "Hierarchy for GEO", + # "Classification Code for Educational qualification", + # "Hierarchy for Educational qualification", + # "Classification Code for Field of study", + # "Hierarchy for Field of study", + # "Classification Code for Gender", + # "Hierarchy for Gender", + # "Classification Code for Age group", + # "Hierarchy for Age group", + # "Classification Code for Status of student in Canada", + # "Hierarchy for Status of student in Canada", + # "Classification Code for Characteristics after graduation", + # "Hierarchy for Characteristics after graduation", + # "Classification Code for Graduate statistics", + # "Hierarchy for Graduate statistics", + # "val_norm", + # "Date", + "Educational qualification", + "Field of study", + "Gender", + "Age group", + "Status of student in Canada", + "Characteristics after graduation", + "Graduate statistics" + )) %>% + rename( + "geo_value" = "GEO", + "time_value" = "REF_DATE", + "value" = "VALUE", + "status" = "STATUS", + "edu_qual" = "Educational qualification", + "fos" = "Field of study", + "gender" = "Gender", + "age_group" = "Age group", + "student_status" = "Status of student in Canada", + "grad_charac" = "Characteristics after graduation", + "grad_stat" = "Graduate statistics" + ) %>% + mutate( + grad_stat = recode_factor( + grad_stat, + `Number of graduates` = "num_graduates", + `Median employment income two years after graduation` = "med_income_2y", + `Median employment income five years after graduation` = "med_income_5y" + ), + time_value = as.integer(time_value) + ) %>% + pivot_wider(names_from = grad_stat, values_from = value) %>% + filter( + # Drop aggregates for some columns + geo_value != "Canada" & + age_group != "15 to 64 years" & + edu_qual != "Total, educational qualification" & + # Keep aggregates for keys we don't want to keep + fos == "Total, field of study" & + gender == "Total, gender" & + student_status == "Canadian and international students" & + # Since we're looking at 2y and 5y employment income, the only + # characteristics remaining are: + # - Graduates reporting employment income + # - Graduates reporting wages, salaries, and commissions only + # For simplicity, keep the first one only + grad_charac == "Graduates reporting employment income" & + # Only keep "good" data + is.na(status) & + # Drop NA value rows + !is.na(num_graduates) & !is.na(med_income_2y) & !is.na(med_income_5y) + ) %>% + select(-c(status, gender, student_status, grad_charac, fos)) + +nrow(gemploy) +ncol(gemploy) + +grad_employ_subset <- gemploy %>% + as_epi_df( + as_of = "2022-07-19", + other_keys = c("age_group", "edu_qual") + ) +usethis::use_data(grad_employ_subset, overwrite = TRUE) diff --git a/data/grad_employ_subset.rda b/data/grad_employ_subset.rda new file mode 100644 index 000000000..9380b43b5 Binary files /dev/null and b/data/grad_employ_subset.rda differ diff --git a/inst/templates/layer.R b/inst/templates/layer.R index 3fecb3c33..59556db5f 100644 --- a/inst/templates/layer.R +++ b/inst/templates/layer.R @@ -29,6 +29,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) { #' @export slather.layer_{{{ name }}} <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() # if layer_ used ... in tidyselect, we need to evaluate it now exprs <- rlang::expr(c(!!!object$terms)) diff --git a/man/add_model.Rd b/man/Add_model.Rd similarity index 52% rename from man/add_model.Rd rename to man/Add_model.Rd index f1209b95f..17b65793c 100644 --- a/man/add_model.Rd +++ b/man/Add_model.Rd @@ -1,25 +1,43 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/epi_workflow.R -\name{add_model} +% Please edit documentation in R/model-methods.R +\name{Add_model} +\alias{Add_model} +\alias{Remove_model} +\alias{Update_model} +\alias{Add_model.epi_workflow} +\alias{Remove_model.epi_workflow} +\alias{Update_model.epi_workflow} +\alias{Add_model.workflow} +\alias{Remove_model.workflow} +\alias{Update_model.workflow} \alias{add_model} \alias{remove_model} \alias{update_model} -\alias{add_model.epi_workflow} -\alias{remove_model.epi_workflow} -\alias{update_model.epi_workflow} \title{Add a model to an \code{epi_workflow}} \usage{ -add_model(x, spec, ..., formula = NULL) +Add_model(x, spec, ..., formula = NULL) -remove_model(x) +Remove_model(x) -update_model(x, spec, ..., formula = NULL) +Update_model(x, spec, ..., formula = NULL) + +\method{Add_model}{epi_workflow}(x, spec, ..., formula = NULL) + +\method{Remove_model}{epi_workflow}(x) + +\method{Update_model}{epi_workflow}(x, spec, ..., formula = NULL) + +\method{Add_model}{workflow}(x, spec, ..., formula = NULL) -\method{add_model}{epi_workflow}(x, spec, ..., formula = NULL) +\method{Remove_model}{workflow}(x) -\method{remove_model}{epi_workflow}(x) +\method{Update_model}{workflow}(x, spec, ..., formula = NULL) -\method{update_model}{epi_workflow}(x, spec, ..., formula = NULL) +add_model(x, spec, ..., formula = NULL) + +remove_model(x) + +update_model(x, spec, ..., formula = NULL) } \arguments{ \item{x}{An \code{epi_workflow}.} @@ -45,13 +63,17 @@ Add a model to an \code{epi_workflow} \details{ Has the same behaviour as \code{\link[workflows:add_model]{workflows::add_model()}} but also ensures that the returned object is an \code{epi_workflow}. + +This family is called \verb{Add_*} / \verb{Update_*} / \verb{Remove_*} to avoid +masking the related functions in \code{{workflows}}. We also provide +aliases with the lower-case names. However, in the event that +\code{{workflows}} is loaded after \code{{epipredict}}, these may fail to function +properly. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter( - time_value > "2021-11-01", - geo_value \%in\% c("ak", "ca", "ny") - ) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% @@ -61,24 +83,24 @@ rf_model <- rand_forest(mode = "regression") wf <- epi_workflow(r) -wf <- wf \%>\% add_model(rf_model) +wf <- wf \%>\% Add_model(rf_model) wf -lm_model <- parsnip::linear_reg() +lm_model <- linear_reg() -wf <- update_model(wf, lm_model) +wf <- Update_model(wf, lm_model) wf -wf <- remove_model(wf) +wf <- Remove_model(wf) wf } \seealso{ \code{\link[workflows:add_model]{workflows::add_model()}} \itemize{ -\item \code{add_model()} adds a parsnip model to the \code{epi_workflow}. -\item \code{remove_model()} removes the model specification as well as any fitted +\item \code{Add_model()} adds a parsnip model to the \code{epi_workflow}. +\item \code{Remove_model()} removes the model specification as well as any fitted model object. Any extra formulas are also removed. -\item \code{update_model()} first removes the model then adds the new +\item \code{Update_model()} first removes the model then adds the new specification to the workflow. } } diff --git a/man/add_frosting.Rd b/man/add_frosting.Rd index 161a540e2..94812cbe2 100644 --- a/man/add_frosting.Rd +++ b/man/add_frosting.Rd @@ -26,15 +26,16 @@ update_frosting(x, frosting, ...) Add frosting to a workflow } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) latest <- jhu \%>\% - dplyr::filter(time_value >= max(time_value) - 14) + filter(time_value >= max(time_value) - 14) # Add frosting to a workflow and predict f <- frosting() \%>\% diff --git a/man/adjust_frosting.Rd b/man/adjust_frosting.Rd index 6cdc13b30..c089b3443 100644 --- a/man/adjust_frosting.Rd +++ b/man/adjust_frosting.Rd @@ -35,6 +35,7 @@ must be inputted as \code{...}. See the examples below for brief illustrations of the different types of updates. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% @@ -42,7 +43,7 @@ r <- epi_recipe(jhu) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) # in the frosting from the workflow f1 <- frosting() \%>\% diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index fc01a3461..ef18796cc 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -11,7 +11,7 @@ apply_frosting(workflow, ...) \method{apply_frosting}{default}(workflow, components, ...) -\method{apply_frosting}{epi_workflow}(workflow, components, new_data, ...) +\method{apply_frosting}{epi_workflow}(workflow, components, new_data, type = NULL, opts = list(), ...) } \arguments{ \item{workflow}{An object of class workflow} @@ -34,6 +34,9 @@ here for ease. \item{new_data}{a data frame containing the new predictors to preprocess and predict on} + +\item{type, opts}{forwarded (along with \code{...}) to \code{\link[=predict.model_fit]{predict.model_fit()}} and +\code{\link[=slather]{slather()}} for supported layers} } \description{ This function is intended for internal use. It implements postprocessing diff --git a/man/arx_class_args_list.Rd b/man/arx_class_args_list.Rd index a1205c71a..311950d62 100644 --- a/man/arx_class_args_list.Rd +++ b/man/arx_class_args_list.Rd @@ -13,7 +13,7 @@ arx_class_args_list( outcome_transform = c("growth_rate", "lag_difference"), breaks = 0.25, horizon = 7L, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, additional_gr_args = list(), nafill_buffer = Inf, diff --git a/man/arx_class_epi_workflow.Rd b/man/arx_class_epi_workflow.Rd index e55e14160..713365f17 100644 --- a/man/arx_class_epi_workflow.Rd +++ b/man/arx_class_epi_workflow.Rd @@ -8,7 +8,7 @@ arx_class_epi_workflow( epi_data, outcome, predictors, - trainer = NULL, + trainer = parsnip::logistic_reg(), args_list = arx_class_args_list() ) } @@ -22,14 +22,17 @@ internally based on the \code{breaks} argument to \code{\link[=arx_class_args_li If discrete classes are already in the \code{epi_df}, it is recommended to code up a classifier from scratch using \code{\link[=epi_recipe]{epi_recipe()}}.} -\item{predictors}{A character vector giving column(s) of predictor -variables.} +\item{predictors}{A character vector giving column(s) of predictor variables. +This defaults to the \code{outcome}. However, if manually specified, only those variables +specifically mentioned will be used. (The \code{outcome} will not be added.) +By default, equals the outcome. If manually specified, does not add the +outcome variable, so make sure to specify it.} -\item{trainer}{A \code{{parsnip}} model describing the type of estimation. -For now, we enforce \code{mode = "classification"}. Typical values are +\item{trainer}{A \code{{parsnip}} model describing the type of estimation. For +now, we enforce \code{mode = "classification"}. Typical values are \code{\link[parsnip:logistic_reg]{parsnip::logistic_reg()}} or \code{\link[parsnip:multinom_reg]{parsnip::multinom_reg()}}. More complicated -trainers like \code{\link[parsnip:naive_Bayes]{parsnip::naive_Bayes()}} or \code{\link[parsnip:rand_forest]{parsnip::rand_forest()}} can -also be used. May be \code{NULL} (the default).} +trainers like \code{\link[parsnip:naive_Bayes]{parsnip::naive_Bayes()}} or \code{\link[parsnip:rand_forest]{parsnip::rand_forest()}} can also +be used. May be \code{NULL} if you'd like to decide later.} \item{args_list}{A list of customization arguments to determine the type of forecasting model. See \code{\link[=arx_class_args_list]{arx_class_args_list()}}.} @@ -44,9 +47,9 @@ before fitting and predicting. Supplying a trainer to the function may alter the returned \code{epi_workflow} object but can be omitted. } \examples{ - +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value >= as.Date("2021-11-01")) + filter(time_value >= as.Date("2021-11-01")) arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate")) @@ -54,7 +57,7 @@ arx_class_epi_workflow( jhu, "death_rate", c("case_rate", "death_rate"), - trainer = parsnip::multinom_reg(), + trainer = multinom_reg(), args_list = arx_class_args_list( breaks = c(-.05, .1), ahead = 14, horizon = 14, method = "linear_reg" diff --git a/man/arx_classifier.Rd b/man/arx_classifier.Rd index de487ec51..c7c2cf059 100644 --- a/man/arx_classifier.Rd +++ b/man/arx_classifier.Rd @@ -8,7 +8,7 @@ arx_classifier( epi_data, outcome, predictors, - trainer = parsnip::logistic_reg(), + trainer = logistic_reg(), args_list = arx_class_args_list() ) } @@ -22,8 +22,11 @@ internally based on the \code{breaks} argument to \code{\link[=arx_class_args_li If discrete classes are already in the \code{epi_df}, it is recommended to code up a classifier from scratch using \code{\link[=epi_recipe]{epi_recipe()}}.} -\item{predictors}{A character vector giving column(s) of predictor -variables.} +\item{predictors}{A character vector giving column(s) of predictor variables. +This defaults to the \code{outcome}. However, if manually specified, only those variables +specifically mentioned will be used. (The \code{outcome} will not be added.) +By default, equals the outcome. If manually specified, does not add the +outcome variable, so make sure to specify it.} \item{trainer}{A \code{{parsnip}} model describing the type of estimation. For now, we enforce \code{mode = "classification"}. Typical values are @@ -45,8 +48,9 @@ This is an autoregressive classification model for that it estimates a class at a particular target horizon. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value >= as.Date("2021-11-01")) + filter(time_value >= as.Date("2021-11-01")) out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate")) diff --git a/man/arx_fcast_epi_workflow.Rd b/man/arx_fcast_epi_workflow.Rd index 8c76bcdd7..4070a3337 100644 --- a/man/arx_fcast_epi_workflow.Rd +++ b/man/arx_fcast_epi_workflow.Rd @@ -7,8 +7,8 @@ arx_fcast_epi_workflow( epi_data, outcome, - predictors, - trainer = NULL, + predictors = outcome, + trainer = linear_reg(), args_list = arx_args_list() ) } @@ -18,11 +18,15 @@ arx_fcast_epi_workflow( \item{outcome}{A character (scalar) specifying the outcome (in the \code{epi_df}).} -\item{predictors}{A character vector giving column(s) of predictor -variables.} +\item{predictors}{A character vector giving column(s) of predictor variables. +This defaults to the \code{outcome}. However, if manually specified, only those variables +specifically mentioned will be used. (The \code{outcome} will not be added.) +By default, equals the outcome. If manually specified, does not add the +outcome variable, so make sure to specify it.} -\item{trainer}{A \code{{parsnip}} model describing the type of estimation. -For now, we enforce \code{mode = "regression"}. May be \code{NULL} (the default).} +\item{trainer}{A \code{{parsnip}} model describing the type of estimation. For +now, we enforce \code{mode = "regression"}. May be \code{NULL} if you'd like to +decide later.} \item{args_list}{A list of customization arguments to determine the type of forecasting model. See \code{\link[=arx_args_list]{arx_args_list()}}.} @@ -38,8 +42,9 @@ may alter the returned \code{epi_workflow} object (e.g., if you intend to use \code{\link[=quantile_reg]{quantile_reg()}}) but can be omitted. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value >= as.Date("2021-12-01")) + filter(time_value >= as.Date("2021-12-01")) arx_fcast_epi_workflow( jhu, "death_rate", diff --git a/man/arx_forecaster.Rd b/man/arx_forecaster.Rd index 7a042c65c..d8c7671dc 100644 --- a/man/arx_forecaster.Rd +++ b/man/arx_forecaster.Rd @@ -7,8 +7,8 @@ arx_forecaster( epi_data, outcome, - predictors, - trainer = parsnip::linear_reg(), + predictors = outcome, + trainer = linear_reg(), args_list = arx_args_list() ) } @@ -18,8 +18,11 @@ arx_forecaster( \item{outcome}{A character (scalar) specifying the outcome (in the \code{epi_df}).} -\item{predictors}{A character vector giving column(s) of predictor -variables.} +\item{predictors}{A character vector giving column(s) of predictor variables. +This defaults to the \code{outcome}. However, if manually specified, only those variables +specifically mentioned will be used. (The \code{outcome} will not be added.) +By default, equals the outcome. If manually specified, does not add the +outcome variable, so make sure to specify it.} \item{trainer}{A \code{{parsnip}} model describing the type of estimation. For now, we enforce \code{mode = "regression"}.} diff --git a/man/autoplot-epipred.Rd b/man/autoplot-epipred.Rd new file mode 100644 index 000000000..27bfdf5f7 --- /dev/null +++ b/man/autoplot-epipred.Rd @@ -0,0 +1,124 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/autoplot.R +\name{autoplot-epipred} +\alias{autoplot-epipred} +\alias{autoplot.epi_workflow} +\alias{autoplot.canned_epipred} +\title{Automatically plot an \code{epi_workflow} or \code{canned_epipred} object} +\usage{ +\method{autoplot}{epi_workflow}( + object, + predictions = NULL, + .levels = c(0.5, 0.8, 0.95), + ..., + .color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"), + .facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"), + .base_color = "dodgerblue4", + .point_pred_color = "orange", + .max_facets = Inf +) + +\method{autoplot}{canned_epipred}( + object, + ..., + .color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"), + .facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"), + .base_color = "dodgerblue4", + .point_pred_color = "orange", + .max_facets = Inf +) +} +\arguments{ +\item{object}{An \code{epi_workflow}} + +\item{predictions}{A data frame with predictions. If \code{NULL}, only the +original data is shown.} + +\item{.levels}{A numeric vector of levels to plot for any prediction bands. +More than 3 levels begins to be difficult to see.} + +\item{...}{Ignored} + +\item{.color_by}{Which variables should determine the color(s) used to plot +lines. Options include: +\itemize{ +\item \code{all_keys} - the default uses the interaction of any key variables +including the \code{geo_value} +\item \code{geo_value} - \code{geo_value} only +\item \code{other_keys} - any available keys that are not \code{geo_value} +\item \code{.response} - the numeric variables (same as the y-axis) +\item \code{all} - uses the interaction of all keys and numeric variables +\item \code{none} - no coloring aesthetic is applied +}} + +\item{.facet_by}{Similar to \code{.color_by} except that the default is to +display the response.} + +\item{.base_color}{If available, prediction bands will be shown with this +color.} + +\item{.point_pred_color}{If available, point forecasts will be shown with +this color.} + +\item{.max_facets}{Cut down of the number of facets displayed. Especially +useful for testing when there are many \code{geo_value}'s or keys.} +} +\description{ +For a fit workflow, the training data will be displayed, the response by +default. If \code{predictions} is not \code{NULL} then point and interval forecasts +will be shown as well. Unfit workflows will result in an error, (you +can simply call \code{autoplot()} on the original \code{epi_df}). +} +\examples{ +library(dplyr) +jhu <- case_death_rate_subset \%>\% + filter(time_value >= as.Date("2021-11-01")) + +r <- epi_recipe(jhu) \%>\% + step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% + step_epi_ahead(death_rate, ahead = 7) \%>\% + step_epi_lag(case_rate, lag = c(0, 7, 14)) \%>\% + step_epi_naomit() + +f <- frosting() \%>\% + layer_residual_quantiles( + quantile_levels = c(.025, .1, .25, .75, .9, .975) + ) \%>\% + layer_threshold(starts_with(".pred")) \%>\% + layer_add_target_date() + +wf <- epi_workflow(r, linear_reg(), f) \%>\% fit(jhu) + +autoplot(wf) + +latest <- jhu \%>\% filter(time_value >= max(time_value) - 14) +preds <- predict(wf, latest) +autoplot(wf, preds, .max_facets = 4) + +# ------- Show multiple horizons + +p <- lapply(c(7, 14, 21, 28), function(h) { + r <- epi_recipe(jhu) \%>\% + step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% + step_epi_ahead(death_rate, ahead = h) \%>\% + step_epi_lag(case_rate, lag = c(0, 7, 14)) \%>\% + step_epi_naomit() + ewf <- epi_workflow(r, linear_reg(), f) \%>\% fit(jhu) + forecast(ewf) +}) + +p <- do.call(rbind, p) +autoplot(wf, p, .max_facets = 4) + +# ------- Plotting canned forecaster output + +jhu <- case_death_rate_subset \%>\% + filter(time_value >= as.Date("2021-11-01")) +flat <- flatline_forecaster(jhu, "death_rate") +autoplot(flat, .max_facets = 4) + +arx <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), + args_list = arx_args_list(ahead = 14L) +) +autoplot(arx, .max_facets = 6) +} diff --git a/man/bake.Rd b/man/bake.Rd deleted file mode 100644 index c1c0137c5..000000000 --- a/man/bake.Rd +++ /dev/null @@ -1,28 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/bake.epi_recipe.R -\name{bake.epi_recipe} -\alias{bake.epi_recipe} -\title{Bake an epi_recipe} -\usage{ -\method{bake}{epi_recipe}(object, new_data, ...) -} -\arguments{ -\item{object}{A trained object such as a \code{\link[=recipe]{recipe()}} with at least -one preprocessing operation.} - -\item{new_data}{An \code{epi_df}, data frame or tibble for whom the -preprocessing will be applied. If \code{NULL} is given to \code{new_data}, -the pre-processed \emph{training data} will be returned.} - -\item{...}{One or more selector functions to choose which variables will be -returned by the function. See \code{\link[recipes:selections]{recipes::selections()}} for -more details. If no selectors are given, the default is to -use \code{\link[tidyselect:everything]{tidyselect::everything()}}.} -} -\value{ -An \code{epi_df} that may have different columns than the -original columns in \code{new_data}. -} -\description{ -Bake an epi_recipe -} diff --git a/man/cdc_baseline_forecaster.Rd b/man/cdc_baseline_forecaster.Rd index cd3c4ed67..0c7f1e436 100644 --- a/man/cdc_baseline_forecaster.Rd +++ b/man/cdc_baseline_forecaster.Rd @@ -44,11 +44,11 @@ weekly_deaths <- case_death_rate_subset \%>\% mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) \%>\% select(-pop, -death_rate) \%>\% group_by(geo_value) \%>\% - epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") \%>\% + epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") \%>\% ungroup() \%>\% filter(weekdays(time_value) == "Saturday") -cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") +cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum") preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn) if (require(ggplot2)) { @@ -62,7 +62,7 @@ if (require(ggplot2)) { geom_line(aes(y = .pred), color = "orange") + geom_line( data = weekly_deaths \%>\% filter(geo_value \%in\% four_states), - aes(x = time_value, y = deaths) + aes(x = time_value, y = deaths_7dsum) ) + scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) + labs(x = "Date", y = "Weekly deaths") + diff --git a/man/clean_f_name.Rd b/man/clean_f_name.Rd new file mode 100644 index 000000000..20ed921df --- /dev/null +++ b/man/clean_f_name.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/step_epi_slide.R +\name{clean_f_name} +\alias{clean_f_name} +\title{Create short function names} +\usage{ +clean_f_name(.f, max_length = 20L) +} +\arguments{ +\item{.f}{a function, character string, or lambda. For example, \code{mean}, +\code{"mean"}, \code{~ mean(.x)} or \verb{\\(x) mean(x, na.rm = TRUE)}.} + +\item{max_length}{integer determining how long names can be} +} +\value{ +a character string of length at most \code{max_length} that +(partially) describes the function. +} +\description{ +Create short function names +} +\examples{ +clean_f_name(mean) +clean_f_name("mean") +clean_f_name(~ mean(.x, na.rm = TRUE)) +clean_f_name(\(x) mean(x, na.rm = TRUE)) +clean_f_name(function(x) mean(x, na.rm = TRUE, trim = 0.2357862)) +} diff --git a/man/create_layer.Rd b/man/create_layer.Rd deleted file mode 100644 index d36385fb2..000000000 --- a/man/create_layer.Rd +++ /dev/null @@ -1,28 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/create-layer.R -\name{create_layer} -\alias{create_layer} -\title{Create a new layer} -\usage{ -create_layer(name = NULL, open = rlang::is_interactive()) -} -\arguments{ -\item{name}{Either a string giving a file name (without directory) or -\code{NULL} to take the name from the currently open file in RStudio.} - -\item{open}{Whether to open the file for interactive editing.} -} -\description{ -This function creates the skeleton for a new \code{frosting} layer. When called -inside a package, it will create an R script in the \verb{R/} directory, -fill in the name of the layer, and open the file. -} -\examples{ -\dontrun{ - -# Note: running this will write `layer_strawberry.R` to -# the `R/` directory of your current project -create_layer("strawberry") -} - -} diff --git a/man/dist_quantiles.Rd b/man/dist_quantiles.Rd index 57d2f3b3b..1a3226e36 100644 --- a/man/dist_quantiles.Rd +++ b/man/dist_quantiles.Rd @@ -7,15 +7,26 @@ dist_quantiles(values, quantile_levels) } \arguments{ -\item{values}{A vector of values} +\item{values}{A vector (or list of vectors) of values.} -\item{quantile_levels}{A vector of probabilities corresponding to \code{values}} +\item{quantile_levels}{A vector (or list of vectors) of probabilities +corresponding to \code{values}. + +When creating multiple sets of \code{values}/\code{quantile_levels} resulting in +different distributions, the sizes must match. See the examples below.} +} +\value{ +A vector of class \code{"distribution"}. } \description{ A distribution parameterized by a set of quantiles } \examples{ -dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) +dist_quantiles(1:4, 1:4 / 5) +dist_quantiles(list(1:3, 1:4), list(1:3 / 4, 1:4 / 5)) +dstn <- dist_quantiles(list(1:4, 8:11), c(.2, .4, .6, .8)) +dstn + quantile(dstn, p = c(.1, .25, .5, .9)) median(dstn) @@ -23,5 +34,4 @@ median(dstn) distributional::parameters(dstn[1]) nested_quantiles(dstn[1])[[1]] -dist_quantiles(1:4, 1:4 / 5) } diff --git a/man/epi_juice.Rd b/man/epi_juice.Rd deleted file mode 100644 index 38eccb9a9..000000000 --- a/man/epi_juice.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/epi_juice.R -\name{epi_juice} -\alias{epi_juice} -\title{Extract transformed training set} -\usage{ -epi_juice(object, ...) -} -\arguments{ -\item{object}{A trained object such as a \code{\link[=recipe]{recipe()}} with at least -one preprocessing operation.} - -\item{...}{One or more selector functions to choose which variables will be -returned by the function. See \code{\link[recipes:selections]{recipes::selections()}} for -more details. If no selectors are given, the default is to -use \code{\link[tidyselect:everything]{tidyselect::everything()}}.} -} -\description{ -Extract transformed training set -} diff --git a/man/epi_keys.Rd b/man/epi_keys.Rd deleted file mode 100644 index 8026fc140..000000000 --- a/man/epi_keys.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/epi_keys.R -\name{epi_keys} -\alias{epi_keys} -\title{Grab any keys associated to an epi_df} -\usage{ -epi_keys(x, ...) -} -\arguments{ -\item{x}{a data.frame, tibble, or epi_df} - -\item{...}{additional arguments passed on to methods} -} -\value{ -If an \code{epi_df}, this returns all "keys". Otherwise \code{NULL} -} -\description{ -Grab any keys associated to an epi_df -} -\keyword{internal} diff --git a/man/epi_recipe.Rd b/man/epi_recipe.Rd index 1c9048a36..d0105d1ec 100644 --- a/man/epi_recipe.Rd +++ b/man/epi_recipe.Rd @@ -57,17 +57,19 @@ around \code{\link[recipes:recipe]{recipes::recipe()}} to properly handle the ad columns present in an \code{epi_df} } \examples{ +library(dplyr) +library(recipes) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-08-01") \%>\% - dplyr::arrange(geo_value, time_value) + filter(time_value > "2021-08-01") \%>\% + arrange(geo_value, time_value) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_lag(case_rate, lag = c(0, 7, 14)) \%>\% - recipes::step_naomit(recipes::all_predictors()) \%>\% + step_naomit(all_predictors()) \%>\% # below, `skip` means we don't do this at predict time - recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) + step_naomit(all_outcomes(), skip = TRUE) r } diff --git a/man/epi_slide_wrapper.Rd b/man/epi_slide_wrapper.Rd new file mode 100644 index 000000000..d67db1c88 --- /dev/null +++ b/man/epi_slide_wrapper.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/step_epi_slide.R +\name{epi_slide_wrapper} +\alias{epi_slide_wrapper} +\title{Wrapper to handle epi_slide particulars} +\usage{ +epi_slide_wrapper( + new_data, + .window_size, + .align, + columns, + fns, + fn_names, + group_keys, + name_prefix +) +} +\arguments{ +\item{fns}{vector of functions, even if it's length 1.} + +\item{group_keys}{the keys to group by. likely \code{epi_keys} (without \code{time_value})} +} +\description{ +This should simplify somewhat in the future when we can run \code{epi_slide} on +columns. Surprisingly, lapply is several orders of magnitude faster than +using roughly equivalent tidy select style. +} +\keyword{internal} diff --git a/man/extrapolate_quantiles.Rd b/man/extrapolate_quantiles.Rd index 619b2aa07..4b1d1282c 100644 --- a/man/extrapolate_quantiles.Rd +++ b/man/extrapolate_quantiles.Rd @@ -1,20 +1,26 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/dist_quantiles.R +% Please edit documentation in R/extrapolate_quantiles.R \name{extrapolate_quantiles} \alias{extrapolate_quantiles} \title{Summarize a distribution with a set of quantiles} \usage{ -extrapolate_quantiles(x, probs, ...) +extrapolate_quantiles(x, probs, replace_na = TRUE, ...) } \arguments{ \item{x}{a \code{distribution} vector} \item{probs}{a vector of probabilities at which to calculate quantiles} +\item{replace_na}{logical. If \code{x} contains \code{NA}'s, these are imputed if +possible (if \code{TRUE}) or retained (if \code{FALSE}). This only effects +elements of class \code{dist_quantiles}.} + \item{...}{additional arguments passed on to the \code{quantile} method} } \value{ -a \code{distribution} vector containing \code{dist_quantiles} +a \code{distribution} vector containing \code{dist_quantiles}. Any elements +of \code{x} which were originally \code{dist_quantiles} will now have a superset +of the original \code{quantile_values} (the union of those and \code{probs}). } \description{ Summarize a distribution with a set of quantiles diff --git a/man/flatline_forecaster.Rd b/man/flatline_forecaster.Rd index 052dd6428..1803f1078 100644 --- a/man/flatline_forecaster.Rd +++ b/man/flatline_forecaster.Rd @@ -20,8 +20,9 @@ ahead (unique horizon) for each unique combination of \code{key_vars}. } \description{ This is a simple forecasting model for -\link[epiprocess:epi_df]{epiprocess::epi_df} data. It uses the most recent observation as the -forcast for any future date, and produces intervals based on the quantiles +\link[epiprocess:epi_df]{epiprocess::epi_df} data. It uses the most recent +observation as the +forecast for any future date, and produces intervals based on the quantiles of the residuals of such a "flatline" forecast over all available training data. } diff --git a/man/flusight_hub_formatter.Rd b/man/flusight_hub_formatter.Rd index 8f3604756..b2be9b4fe 100644 --- a/man/flusight_hub_formatter.Rd +++ b/man/flusight_hub_formatter.Rd @@ -41,21 +41,24 @@ be done via the \code{...} argument. See the examples below. The specific requir format for this forecast task is \href{https://github.com/cdcepi/FluSight-forecast-hub/blob/main/model-output/README.md}{here}. } \examples{ -if (require(dplyr)) { - weekly_deaths <- case_death_rate_subset \%>\% - select(geo_value, time_value, death_rate) \%>\% - left_join(state_census \%>\% select(pop, abbr), by = c("geo_value" = "abbr")) \%>\% - mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) \%>\% - select(-pop, -death_rate) \%>\% - group_by(geo_value) \%>\% - epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") \%>\% - ungroup() \%>\% - filter(weekdays(time_value) == "Saturday") +library(dplyr) +weekly_deaths <- case_death_rate_subset \%>\% + filter( + time_value >= as.Date("2021-09-01"), + geo_value \%in\% c("ca", "ny", "dc", "ga", "vt") + ) \%>\% + select(geo_value, time_value, death_rate) \%>\% + left_join(state_census \%>\% select(pop, abbr), by = c("geo_value" = "abbr")) \%>\% + mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) \%>\% + select(-pop, -death_rate) \%>\% + group_by(geo_value) \%>\% + epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") \%>\% + ungroup() \%>\% + filter(weekdays(time_value) == "Saturday") - cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") - flusight_hub_formatter(cdc) - flusight_hub_formatter(cdc, target = "wk inc covid deaths") - flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) - flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") -} +cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum") +flusight_hub_formatter(cdc) +flusight_hub_formatter(cdc, target = "wk inc covid deaths") +flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) +flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") } diff --git a/man/forecast.epi_workflow.Rd b/man/forecast.epi_workflow.Rd new file mode 100644 index 000000000..b9f6870b8 --- /dev/null +++ b/man/forecast.epi_workflow.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/epi_workflow.R +\name{forecast.epi_workflow} +\alias{forecast.epi_workflow} +\title{Produce a forecast from an epi workflow} +\usage{ +\method{forecast}{epi_workflow}(object, ..., fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) +} +\arguments{ +\item{object}{An epi workflow.} + +\item{...}{Not used.} + +\item{fill_locf}{Logical. Should we use locf to fill in missing data?} + +\item{n_recent}{Integer or NULL. If filling missing data with locf = TRUE, +how far back are we willing to tolerate missing data? Larger values allow +more filling. The default NULL will determine this from the the recipe. For +example, suppose n_recent = 3, then if the 3 most recent observations in any +geo_value are all NA’s, we won’t be able to fill anything, and an error +message will be thrown. (See details.)} + +\item{forecast_date}{By default, this is set to the maximum time_value in x. +But if there is data latency such that recent NA's should be filled, this may +be after the last available time_value.} +} +\value{ +A forecast tibble. +} +\description{ +Produce a forecast from an epi workflow +} diff --git a/man/frosting.Rd b/man/frosting.Rd index 362c40a4f..a75f21b61 100644 --- a/man/frosting.Rd +++ b/man/frosting.Rd @@ -22,14 +22,14 @@ to hold steps for postprocessing predictions. The arguments are currently placeholders and must be NULL } \examples{ - +library(dplyr) # Toy example to show that frosting can be created and added for postprocessing f <- frosting() wf <- epi_workflow() \%>\% add_frosting(f) # A more realistic example jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% @@ -37,7 +37,6 @@ r <- epi_recipe(jhu) \%>\% step_epi_naomit() wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) -latest <- get_test_data(recipe = r, x = jhu) f <- frosting() \%>\% layer_predict() \%>\% @@ -45,6 +44,6 @@ f <- frosting() \%>\% wf1 <- wf \%>\% add_frosting(f) -p <- predict(wf1, latest) +p <- forecast(wf1) p } diff --git a/man/get_test_data.Rd b/man/get_test_data.Rd index 392d1dce2..b18685d89 100644 --- a/man/get_test_data.Rd +++ b/man/get_test_data.Rd @@ -37,7 +37,7 @@ keys, as well other variables in the original dataset. } \description{ Based on the longest lag period in the recipe, -\code{get_test_data()} creates an \link{epi_df} +\code{get_test_data()} creates an \link[epiprocess:epi_df]{epi_df} with columns \code{geo_value}, \code{time_value} and other variables in the original dataset, which will be used to create features necessary to produce forecasts. diff --git a/man/grab_names.Rd b/man/grab_names.Rd deleted file mode 100644 index cee6b19dc..000000000 --- a/man/grab_names.Rd +++ /dev/null @@ -1,31 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/grab_names.R -\name{grab_names} -\alias{grab_names} -\title{Get the names from a data frame via tidy select} -\usage{ -grab_names(dat, ...) -} -\arguments{ -\item{dat}{a data.frame} - -\item{...}{<\code{\link[dplyr:dplyr_tidy_select]{tidy-select}}> One or more unquoted -expressions separated by commas. Variable names can be used as if they -were positions in the data frame, so expressions like \code{x:y} can -be used to select a range of variables.} -} -\value{ -a character vector -} -\description{ -Given a data.frame, use \verb{} syntax to choose -some variables. Return the names of those variables -} -\details{ -As this is an internal function, no checks are performed. -} -\examples{ -df <- data.frame(a = 1, b = 2, cc = rep(NA, 3)) -grab_names(df, dplyr::starts_with("c")) -} -\keyword{internal} diff --git a/man/grad_employ_subset.Rd b/man/grad_employ_subset.Rd new file mode 100644 index 000000000..46ba36913 --- /dev/null +++ b/man/grad_employ_subset.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data.R +\docType{data} +\name{grad_employ_subset} +\alias{grad_employ_subset} +\title{Subset of Statistics Canada median employment income for postsecondary graduates} +\format{ +An \link[epiprocess:epi_df]{epiprocess::epi_df} with 10193 rows and 8 variables: +\describe{ +\item{geo_value}{The province in Canada associated with each +row of measurements.} +\item{time_value}{The time value, a year integer in YYYY format} +\item{edu_qual}{The education qualification} +\item{fos}{The field of study} +\item{age_group}{The age group; either 15 to 34 or 35 to 64} +\item{num_graduates}{The number of graduates for the given row of characteristics} +\item{med_income_2y}{The median employment income two years after graduation} +\item{med_income_5y}{The median employment income five years after graduation} +} +} +\source{ +This object contains modified data from the following Statistics Canada +data table: \href{https://www150.statcan.gc.ca/t1/tbl1/en/tv.action?pid=3710011501}{ +Characteristics and median employment income of longitudinal cohorts of postsecondary +graduates two and five years after graduation, by educational qualification and +field of study (primary groupings) +} + +Modifications: +\itemize{ +\item Only provincial-level geo_values are kept +\item Only age group, field of study, and educational qualification are kept as +covariates. For the remaining covariates, we keep aggregated values and +drop the level-specific rows. +\item No modifications were made to the time range of the data +} +} +\usage{ +grad_employ_subset +} +\description{ +Subset of Statistics Canada median employment income for postsecondary graduates +} +\keyword{datasets} diff --git a/man/grf_quantiles.Rd b/man/grf_quantiles.Rd new file mode 100644 index 000000000..e6852a55b --- /dev/null +++ b/man/grf_quantiles.Rd @@ -0,0 +1,108 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/make_grf_quantiles.R +\name{grf_quantiles} +\alias{grf_quantiles} +\title{Random quantile forests via grf} +\description{ +\code{\link[grf:quantile_forest]{grf::quantile_forest()}} fits random forests in a way that makes it easy +to calculate \emph{quantile} forests. Currently, this is the only engine +provided here, since quantile regression is the typical use-case. +} +\section{Tuning Parameters}{ + + +This model has 3 tuning parameters: +\itemize{ +\item \code{mtry}: # Randomly Selected Predictors (type: integer, default: see below) +\item \code{trees}: # Trees (type: integer, default: 2000L) +\item \code{min_n}: Minimal Node Size (type: integer, default: 5) +} + +\code{mtry} depends on the number of columns in the design matrix. +The default in \code{\link[grf:quantile_forest]{grf::quantile_forest()}} is \code{min(ceiling(sqrt(ncol(X)) + 20), ncol(X))}. + +For categorical predictors, a one-hot encoding is always used. This makes +splitting efficient, but has implications for the \code{mtry} choice. A factor +with many levels will become a large number of columns in the design matrix +which means that some of these may be selected frequently for potential splits. +This is different than in other implementations of random forest. For more +details, see \href{https://grf-labs.github.io/grf/articles/categorical_inputs.html}{the \code{grf} discussion}. +} + +\section{Translation from parsnip to the original package}{ + + +\if{html}{\out{
}}\preformatted{rand_forest( + mode = "regression", # you must specify the `mode = regression` + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) \%>\% + set_engine("grf_quantiles") \%>\% + translate() +#> Random Forest Model Specification (regression) +#> +#> Main Arguments: +#> mtry = integer(1) +#> trees = integer(1) +#> min_n = integer(1) +#> +#> Computational engine: grf_quantiles +#> +#> Model fit template: +#> grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1), +#> x), num.trees = integer(1), min.node.size = min_rows(~integer(1), +#> x), quantiles = c(0.1, 0.5, 0.9), num.threads = 1L, seed = stats::runif(1, +#> 0, .Machine$integer.max)) +}\if{html}{\out{
}} +} + +\section{Case weights}{ + + +Case weights are not supported. +} + +\examples{ +library(grf) +tib <- data.frame( + y = rnorm(100), x = rnorm(100), z = rnorm(100), + f = factor(sample(letters[1:3], 100, replace = TRUE)) +) +spec <- rand_forest(engine = "grf_quantiles", mode = "regression") +out <- fit(spec, formula = y ~ x + z, data = tib) +predict(out, new_data = tib[1:5, ]) \%>\% + pivot_quantiles_wider(.pred) + +# -- adjusting the desired quantiles + +spec <- rand_forest(mode = "regression") \%>\% + set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10)) +out <- fit(spec, formula = y ~ x + z, data = tib) +predict(out, new_data = tib[1:5, ]) \%>\% + pivot_quantiles_wider(.pred) + +# -- a more complicated task + +library(dplyr) +dat <- case_death_rate_subset \%>\% + filter(time_value > as.Date("2021-10-01")) +rec <- epi_recipe(dat) \%>\% + step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) \%>\% + step_epi_ahead(death_rate, ahead = 7) \%>\% + step_epi_naomit() +frost <- frosting() \%>\% + layer_predict() \%>\% + layer_threshold(.pred) +spec <- rand_forest(mode = "regression") \%>\% + set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75)) + +ewf <- epi_workflow(rec, spec, frost) \%>\% + fit(dat) \%>\% + forecast() +ewf \%>\% + rename(forecast_date = time_value) \%>\% + mutate(target_date = forecast_date + 7L) \%>\% + pivot_quantiles_wider(.pred) + +} diff --git a/man/layer-processors.Rd b/man/layer-processors.Rd index 0c6df8c5c..76e230a7b 100644 --- a/man/layer-processors.Rd +++ b/man/layer-processors.Rd @@ -20,7 +20,7 @@ extract_layers(x, ...) is_layer(x) -validate_layer(x, ..., arg = "`x`", call = caller_env()) +validate_layer(x, ..., arg = rlang::caller_arg(x), call = caller_env()) detect_layer(x, name, ...) diff --git a/man/layer_add_forecast_date.Rd b/man/layer_add_forecast_date.Rd index 4e173d662..e27f2bacd 100644 --- a/man/layer_add_forecast_date.Rd +++ b/man/layer_add_forecast_date.Rd @@ -36,15 +36,16 @@ less than the maximum \code{as_of} value (from the data used pre-processing, model fitting, and postprocessing), an appropriate warning will be thrown. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) latest <- jhu \%>\% - dplyr::filter(time_value >= max(time_value) - 14) + filter(time_value >= max(time_value) - 14) # Don't specify `forecast_date` (by default, this should be last date in latest) f <- frosting() \%>\% diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index 3c2884e10..dc0d2f190 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -37,25 +37,25 @@ has been specified in a preprocessing step (most likely in in the test data to get the target date. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) -latest <- get_test_data(r, jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) # Use ahead + forecast date f <- frosting() \%>\% layer_predict() \%>\% - layer_add_forecast_date(forecast_date = "2022-05-31") \%>\% + layer_add_forecast_date(forecast_date = as.Date("2022-05-31")) \%>\% layer_add_target_date() \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f) -p <- predict(wf1, latest) +p <- forecast(wf1) p # Use ahead + max time value from pre, fit, post @@ -66,7 +66,7 @@ f2 <- frosting() \%>\% layer_naomit(.pred) wf2 <- wf \%>\% add_frosting(f2) -p2 <- predict(wf2, latest) +p2 <- forecast(wf2) p2 # Specify own target date @@ -76,6 +76,6 @@ f3 <- frosting() \%>\% layer_naomit(.pred) wf3 <- wf \%>\% add_frosting(f3) -p3 <- predict(wf3, latest) +p3 <- forecast(wf3) p3 } diff --git a/man/layer_cdc_flatline_quantiles.Rd b/man/layer_cdc_flatline_quantiles.Rd index cf11de8eb..c3bc4f257 100644 --- a/man/layer_cdc_flatline_quantiles.Rd +++ b/man/layer_cdc_flatline_quantiles.Rd @@ -84,6 +84,7 @@ the future. This version continues to use the same set of residuals, and adds them on to produce wider intervals as \code{ahead} increases. } \examples{ +library(dplyr) r <- epi_recipe(case_death_rate_subset) \%>\% # data is "daily", so we fit this to 1 ahead, the result will contain # 1 day ahead residuals @@ -93,24 +94,20 @@ r <- epi_recipe(case_death_rate_subset) \%>\% forecast_date <- max(case_death_rate_subset$time_value) -latest <- get_test_data( - epi_recipe(case_death_rate_subset), case_death_rate_subset -) - f <- frosting() \%>\% layer_predict() \%>\% layer_cdc_flatline_quantiles(aheads = c(7, 14, 21, 28), symmetrize = TRUE) -eng <- parsnip::linear_reg() \%>\% parsnip::set_engine("flatline") +eng <- linear_reg(engine = "flatline") wf <- epi_workflow(r, eng, f) \%>\% fit(case_death_rate_subset) -preds <- suppressWarnings(predict(wf, new_data = latest)) \%>\% - dplyr::select(-time_value) \%>\% - dplyr::mutate(forecast_date = forecast_date) +preds <- forecast(wf) \%>\% + select(-time_value) \%>\% + mutate(forecast_date = forecast_date) preds preds <- preds \%>\% - unnest(.pred_distn_all) \%>\% + tidyr::unnest(.pred_distn_all) \%>\% pivot_quantiles_wider(.pred_distn) \%>\% mutate(target_date = forecast_date + ahead) diff --git a/man/layer_naomit.Rd b/man/layer_naomit.Rd index 74652daab..d77112f95 100644 --- a/man/layer_naomit.Rd +++ b/man/layer_naomit.Rd @@ -24,16 +24,15 @@ an updated \code{frosting} postprocessor Omit \code{NA}s from predictions or other columns } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) - -latest <- get_test_data(recipe = r, x = jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% @@ -41,6 +40,6 @@ f <- frosting() \%>\% wf1 <- wf \%>\% add_frosting(f) -p <- predict(wf1, latest) +p <- forecast(wf1) p } diff --git a/man/layer_point_from_distn.Rd b/man/layer_point_from_distn.Rd index 7ad69a332..276f7cb17 100644 --- a/man/layer_point_from_distn.Rd +++ b/man/layer_point_from_distn.Rd @@ -34,17 +34,17 @@ information, so one should usually call this AFTER \code{layer_quantile_distn()} or set the \code{name} argument to something specific. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) \%>\% fit(jhu) - -latest <- get_test_data(recipe = r, x = jhu) +wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) \%>\% + fit(jhu) f1 <- frosting() \%>\% layer_predict() \%>\% @@ -53,7 +53,7 @@ f1 <- frosting() \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f1) -p1 <- predict(wf1, latest) +p1 <- forecast(wf1) p1 f2 <- frosting() \%>\% @@ -62,6 +62,6 @@ f2 <- frosting() \%>\% layer_naomit(.pred) wf2 <- wf \%>\% add_frosting(f2) -p2 <- predict(wf2, latest) +p2 <- forecast(wf2) p2 } diff --git a/man/layer_population_scaling.Rd b/man/layer_population_scaling.Rd index 179d6862c..5a105f208 100644 --- a/man/layer_population_scaling.Rd +++ b/man/layer_population_scaling.Rd @@ -74,9 +74,10 @@ passed will \emph{multiply} the selected variables while the \code{rate_rescalin argument is a common \emph{divisor} of the selected variables. } \examples{ -jhu <- epiprocess::jhu_csse_daily_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% - dplyr::select(geo_value, time_value, cases) +library(dplyr) +jhu <- jhu_csse_daily_subset \%>\% + filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% + select(geo_value, time_value, cases) pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) @@ -101,19 +102,9 @@ f <- frosting() \%>\% df_pop_col = "value" ) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) \%>\% add_frosting(f) -latest <- get_test_data( - recipe = r, - x = epiprocess::jhu_csse_daily_subset \%>\% - dplyr::filter( - time_value > "2021-11-01", - geo_value \%in\% c("ca", "ny") - ) \%>\% - dplyr::select(geo_value, time_value, cases) -) - -predict(wf, latest) +forecast(wf) } diff --git a/man/layer_predict.Rd b/man/layer_predict.Rd index 03473053f..8ae92f4c8 100644 --- a/man/layer_predict.Rd +++ b/man/layer_predict.Rd @@ -58,6 +58,7 @@ postprocessing. This would typically be the first layer in a \code{frosting} postprocessor. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) @@ -66,7 +67,7 @@ r <- epi_recipe(jhu) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) latest <- jhu \%>\% filter(time_value >= max(time_value) - 14) # Predict layer alone diff --git a/man/layer_predictive_distn.Rd b/man/layer_predictive_distn.Rd index 2cd374fdb..240db5f5b 100644 --- a/man/layer_predictive_distn.Rd +++ b/man/layer_predictive_distn.Rd @@ -39,17 +39,16 @@ should be reasonably accurate for models fit using \code{lm} when the new point \verb{x*} isn't too far from the bulk of the data. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) - -latest <- get_test_data(recipe = r, x = jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% @@ -57,6 +56,6 @@ f <- frosting() \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f) -p <- predict(wf1, latest) +p <- forecast(wf1) p } diff --git a/man/layer_quantile_distn.Rd b/man/layer_quantile_distn.Rd index 167282760..68192deee 100644 --- a/man/layer_quantile_distn.Rd +++ b/man/layer_quantile_distn.Rd @@ -32,13 +32,22 @@ quantiles will be added to the predictions. } \description{ This function calculates quantiles when the prediction was \emph{distributional}. -Currently, the only distributional engine is \code{quantile_reg()}. -If this engine is used, then this layer will grab out estimated (or extrapolated) -quantiles at the requested quantile values. +} +\details{ +Currently, the only distributional modes/engines are +\itemize{ +\item \code{quantile_reg()} +\item \code{smooth_quantile_reg()} +\item \code{rand_forest(mode = "regression") \%>\% set_engine("grf_quantiles")} +} + +If these engines were used, then this layer will grab out estimated +(or extrapolated) quantiles at the requested quantile values. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% @@ -48,14 +57,12 @@ r <- epi_recipe(jhu) \%>\% wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) \%>\% fit(jhu) -latest <- get_test_data(recipe = r, x = jhu) - f <- frosting() \%>\% layer_predict() \%>\% layer_quantile_distn() \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f) -p <- predict(wf1, latest) +p <- forecast(wf1) p } diff --git a/man/layer_residual_quantiles.Rd b/man/layer_residual_quantiles.Rd index bf0e05be1..39e1ecfbe 100644 --- a/man/layer_residual_quantiles.Rd +++ b/man/layer_residual_quantiles.Rd @@ -39,31 +39,36 @@ residual quantiles added to the prediction Creates predictions based on residual quantiles } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) - -latest <- get_test_data(recipe = r, x = jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% - layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) \%>\% + layer_residual_quantiles( + quantile_levels = c(0.0275, 0.975), + symmetrize = FALSE + ) \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f) -p <- predict(wf1, latest) +p <- forecast(wf1) f2 <- frosting() \%>\% layer_predict() \%>\% - layer_residual_quantiles(quantile_levels = c(0.3, 0.7), by_key = "geo_value") \%>\% + layer_residual_quantiles( + quantile_levels = c(0.3, 0.7), + by_key = "geo_value" + ) \%>\% layer_naomit(.pred) wf2 <- wf \%>\% add_frosting(f2) -p2 <- predict(wf2, latest) +p2 <- forecast(wf2) } diff --git a/man/layer_threshold.Rd b/man/layer_threshold.Rd index 127311ae6..0f4b1dfb7 100644 --- a/man/layer_threshold.Rd +++ b/man/layer_threshold.Rd @@ -40,21 +40,19 @@ smaller than the lower threshold or higher than the upper threshold equal to the threshold values. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value < "2021-03-08", - geo_value \%in\% c("ak", "ca", "ar")) + filter(time_value < "2021-03-08", geo_value \%in\% c("ak", "ca", "ar")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) - -latest <- get_test_data(r, jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% layer_threshold(.pred, lower = 0.180, upper = 0.310) wf <- wf \%>\% add_frosting(f) -p <- predict(wf, latest) +p <- forecast(wf) p } diff --git a/man/nested_quantiles.Rd b/man/nested_quantiles.Rd index 143532650..b34b718ca 100644 --- a/man/nested_quantiles.Rd +++ b/man/nested_quantiles.Rd @@ -16,9 +16,11 @@ a list-col Turn a vector of quantile distributions into a list-col } \examples{ +library(dplyr) +library(tidyr) edf <- case_death_rate_subset[1:3, ] edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) -edf_nested <- edf \%>\% dplyr::mutate(q = nested_quantiles(q)) -edf_nested \%>\% tidyr::unnest(q) +edf_nested <- edf \%>\% mutate(q = nested_quantiles(q)) +edf_nested \%>\% unnest(q) } diff --git a/man/pivot_quantiles_longer.Rd b/man/pivot_quantiles_longer.Rd index f73e6deaf..9879d5d07 100644 --- a/man/pivot_quantiles_longer.Rd +++ b/man/pivot_quantiles_longer.Rd @@ -34,9 +34,9 @@ multiple columns are selected, these will be prefixed with the column name. \examples{ d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:4, 1:3 / 4)) d2 <- c(dist_quantiles(2:4, 2:4 / 5), dist_quantiles(3:5, 2:4 / 5)) -tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) +tib <- tibble(g = c("a", "b"), d1 = d1, d2 = d2) pivot_quantiles_longer(tib, "d1") -pivot_quantiles_longer(tib, tidyselect::ends_with("1")) +pivot_quantiles_longer(tib, dplyr::ends_with("1")) pivot_quantiles_longer(tib, d1, d2) } diff --git a/man/pivot_quantiles_wider.Rd b/man/pivot_quantiles_wider.Rd index 02a33bb2f..e477777ca 100644 --- a/man/pivot_quantiles_wider.Rd +++ b/man/pivot_quantiles_wider.Rd @@ -30,6 +30,6 @@ d2 <- c(dist_quantiles(2:4, 2:4 / 5), dist_quantiles(3:5, 2:4 / 5)) tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) pivot_quantiles_wider(tib, c("d1", "d2")) -pivot_quantiles_wider(tib, tidyselect::starts_with("d")) +pivot_quantiles_wider(tib, dplyr::starts_with("d")) pivot_quantiles_wider(tib, d2) } diff --git a/man/predict-epi_workflow.Rd b/man/predict-epi_workflow.Rd index d92fd8ca9..130279249 100644 --- a/man/predict-epi_workflow.Rd +++ b/man/predict-epi_workflow.Rd @@ -5,7 +5,7 @@ \alias{predict.epi_workflow} \title{Predict from an epi_workflow} \usage{ -\method{predict}{epi_workflow}(object, new_data, ...) +\method{predict}{epi_workflow}(object, new_data, type = NULL, opts = list(), ...) } \arguments{ \item{object}{An epi_workflow that has been fit by @@ -14,6 +14,16 @@ \item{new_data}{A data frame containing the new predictors to preprocess and predict on} +\item{type}{A single character value or \code{NULL}. Possible values +are \code{"numeric"}, \code{"class"}, \code{"prob"}, \code{"conf_int"}, \code{"pred_int"}, +\code{"quantile"}, \code{"time"}, \code{"hazard"}, \code{"survival"}, or \code{"raw"}. When \code{NULL}, +\code{predict()} will choose an appropriate value based on the model's mode.} + +\item{opts}{A list of optional arguments to the underlying +predict function that will be used when \code{type = "raw"}. The +list should not include options for the model object or the +new data being predicted.} + \item{...}{Additional \code{parsnip}-related options, depending on the value of \code{type}. Arguments to the underlying model's prediction function cannot be passed here (use the \code{opts} argument instead). diff --git a/man/quantile_reg.Rd b/man/quantile_reg.Rd index 8e576ac84..5079c3434 100644 --- a/man/quantile_reg.Rd +++ b/man/quantile_reg.Rd @@ -4,24 +4,35 @@ \alias{quantile_reg} \title{Quantile regression} \usage{ -quantile_reg(mode = "regression", engine = "rq", quantile_levels = 0.5) +quantile_reg( + mode = "regression", + engine = "rq", + quantile_levels = 0.5, + method = "br" +) } \arguments{ \item{mode}{A single character string for the type of model. The only possible value for this model is "regression".} \item{engine}{Character string naming the fitting function. Currently, only -"rq" is supported.} +"rq" and "grf" are supported.} \item{quantile_levels}{A scalar or vector of values in (0, 1) to determine which quantiles to estimate (default is 0.5).} + +\item{method}{A fitting method used by \code{\link[quantreg:rq]{quantreg::rq()}}. See the +documentation for a list of options.} } \description{ \code{quantile_reg()} generates a quantile regression model \emph{specification} for the \href{https://www.tidymodels.org/}{tidymodels} framework. Currently, the -only supported engine is "rq" which uses \code{\link[quantreg:rq]{quantreg::rq()}}. +only supported engines are "rq", which uses \code{\link[quantreg:rq]{quantreg::rq()}}. +Quantile regression is also possible by combining \code{\link[parsnip:rand_forest]{parsnip::rand_forest()}} +with the \code{grf} engine. See \link{grf_quantiles}. } \examples{ +library(quantreg) tib <- data.frame(y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100)) rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) \%>\% set_engine("rq") ff <- rq_spec \%>\% fit(y ~ ., data = tib) diff --git a/man/reexports.Rd b/man/reexports.Rd index b23f00698..f6849a53c 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -1,11 +1,16 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/reexports-tidymodels.R +% Please edit documentation in R/autoplot.R, R/reexports-tidymodels.R \docType{import} \name{reexports} \alias{reexports} +\alias{autoplot} \alias{fit} +\alias{forecast} \alias{prep} \alias{bake} +\alias{rand_id} +\alias{tibble} +\alias{tidy} \title{Objects exported from other packages} \keyword{internal} \description{ @@ -13,8 +18,12 @@ These objects are imported from other packages. Follow the links below to see their documentation. \describe{ - \item{generics}{\code{\link[generics]{fit}}} + \item{generics}{\code{\link[generics]{fit}}, \code{\link[generics]{forecast}}, \code{\link[generics]{tidy}}} - \item{recipes}{\code{\link[recipes]{bake}}, \code{\link[recipes]{prep}}} + \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} + + \item{recipes}{\code{\link[recipes]{bake}}, \code{\link[recipes]{prep}}, \code{\link[recipes]{rand_id}}} + + \item{tibble}{\code{\link[tibble]{tibble}}} }} diff --git a/man/smooth_quantile_reg.Rd b/man/smooth_quantile_reg.Rd index bd8c012f2..c6b17dd86 100644 --- a/man/smooth_quantile_reg.Rd +++ b/man/smooth_quantile_reg.Rd @@ -36,6 +36,7 @@ the \href{https://www.tidymodels.org/}{tidymodels} framework. Currently, the only supported engine is \code{\link[smoothqr:smooth_qr]{smoothqr::smooth_qr()}}. } \examples{ +library(smoothqr) tib <- data.frame( y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100), y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100), @@ -75,17 +76,16 @@ lines(pl$x, pl$`0.2`, col = "blue") lines(pl$x, pl$`0.8`, col = "blue") lines(pl$x, pl$`0.5`, col = "red") -if (require("ggplot2")) { - ggplot(data.frame(x = x, y = y), aes(x)) + - geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + - geom_point(aes(y = y), colour = "grey") + # observed data - geom_function(fun = sin, colour = "black") + # truth - geom_vline(xintercept = fd, linetype = "dashed") + # end of training data - geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction - theme_bw() + - coord_cartesian(xlim = c(0, NA)) + - ylab("y") -} +library(ggplot2) +ggplot(data.frame(x = x, y = y), aes(x)) + + geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + + geom_point(aes(y = y), colour = "grey") + # observed data + geom_function(fun = sin, colour = "black") + # truth + geom_vline(xintercept = fd, linetype = "dashed") + # end of training data + geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction + theme_bw() + + coord_cartesian(xlim = c(0, NA)) + + ylab("y") } \seealso{ \code{\link[=fit.model_spec]{fit.model_spec()}}, \code{\link[=set_engine]{set_engine()}} diff --git a/man/step_epi_shift.Rd b/man/step_epi_shift.Rd index bf135346e..2bf22c15d 100644 --- a/man/step_epi_shift.Rd +++ b/man/step_epi_shift.Rd @@ -8,12 +8,10 @@ step_epi_lag( recipe, ..., - role = "predictor", - trained = FALSE, lag, + role = "predictor", prefix = "lag_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_lag") ) @@ -21,12 +19,10 @@ step_epi_lag( step_epi_ahead( recipe, ..., - role = "outcome", - trained = FALSE, ahead, + role = "outcome", prefix = "ahead_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_ahead") ) @@ -38,24 +34,18 @@ sequence of operations for this recipe.} \item{...}{One or more selector functions to choose variables for this step. See \code{\link[recipes:selections]{recipes::selections()}} for more details.} -\item{role}{For model terms created by this step, what analysis role should -they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} - -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} - \item{lag, ahead}{A vector of integers. Each specified column will be the lag or lead for each value in the vector. Lag integers must be nonnegative, while ahead integers must be positive.} -\item{prefix}{A prefix to indicate what type of variable this is} +\item{role}{For model terms created by this step, what analysis role should +they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} + +\item{prefix}{A character string that will be prefixed to the new column.} \item{default}{Determines what fills empty rows left by leading/lagging (defaults to NA).} -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} - \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked when \code{\link[=prep]{prep()}} is run, some operations may not be able to be diff --git a/man/step_epi_slide.Rd b/man/step_epi_slide.Rd new file mode 100644 index 000000000..242f8e312 --- /dev/null +++ b/man/step_epi_slide.Rd @@ -0,0 +1,92 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/step_epi_slide.R +\name{step_epi_slide} +\alias{step_epi_slide} +\title{Calculate a rolling window transformation} +\usage{ +step_epi_slide( + recipe, + ..., + .f, + .window_size = NULL, + .align = c("right", "center", "left"), + role = "predictor", + prefix = "epi_slide_", + f_name = clean_f_name(.f), + skip = FALSE, + id = rand_id("epi_slide") +) +} +\arguments{ +\item{recipe}{A recipe object. The step will be added to the +sequence of operations for this recipe.} + +\item{...}{One or more selector functions to choose variables +for this step. See \code{\link[recipes:selections]{recipes::selections()}} for more details.} + +\item{.f}{A function in one of the following formats: +\enumerate{ +\item An unquoted function name with no arguments, e.g., \code{mean} +\item A character string name of a function, e.g., \code{"mean"}. Note that this +can be difficult to examine for mistakes (so the misspelling \code{"maen"} +won't produce an error until you try to actually fit the model) +\item A base \code{R} lambda function, e.g., \code{function(x) mean(x, na.rm = TRUE)} +\item A new-style base \code{R} lambda function, e.g., \verb{\\(x) mean(x, na.rm = TRUE)} +\item A one-sided formula like \code{~ mean(.x, na.rm = TRUE)}. +} + +Note that in cases 3 and 4, \code{x} can be any variable name you like (for +example \verb{\\(dog) mean(dog, na.rm = TRUE)} will work). But in case 5, the +argument must be named \code{.x}. A common, though very difficult to debug +error is using something like \code{function(x) mean}. This will not work +because it returns the function mean, rather than \code{mean(x)}} + +\item{.window_size}{the size of the sliding window, required. Usually a +non-negative integer will suffice (e.g. for data indexed by date, but more +restrictive in other time_type cases (see \code{\link[epiprocess:epi_slide]{epiprocess::epi_slide()}} for +details). For example, set to 7 for a 7-day window.} + +\item{.align}{a character string indicating how the window should be aligned. +By default, this is "right", meaning the slide_window will be anchored with +its right end point on the reference date. (see \code{\link[epiprocess:epi_slide]{epiprocess::epi_slide()}} +for details).} + +\item{role}{For model terms created by this step, what analysis role should +they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} + +\item{prefix}{A character string that will be prefixed to the new column.} + +\item{f_name}{a character string of at most 20 characters that describes the +function. This will be combined with \code{prefix} and the columns in \code{...} to +name the result using \verb{\{prefix\}\{f_name\}_\{column\}}. By default it will be +determined automatically using \code{clean_f_name()}.} + +\item{skip}{A logical. Should the step be skipped when the +recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked +when \code{\link[=prep]{prep()}} is run, some operations may not be able to be +conducted on new data (e.g. processing the outcome variable(s)). +Care should be taken when using \code{skip = TRUE} as it may affect +the computations for subsequent operations.} + +\item{id}{A unique identifier for the step} +} +\value{ +An updated version of \code{recipe} with the new step added to the +sequence of any existing operations. +} +\description{ +\code{step_epi_slide()} creates a \emph{specification} of a recipe step +that will generate one or more new columns of derived data by "sliding" +a computation along existing data. +} +\examples{ +library(dplyr) +jhu <- case_death_rate_subset \%>\% + filter(time_value >= as.Date("2021-01-01"), geo_value \%in\% c("ca", "ny")) +rec <- epi_recipe(jhu) \%>\% + step_epi_slide(case_rate, death_rate, + .f = \(x) mean(x, na.rm = TRUE), + .window_size = 7L + ) +bake(prep(rec, jhu), new_data = NULL) +} diff --git a/man/step_growth_rate.Rd b/man/step_growth_rate.Rd index b409135b1..bc6da0bef 100644 --- a/man/step_growth_rate.Rd +++ b/man/step_growth_rate.Rd @@ -8,13 +8,11 @@ step_growth_rate( recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, replace_Inf = NA, prefix = "gr_", - columns = NULL, skip = FALSE, id = rand_id("growth_rate"), additional_gr_args_list = list() @@ -30,20 +28,15 @@ for this step. See \code{\link[recipes:selections]{recipes::selections()}} for m \item{role}{For model terms created by this step, what analysis role should they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} - \item{horizon}{Bandwidth for the sliding window, when \code{method} is "rel_change" or "linear_reg". See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}} for more details.} -\item{method}{Either "rel_change", "linear_reg", "smooth_spline", or -"trend_filter", indicating the method to use for the growth rate -calculation. The first two are local methods: they are run in a sliding +\item{method}{Either "rel_change" or "linear_reg", +indicating the method to use for the growth rate +calculation. These are local methods: they are run in a sliding fashion over the sequence (in order to estimate derivatives and hence -growth rates); the latter two are global methods: they are run once over -the entire sequence. See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}} for more -details.} +growth rates). See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}} for more details.} \item{log_scale}{Should growth rates be estimated using the parameterization on the log scale? See details for an explanation. Default is \code{FALSE}.} @@ -56,10 +49,7 @@ being removed from the data. Alternatively, you could specify arbitrary large values, or perhaps zero. Setting this argument to \code{NULL} will result in no replacement.} -\item{prefix}{A prefix to indicate what type of variable this is} - -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} +\item{prefix}{A character string that will be prefixed to the new column.} \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked @@ -88,8 +78,8 @@ r <- epi_recipe(case_death_rate_subset) \%>\% r r \%>\% - recipes::prep() \%>\% - recipes::bake(case_death_rate_subset) + prep(case_death_rate_subset) \%>\% + bake(case_death_rate_subset) } \seealso{ Other row operation steps: diff --git a/man/step_lag_difference.Rd b/man/step_lag_difference.Rd index b06abe43c..7969ea3a7 100644 --- a/man/step_lag_difference.Rd +++ b/man/step_lag_difference.Rd @@ -8,10 +8,8 @@ step_lag_difference( recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, prefix = "lag_diff_", - columns = NULL, skip = FALSE, id = rand_id("lag_diff") ) @@ -26,16 +24,10 @@ for this step. See \code{\link[recipes:selections]{recipes::selections()}} for m \item{role}{For model terms created by this step, what analysis role should they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} - \item{horizon}{Scalar or vector. Time period(s) over which to calculate differences.} -\item{prefix}{A prefix to indicate what type of variable this is} - -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} +\item{prefix}{A character string that will be prefixed to the new column.} \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked @@ -56,12 +48,13 @@ that will generate one or more new columns of derived data. } \examples{ r <- epi_recipe(case_death_rate_subset) \%>\% - step_lag_difference(case_rate, death_rate, horizon = c(7, 14)) + step_lag_difference(case_rate, death_rate, horizon = c(7, 14)) \%>\% + step_epi_naomit() r r \%>\% - recipes::prep() \%>\% - recipes::bake(case_death_rate_subset) + prep(case_death_rate_subset) \%>\% + bake(case_death_rate_subset) } \seealso{ Other row operation steps: diff --git a/man/step_population_scaling.Rd b/man/step_population_scaling.Rd index 1a9564563..294f27f61 100644 --- a/man/step_population_scaling.Rd +++ b/man/step_population_scaling.Rd @@ -8,33 +8,25 @@ step_population_scaling( recipe, ..., role = "raw", - trained = FALSE, df, by = NULL, df_pop_col, rate_rescaling = 1, create_new = TRUE, suffix = "_scaled", - columns = NULL, skip = FALSE, id = rand_id("population_scaling") ) } \arguments{ -\item{recipe}{A recipe object. The step will be added to the sequence of -operations for this recipe. The recipe should contain information about the -\code{epi_df} such as column names.} +\item{recipe}{A recipe object. The step will be added to the +sequence of operations for this recipe.} -\item{...}{One or more selector functions to scale variables +\item{...}{One or more selector functions to choose variables for this step. See \code{\link[recipes:selections]{recipes::selections()}} for more details.} \item{role}{For model terms created by this step, what analysis role should -they be assigned? By default, the new columns created by this step from the -original variables will be used as predictors in a model. Other options can -be ard are not limited to "outcome".} - -\item{trained}{A logical to indicate if the quantities for preprocessing -have been estimated.} +they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} \item{df}{a data frame that contains the population data to be used for inverting the existing scaling.} @@ -68,10 +60,7 @@ scale is "per 100K", then set \code{rate_rescaling = 1e5} to get rates.} in the \code{epi_df}} \item{suffix}{a character. The suffix added to the column name if -\code{crete_new = TRUE}. Default to "_scaled".} - -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} +\code{create_new = TRUE}. Default to "_scaled".} \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked @@ -98,11 +87,10 @@ passed will \emph{divide} the selected variables while the \code{rate_rescaling} argument is a common \emph{multiplier} of the selected variables. } \examples{ -library(epiprocess) -library(epipredict) -jhu <- epiprocess::jhu_csse_daily_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% - dplyr::select(geo_value, time_value, cases) +library(dplyr) +jhu <- jhu_csse_daily_subset \%>\% + filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% + select(geo_value, time_value, cases) pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) @@ -127,20 +115,9 @@ f <- frosting() \%>\% df_pop_col = "value" ) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) \%>\% add_frosting(f) -latest <- get_test_data( - recipe = r, - epiprocess::jhu_csse_daily_subset \%>\% - dplyr::filter( - time_value > "2021-11-01", - geo_value \%in\% c("ca", "ny") - ) \%>\% - dplyr::select(geo_value, time_value, cases) -) - - -predict(wf, latest) +forecast(wf) } diff --git a/man/step_training_window.Rd b/man/step_training_window.Rd index ce7c0fc74..42f6b9a95 100644 --- a/man/step_training_window.Rd +++ b/man/step_training_window.Rd @@ -7,7 +7,6 @@ step_training_window( recipe, role = NA, - trained = FALSE, n_recent = 50, epi_keys = NULL, id = rand_id("training_window") @@ -17,10 +16,8 @@ step_training_window( \item{recipe}{A recipe object. The step will be added to the sequence of operations for this recipe.} -\item{role}{Not used by this step since no new variables are created.} - -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} +\item{role}{For model terms created by this step, what analysis role should +they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} \item{n_recent}{An integer value that represents the number of most recent observations that are to be kept in the training window per key @@ -30,7 +27,7 @@ The default value is 50.} to group on. The default, \code{NULL}, ensures that every key combination is limited.} -\item{id}{A character string that is unique to this step to identify it.} +\item{id}{A unique identifier for the step} } \value{ An updated version of \code{recipe} with the new step added to the @@ -47,13 +44,10 @@ Note that \code{step_epi_lead()} and \code{step_epi_lag()} should come after any filtering step. } \examples{ -tib <- tibble::tibble( +tib <- tibble( x = 1:10, y = 1:10, - time_value = rep(seq(as.Date("2020-01-01"), - by = 1, - length.out = 5 - ), times = 2), + time_value = rep(seq(as.Date("2020-01-01"), by = 1, length.out = 5), 2), geo_value = rep(c("ca", "hi"), each = 5) ) \%>\% as_epi_df() @@ -64,7 +58,7 @@ epi_recipe(y ~ x, data = tib) \%>\% bake(new_data = NULL) epi_recipe(y ~ x, data = tib) \%>\% - recipes::step_naomit() \%>\% + step_epi_naomit() \%>\% step_training_window(n_recent = 3) \%>\% prep(tib) \%>\% bake(new_data = NULL) diff --git a/man/tidy.frosting.Rd b/man/tidy.frosting.Rd index 6b28461b4..ba3c0f3d5 100644 --- a/man/tidy.frosting.Rd +++ b/man/tidy.frosting.Rd @@ -37,8 +37,9 @@ method for the operation exists). Note that this is a modified version of the \code{tidy} method for a recipe. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% diff --git a/man/update.layer.Rd b/man/update.layer.Rd index 0f1fe9c22..9604992e1 100644 --- a/man/update.layer.Rd +++ b/man/update.layer.Rd @@ -18,15 +18,15 @@ will replace the elements of the same name in the actual post-processing layer. Analogous to \code{update.step()} from the \code{recipes} package. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) -latest <- jhu \%>\% - dplyr::filter(time_value >= max(time_value) - 14) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) +latest <- jhu \%>\% filter(time_value >= max(time_value) - 14) # Specify a `forecast_date` that is greater than or equal to `as_of` date f <- frosting() \%>\% diff --git a/man/weighted_interval_score.Rd b/man/weighted_interval_score.Rd new file mode 100644 index 000000000..4907e2724 --- /dev/null +++ b/man/weighted_interval_score.Rd @@ -0,0 +1,97 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/weighted_interval_score.R +\name{weighted_interval_score} +\alias{weighted_interval_score} +\alias{weighted_interval_score.dist_quantiles} +\title{Compute weighted interval score} +\usage{ +weighted_interval_score(x, actual, quantile_levels = NULL, ...) + +\method{weighted_interval_score}{dist_quantiles}( + x, + actual, + quantile_levels = NULL, + na_handling = c("impute", "drop", "propagate", "fail"), + ... +) +} +\arguments{ +\item{x}{distribution. A vector of class distribution. Ideally, this vector +contains \code{dist_quantiles()}, though other distributions are supported when +\code{quantile_levels} is specified. See below.} + +\item{actual}{double. Actual value(s)} + +\item{quantile_levels}{probabilities. If specified, the score will be +computed at this set of levels.} + +\item{...}{not used} + +\item{na_handling}{character. Determines how \code{quantile_levels} without a +corresponding \code{value} are handled. For \code{"impute"}, missing values will be +calculated if possible using the available quantiles. For \code{"drop"}, +explicitly missing values are ignored in the calculation of the score, but +implicitly missing values are imputed if possible. +For \code{"propogate"}, the resulting score will be \code{NA} if any missing values +exist in the original \code{quantile_levels}. Finally, if +\code{quantile_levels} is specified, \code{"fail"} will result in +the score being \code{NA} when any required quantile levels (implicit or explicit) +are do not have corresponding values.} +} +\value{ +a vector of nonnegative scores. +} +\description{ +Weighted interval score (WIS), a well-known quantile-based +approximation of the commonly-used continuous ranked probability score +(CRPS). WIS is a proper score, and can be thought of as a distributional +generalization of absolute error. For example, see \href{https://arxiv.org/abs/2005.12881}{Bracher et al. (2020)} for discussion in the context +of COVID-19 forecasting. +} +\section{Methods (by class)}{ +\itemize{ +\item \code{weighted_interval_score(dist_quantiles)}: Weighted interval score with +\code{dist_quantiles} allows for different \code{NA} behaviours. + +}} +\examples{ +quantile_levels <- c(.2, .4, .6, .8) +predq_1 <- 1:4 # +predq_2 <- 8:11 +dstn <- dist_quantiles(list(predq_1, predq_2), quantile_levels) +actual <- c(3.3, 7.1) +weighted_interval_score(dstn, actual) +weighted_interval_score(dstn, actual, c(.25, .5, .75)) + +library(distributional) +dstn <- dist_normal(c(.75, 2)) +weighted_interval_score(dstn, 1, c(.25, .5, .75)) + +# Missing value behaviours +dstn <- dist_quantiles(c(1, 2, NA, 4), 1:4 / 5) +weighted_interval_score(dstn, 2.5) +weighted_interval_score(dstn, 2.5, 1:9 / 10) +weighted_interval_score(dstn, 2.5, 1:9 / 10, na_handling = "drop") +weighted_interval_score(dstn, 2.5, na_handling = "propagate") +weighted_interval_score(dist_quantiles(1:4, 1:4 / 5), 2.5, 1:9 / 10, + na_handling = "fail" +) + + +# Using some actual forecasts -------- +library(dplyr) +jhu <- case_death_rate_subset \%>\% + filter(time_value >= "2021-10-01", time_value <= "2021-12-01") +preds <- flatline_forecaster( + jhu, "death_rate", + flatline_args_list(quantile_levels = c(.01, .025, 1:19 / 20, .975, .99)) +)$predictions +actuals <- case_death_rate_subset \%>\% + filter(time_value == as.Date("2021-12-01") + 7) \%>\% + select(geo_value, time_value, actual = death_rate) +preds <- left_join(preds, actuals, + by = c("target_date" = "time_value", "geo_value") +) \%>\% + mutate(wis = weighted_interval_score(.pred_distn, actual)) +preds +} diff --git a/tests/testthat/_snaps/arg_is_.md b/tests/testthat/_snaps/arg_is_.md new file mode 100644 index 000000000..fcb823f2a --- /dev/null +++ b/tests/testthat/_snaps/arg_is_.md @@ -0,0 +1,384 @@ +# logical + + Code + arg_is_lgl(l, ll, n) + Condition + Error: + ! `n` must be of type . + +--- + + Code + arg_is_lgl(x) + Condition + Error: + ! `x` must be of type . + +--- + + Code + arg_is_lgl(l, ll, nn) + Condition + Error: + ! `nn` must be of type . + +# scalar + + Code + arg_is_scalar(x, y, n) + Condition + Error: + ! `n` must be a scalar. + +--- + + Code + arg_is_scalar(x, y, nn) + Condition + Error: + ! `nn` must be a scalar. + +--- + + Code + arg_is_scalar(v, nn) + Condition + Error: + ! `v` must be a scalar. + +--- + + Code + arg_is_scalar(v, nn, allow_na = TRUE) + Condition + Error: + ! `v` must be a scalar. + +--- + + Code + arg_is_scalar(v, n, allow_null = TRUE) + Condition + Error: + ! `v` must be a scalar. + +--- + + Code + arg_is_scalar(nnn, allow_na = TRUE) + Condition + Error: + ! `nnn` must be a scalar. + +# numeric + + Code + arg_is_numeric(a) + Condition + Error: + ! `a` must be of type . + +--- + + Code + arg_is_numeric(i, j, n) + Condition + Error: + ! `n` must be of type . + +--- + + Code + arg_is_numeric(i, nn) + Condition + Error: + ! `nn` must be of type . + +# positive + + Code + arg_is_pos(a) + Condition + Error: + ! `a` must be a strictly positive number. + +--- + + Code + arg_is_pos(i, k) + Condition + Error: + ! `k` must be strictly positive numbers. + +--- + + Code + arg_is_pos(i, j, n) + Condition + Error: + ! `n` must be strictly positive numbers. + +--- + + Code + arg_is_pos(i, nn) + Condition + Error: + ! `nn` must be a strictly positive number. + +--- + + Code + arg_is_pos(a = 0:10) + Condition + Error: + ! `0:10` must be strictly positive numbers. + +# nonneg + + Code + arg_is_nonneg(a) + Condition + Error: + ! `a` must be a non-negative number. + +--- + + Code + arg_is_nonneg(i, k) + Condition + Error: + ! `k` must be non-negative numbers. + +--- + + Code + arg_is_nonneg(i, j, n) + Condition + Error: + ! `n` must be non-negative numbers. + +--- + + Code + arg_is_nonneg(i, nn) + Condition + Error: + ! `nn` must be a non-negative number. + +# nonneg-int + + Code + arg_is_nonneg_int(a) + Condition + Error: + ! `a` must be a non-negative integer. + +--- + + Code + arg_is_nonneg_int(d) + Condition + Error: + ! `d` must be a non-negative integer. + +--- + + Code + arg_is_nonneg_int(i, k) + Condition + Error: + ! `k` must be non-negative integers. + +--- + + Code + arg_is_nonneg_int(i, j, n) + Condition + Error: + ! `n` must be non-negative integers. + +--- + + Code + arg_is_nonneg_int(i, nn) + Condition + Error: + ! `nn` must be a non-negative integer. + +# date + + Code + arg_is_date(d, dd, n) + Condition + Error: + ! `n` must be dates. + +--- + + Code + arg_is_date(d, dd, nn) + Condition + Error: + ! `nn` must be a date. + +--- + + Code + arg_is_date(a) + Condition + Error: + ! `a` must be a date. + +--- + + Code + arg_is_date(v) + Condition + Error: + ! `v` must be dates. + +--- + + Code + arg_is_date(ll) + Condition + Error: + ! `ll` must be dates. + +# probabilities + + Code + arg_is_probabilities(a) + Condition + Error: + ! `a` must lie in [0, 1]. + +--- + + Code + arg_is_probabilities(d) + Condition + Error: + ! `d` must lie in [0, 1]. + +--- + + Code + arg_is_probabilities(i, 1.1) + Condition + Error: + ! `1.1` must lie in [0, 1]. + +--- + + Code + arg_is_probabilities(c(0.4, 0.8), n) + Condition + Error: + ! `n` must lie in [0, 1]. + +--- + + Code + arg_is_probabilities(c(0.4, 0.8), nn) + Condition + Error: + ! `nn` must lie in [0, 1]. + +# chr + + Code + arg_is_chr(a, b, n) + Condition + Error: + ! `n` must be of type . + +--- + + Code + arg_is_chr(a, b, nn) + Condition + Error: + ! `nn` must be of type . + +--- + + Code + arg_is_chr(d) + Condition + Error: + ! `d` must be of type . + +--- + + Code + arg_is_chr(v) + Condition + Error: + ! `v` must be of type . + +--- + + Code + arg_is_chr(ll) + Condition + Error: + ! `ll` must be of type . + +--- + + Code + arg_is_chr(z) + Condition + Error: + ! `z` must be of type . + +# function + + Code + arg_is_function(c(a, b)) + Condition + Error: + ! `c(a, b)` must be of type . + +--- + + Code + arg_is_function(c(f, g)) + Condition + Error: + ! `c(f, g)` must be of type . + +--- + + Code + arg_is_function(f) + Condition + Error: + ! `f` must be of type . + +# coerce scalar to date + + Code + arg_to_date("12345") + Condition + Error in `arg_to_date()`: + ! `x` must be a date. + +--- + + Code + arg_to_date(c("12345", "12345")) + Condition + Error in `arg_to_date()`: + ! `x` must be a scalar. + +# simple surface step test + + Code + epi_recipe(jhu_csse_daily_subset) %>% step_epi_lag(death_rate, lag = "hello") + Condition + Error in `step_epi_lag()`: + ! `lag` must be a non-negative integer. + diff --git a/tests/testthat/_snaps/arx_args_list.md b/tests/testthat/_snaps/arx_args_list.md new file mode 100644 index 000000000..959a5e25b --- /dev/null +++ b/tests/testthat/_snaps/arx_args_list.md @@ -0,0 +1,152 @@ +# arx_args checks inputs + + Code + arx_args_list(ahead = c(0, 4)) + Condition + Error in `arx_args_list()`: + ! `ahead` must be a scalar. + +--- + + Code + arx_args_list(n_training = c(28, 65)) + Condition + Error in `arx_args_list()`: + ! `n_training` must be a scalar. + +--- + + Code + arx_args_list(ahead = -1) + Condition + Error in `arx_args_list()`: + ! `ahead` must be a non-negative integer. + +--- + + Code + arx_args_list(ahead = 1.5) + Condition + Error in `arx_args_list()`: + ! `ahead` must be a non-negative integer. + +--- + + Code + arx_args_list(n_training = -1) + Condition + Error in `arx_args_list()`: + ! `n_training` must be a strictly positive number. + +--- + + Code + arx_args_list(n_training = 1.5) + Condition + Error in `arx_args_list()`: + ! `n_training` must be a positive integer. + +--- + + Code + arx_args_list(lags = c(-1, 0)) + Condition + Error in `arx_args_list()`: + ! `lags` must be non-negative integers. + +--- + + Code + arx_args_list(lags = list(c(1:5, 6.5), 2:8)) + Condition + Error in `arx_args_list()`: + ! `lags` must be non-negative integers. + +--- + + Code + arx_args_list(symmetrize = 4) + Condition + Error in `arx_args_list()`: + ! `symmetrize` must be of type . + +--- + + Code + arx_args_list(nonneg = 4) + Condition + Error in `arx_args_list()`: + ! `nonneg` must be of type . + +--- + + Code + arx_args_list(quantile_levels = -0.1) + Condition + Error in `arx_args_list()`: + ! `quantile_levels` must lie in [0, 1]. + +--- + + Code + arx_args_list(quantile_levels = 1.1) + Condition + Error in `arx_args_list()`: + ! `quantile_levels` must lie in [0, 1]. + +--- + + Code + arx_args_list(target_date = "2022-01-01") + Condition + Error in `arx_args_list()`: + ! `target_date` must be a date. + +--- + + Code + arx_args_list(n_training_min = "de") + Condition + Error in `arx_args_list()`: + ! `...` must be empty. + x Problematic argument: + * n_training_min = "de" + +--- + + Code + arx_args_list(epi_keys = 1) + Condition + Error in `arx_args_list()`: + ! `...` must be empty. + x Problematic argument: + * epi_keys = 1 + +# arx forecaster disambiguates quantiles + + Code + compare_quantile_args(alist, tlist) + Condition + Error in `compare_quantile_args()`: + ! You have specified different, non-default, quantiles in the trainier and `arx_args` options. + i Please only specify quantiles in one location. + +# arx_lags_validator handles named & unnamed lists as expected + + Code + arx_lags_validator(pred_vec, lags_finit_fn_switch2) + Condition + Error in `arx_lags_validator()`: + ! You have requested 2 predictor(s) but 3 different lags. + i Lags must be a vector or a list with length == number of predictors. + +--- + + Code + arx_lags_validator(pred_vec, lags_init_other_name) + Condition + Error in `arx_lags_validator()`: + ! If lags is a named list, then all predictors must be present. + i The predictors are `death_rate` and `case_rate`. + i So lags is missing `case_rate`'. + diff --git a/tests/testthat/_snaps/arx_cargs_list.md b/tests/testthat/_snaps/arx_cargs_list.md new file mode 100644 index 000000000..30ccb4d36 --- /dev/null +++ b/tests/testthat/_snaps/arx_cargs_list.md @@ -0,0 +1,92 @@ +# arx_class_args checks inputs + + Code + arx_class_args_list(ahead = c(0, 4)) + Condition + Error in `arx_class_args_list()`: + ! `ahead` must be a scalar. + +--- + + Code + arx_class_args_list(n_training = c(28, 65)) + Condition + Error in `arx_class_args_list()`: + ! `n_training` must be a scalar. + +--- + + Code + arx_class_args_list(ahead = -1) + Condition + Error in `arx_class_args_list()`: + ! `ahead` must be a non-negative integer. + +--- + + Code + arx_class_args_list(ahead = 1.5) + Condition + Error in `arx_class_args_list()`: + ! `ahead` must be a non-negative integer. + +--- + + Code + arx_class_args_list(n_training = -1) + Condition + Error in `arx_class_args_list()`: + ! `n_training` must be a strictly positive number. + +--- + + Code + arx_class_args_list(n_training = 1.5) + Condition + Error in `arx_class_args_list()`: + ! `n_training` must be a positive integer. + +--- + + Code + arx_class_args_list(lags = c(-1, 0)) + Condition + Error in `arx_class_args_list()`: + ! `lags` must be non-negative integers. + +--- + + Code + arx_class_args_list(lags = list(c(1:5, 6.5), 2:8)) + Condition + Error in `arx_class_args_list()`: + ! `lags` must be non-negative integers. + +--- + + Code + arx_class_args_list(target_date = "2022-01-01") + Condition + Error in `arx_class_args_list()`: + ! `target_date` must be a date. + +--- + + Code + arx_class_args_list(n_training_min = "de") + Condition + Error in `arx_class_args_list()`: + ! `...` must be empty. + x Problematic argument: + * n_training_min = "de" + +--- + + Code + arx_class_args_list(epi_keys = 1) + Condition + Error in `arx_class_args_list()`: + ! `...` must be empty. + x Problematic argument: + * epi_keys = 1 + diff --git a/tests/testthat/_snaps/bake-method.md b/tests/testthat/_snaps/bake-method.md new file mode 100644 index 000000000..eee28cc4b --- /dev/null +++ b/tests/testthat/_snaps/bake-method.md @@ -0,0 +1,9 @@ +# bake method works in all cases + + Code + bake(prep(r, edf), NULL, composition = "matrix") + Condition + Error in `hardhat::recompose()`: + ! `data` must only contain numeric columns. + i These columns aren't numeric: "geo_value" and "time_value". + diff --git a/tests/testthat/_snaps/check-training-set.md b/tests/testthat/_snaps/check-training-set.md new file mode 100644 index 000000000..e5eec7e7c --- /dev/null +++ b/tests/testthat/_snaps/check-training-set.md @@ -0,0 +1,20 @@ +# training set validation works + + Code + validate_meta_match(t1, template, "geo_type", "abort") + Condition + Error in `validate_meta_match()`: + ! The `geo_type` of the training data appears to be different from that + used to construct the recipe. This may result in unexpected consequences. + i Training `geo_type` is 'county'. + i Originally, it was 'state'. + +--- + + Code + epi_check_training_set(t4, rec) + Condition + Error in `epi_check_training_set()`: + ! The recipe specifies keys which are not in the training data. + i The training set is missing columns for missing_col. + diff --git a/tests/testthat/_snaps/check_enough_train_data.md b/tests/testthat/_snaps/check_enough_train_data.md new file mode 100644 index 000000000..8f2389acb --- /dev/null +++ b/tests/testthat/_snaps/check_enough_train_data.md @@ -0,0 +1,46 @@ +# check_enough_train_data works on pooled data + + Code + epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n + 1, + drop_na = FALSE) %>% prep(toy_epi_df) %>% bake(new_data = NULL) + Condition + Error in `prep()`: + ! The following columns don't have enough data to predict: x and y. + +--- + + Code + epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n - 1, + drop_na = TRUE) %>% prep(toy_epi_df) %>% bake(new_data = NULL) + Condition + Error in `prep()`: + ! The following columns don't have enough data to predict: x and y. + +# check_enough_train_data works on unpooled data + + Code + epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = n + 1, epi_keys = "geo_value", + drop_na = FALSE) %>% prep(toy_epi_df) %>% bake(new_data = NULL) + Condition + Error in `prep()`: + ! The following columns don't have enough data to predict: x and y. + +--- + + Code + epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n - 3, + epi_keys = "geo_value", drop_na = TRUE) %>% prep(toy_epi_df) %>% bake(new_data = NULL) + Condition + Error in `prep()`: + ! The following columns don't have enough data to predict: x and y. + +# check_enough_train_data works with all_predictors() downstream of constructed terms + + Code + epi_recipe(toy_epi_df) %>% step_epi_lag(x, lag = c(1, 2)) %>% + check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>% prep( + toy_epi_df) %>% bake(new_data = NULL) + Condition + Error in `prep()`: + ! The following columns don't have enough data to predict: lag_1_x, lag_2_x, and y. + diff --git a/tests/testthat/_snaps/dist_quantiles.md b/tests/testthat/_snaps/dist_quantiles.md new file mode 100644 index 000000000..da7e50100 --- /dev/null +++ b/tests/testthat/_snaps/dist_quantiles.md @@ -0,0 +1,56 @@ +# constructor returns reasonable quantiles + + Code + new_quantiles(rnorm(5), rnorm(5)) + Condition + Error in `new_quantiles()`: + ! `quantile_levels` must lie in [0, 1]. + +--- + + Code + new_quantiles(sort(rnorm(5)), sort(runif(2))) + Condition + Error in `new_quantiles()`: + ! length(values) == length(quantile_levels) is not TRUE + +--- + + Code + new_quantiles(c(2, 1, 3, 4, 5), c(0.1, 0.1, 0.2, 0.5, 0.8)) + Condition + Error in `new_quantiles()`: + ! !vctrs::vec_duplicate_any(quantile_levels) is not TRUE + +--- + + Code + new_quantiles(c(2, 1, 3, 4, 5), c(0.1, 0.15, 0.2, 0.5, 0.8)) + Condition + Error in `new_quantiles()`: + ! `values[order(quantile_levels)]` produces unsorted quantiles. + +--- + + Code + new_quantiles(c(1, 2, 3), c(0.1, 0.2, 3)) + Condition + Error in `new_quantiles()`: + ! `quantile_levels` must lie in [0, 1]. + +# arithmetic works on quantiles + + Code + sum(dstn) + Condition + Error in `mapply()`: + ! You can't perform arithmetic between two distributions like this. + +--- + + Code + suppressWarnings(dstn + distributional::dist_normal()) + Condition + Error: + ! non-numeric argument to binary operator + diff --git a/tests/testthat/_snaps/enframer.md b/tests/testthat/_snaps/enframer.md new file mode 100644 index 000000000..4b05dbff3 --- /dev/null +++ b/tests/testthat/_snaps/enframer.md @@ -0,0 +1,32 @@ +# enframer errors/works as needed + + Code + enframer(1:5, letters[1]) + Condition + Error in `enframer()`: + ! is.data.frame(df) is not TRUE + +--- + + Code + enframer(data.frame(a = 1:5), 1:3) + Condition + Error in `enframer()`: + ! `x` must be of type . + +--- + + Code + enframer(data.frame(a = 1:5), letters[1:3]) + Condition + Error in `enframer()`: + ! In enframer: some new cols match existing column names + +--- + + Code + enframer(data.frame(aa = 1:5), letters[1:2], fill = 1:4) + Condition + Error in `enframer()`: + ! length(fill) == 1 || length(fill) == nrow(df) is not TRUE + diff --git a/tests/testthat/_snaps/epi_recipe.md b/tests/testthat/_snaps/epi_recipe.md new file mode 100644 index 000000000..24b046678 --- /dev/null +++ b/tests/testthat/_snaps/epi_recipe.md @@ -0,0 +1,32 @@ +# epi_recipe produces error if not an epi_df + + Code + epi_recipe(tib) + Condition + Error in `epi_recipe()`: + ! `x` must be an or a , not a . + +--- + + Code + epi_recipe(y ~ x, tib) + Condition + Error in `epi_recipe()`: + ! `epi_recipe()` has been called with a non- object. Use `recipe()` instead. + +--- + + Code + epi_recipe(m) + Condition + Error in `epi_recipe()`: + ! `x` must be an or a , not a . + +# add/update/adjust/remove epi_recipe works as intended + + Code + workflows::extract_preprocessor(wf)$steps + Condition + Error in `workflows::extract_preprocessor()`: + ! The workflow does not have a preprocessor. + diff --git a/tests/testthat/_snaps/epi_workflow.md b/tests/testthat/_snaps/epi_workflow.md new file mode 100644 index 000000000..006333423 --- /dev/null +++ b/tests/testthat/_snaps/epi_workflow.md @@ -0,0 +1,33 @@ +# model can be added/updated/removed from epi_workflow + + Code + extract_spec_parsnip(wf) + Condition + Error in `extract_spec_parsnip()`: + ! The workflow does not have a model spec. + +# forecast method errors when workflow not fit + + Code + forecast(wf) + Condition + Error in `forecast()`: + ! You cannot `forecast()` a that has not been trained. + i Please use `fit()` before forecasting. + +# fit method does not silently drop the class + + Code + epi_recipe(y ~ x, data = tbl) + Condition + Error in `epi_recipe()`: + ! `epi_recipe()` has been called with a non- object. Use `recipe()` instead. + +--- + + Code + ewf_erec_edf %>% fit(tbl) + Condition + Error in `if (new_meta != old_meta) ...`: + ! argument is of length zero + diff --git a/tests/testthat/_snaps/extract_argument.md b/tests/testthat/_snaps/extract_argument.md new file mode 100644 index 000000000..d4ff44c95 --- /dev/null +++ b/tests/testthat/_snaps/extract_argument.md @@ -0,0 +1,72 @@ +# layer argument extractor works + + Code + extract_argument(f$layers[[1]], "uhoh", "bubble") + Condition + Error in `extract_argument()`: + ! Requested "uhoh" not found. This is a(n) . + +--- + + Code + extract_argument(f$layers[[1]], "layer_predict", "bubble") + Condition + Error in `extract_argument()`: + ! Requested argument "bubble" not found in "layer_predict". + +--- + + Code + extract_argument(f, "layer_thresh", "quantile_levels") + Condition + Error in `extract_argument()`: + ! frosting object does not contain a "layer_thresh". + +--- + + Code + extract_argument(epi_workflow(), "layer_residual_quantiles", "quantile_levels") + Condition + Error in `extract_frosting()`: + ! The epi_workflow does not have a postprocessor. + +--- + + Code + extract_argument(wf, "layer_predict", c("type", "opts")) + Condition + Error in `FUN()`: + ! `arg` must be a scalar of type . + +# recipe argument extractor works + + Code + extract_argument(r$steps[[1]], "uhoh", "bubble") + Condition + Error in `extract_argument()`: + ! Requested "uhoh" not found. This is a . + +--- + + Code + extract_argument(r$steps[[1]], "step_epi_lag", "bubble") + Condition + Error in `extract_argument()`: + ! Requested argument "bubble" not found in "step_epi_lag". + +--- + + Code + extract_argument(r, "step_lightly", "quantile_levels") + Condition + Error in `extract_argument()`: + ! recipe object does not contain a "step_lightly". + +--- + + Code + extract_argument(epi_workflow(), "step_epi_lag", "lag") + Condition + Error in `extract_argument()`: + ! The workflow must have a recipe preprocessor. + diff --git a/tests/testthat/_snaps/flatline_args_list.md b/tests/testthat/_snaps/flatline_args_list.md new file mode 100644 index 000000000..02053f95b --- /dev/null +++ b/tests/testthat/_snaps/flatline_args_list.md @@ -0,0 +1,128 @@ +# flatline_args_list checks inputs + + Code + flatline_args_list(ahead = c(0, 4)) + Condition + Error in `flatline_args_list()`: + ! `ahead` must be a scalar. + +--- + + Code + flatline_args_list(n_training = c(28, 65)) + Condition + Error in `flatline_args_list()`: + ! `n_training` must be a scalar. + +--- + + Code + flatline_args_list(ahead = -1) + Condition + Error in `flatline_args_list()`: + ! `ahead` must be a non-negative integer. + +--- + + Code + flatline_args_list(ahead = 1.5) + Condition + Error in `flatline_args_list()`: + ! `ahead` must be a non-negative integer. + +--- + + Code + flatline_args_list(n_training = -1) + Condition + Error in `flatline_args_list()`: + ! `n_training` must be a strictly positive number. + +--- + + Code + flatline_args_list(n_training = 1.5) + Condition + Error in `flatline_args_list()`: + ! `n_training` must be a positive integer. + +--- + + Code + flatline_args_list(lags = c(-1, 0)) + Condition + Error in `flatline_args_list()`: + ! `...` must be empty. + x Problematic argument: + * lags = c(-1, 0) + +--- + + Code + flatline_args_list(lags = list(c(1:5, 6.5), 2:8)) + Condition + Error in `flatline_args_list()`: + ! `...` must be empty. + x Problematic argument: + * lags = list(c(1:5, 6.5), 2:8) + +--- + + Code + flatline_args_list(symmetrize = 4) + Condition + Error in `flatline_args_list()`: + ! `symmetrize` must be of type . + +--- + + Code + flatline_args_list(nonneg = 4) + Condition + Error in `flatline_args_list()`: + ! `nonneg` must be of type . + +--- + + Code + flatline_args_list(quantile_levels = -0.1) + Condition + Error in `flatline_args_list()`: + ! `quantile_levels` must lie in [0, 1]. + +--- + + Code + flatline_args_list(quantile_levels = 1.1) + Condition + Error in `flatline_args_list()`: + ! `quantile_levels` must lie in [0, 1]. + +--- + + Code + flatline_args_list(target_date = "2022-01-01") + Condition + Error in `flatline_args_list()`: + ! `target_date` must be a date. + +--- + + Code + flatline_args_list(n_training_min = "de") + Condition + Error in `flatline_args_list()`: + ! `...` must be empty. + x Problematic argument: + * n_training_min = "de" + +--- + + Code + flatline_args_list(epi_keys = 1) + Condition + Error in `flatline_args_list()`: + ! `...` must be empty. + x Problematic argument: + * epi_keys = 1 + diff --git a/tests/testthat/_snaps/frosting.md b/tests/testthat/_snaps/frosting.md new file mode 100644 index 000000000..daf7f1ed7 --- /dev/null +++ b/tests/testthat/_snaps/frosting.md @@ -0,0 +1,16 @@ +# frosting validators / constructors work + + Code + wf %>% add_postprocessor(list()) + Condition + Error: + ! `postprocessor` must be a frosting object. + +# frosting can be created/added/updated/adjusted/removed + + Code + frosting(layers = 1:5) + Condition + Error in `frosting()`: + ! Currently, no arguments to `frosting()` are allowed to be non-null. + diff --git a/tests/testthat/_snaps/get_test_data.md b/tests/testthat/_snaps/get_test_data.md new file mode 100644 index 000000000..e65b0715c --- /dev/null +++ b/tests/testthat/_snaps/get_test_data.md @@ -0,0 +1,66 @@ +# expect insufficient training data error + + Code + get_test_data(recipe = r, x = case_death_rate_subset) + Condition + Error in `get_test_data()`: + ! You supplied insufficient recent data for this recipe. + ! You need at least 367 days of data, + ! but `x` contains only 365. + +# expect error that geo_value or time_value does not exist + + Code + get_test_data(recipe = r, x = wrong_epi_df) + Condition + Error in `get_test_data()`: + ! `x` must be an `epi_df`. + +# NA fill behaves as desired + + Code + get_test_data(r, df, "A") + Condition + Error in `get_test_data()`: + ! `fill_locf` must be of type . + +--- + + Code + get_test_data(r, df, TRUE, -3) + Condition + Error in `get_test_data()`: + ! `n_recent` must be a positive integer. + +--- + + Code + get_test_data(r, df2, TRUE) + Condition + Error in `if (recipes::is_trained(recipe)) ...`: + ! argument is of length zero + +# forecast date behaves + + Code + get_test_data(r, df, TRUE, forecast_date = 9) + Condition + Error in `get_test_data()`: + ! `forecast_date` must be the same class as `x$time_value`. + +--- + + Code + get_test_data(r, df, TRUE, forecast_date = 9L) + Condition + Error in `get_test_data()`: + ! `forecast_date` must be no earlier than `max(x$time_value)` + +--- + + Code + get_test_data(r, df, forecast_date = 9L) + Condition + Error in `get_test_data()`: + ! `forecast_date` must be no earlier than `max(x$time_value)` + diff --git a/tests/testthat/_snaps/layer_add_forecast_date.md b/tests/testthat/_snaps/layer_add_forecast_date.md new file mode 100644 index 000000000..9e829be91 --- /dev/null +++ b/tests/testthat/_snaps/layer_add_forecast_date.md @@ -0,0 +1,42 @@ +# layer validation works + + Code + layer_add_forecast_date(f, c("2022-05-31", "2022-05-31")) + Condition + Error in `layer_add_forecast_date()`: + ! `forecast_date` must be a scalar. + +--- + + Code + layer_add_forecast_date(f, "2022-05-31", id = 2) + Condition + Error in `layer_add_forecast_date()`: + ! `id` must be a scalar of type . + +--- + + Code + layer_add_forecast_date(f, "2022-05-31", id = c("a", "b")) + Condition + Error in `layer_add_forecast_date()`: + ! `id` must be a scalar of type . + +# forecast date works for daily + + Code + predict(wf1, latest_yearly) + Condition + Error: + ! Can't convert `data$time_value` to match type of `time_value` . + +--- + + Code + predict(wf3, latest) + Condition + Error in `layer_add_forecast_date()`: + ! The `forecast_date` was given as a "year" while the + ! `time_type` of the training data was "day". + i See `?epiprocess::epi_df` for descriptions of these are determined. + diff --git a/tests/testthat/_snaps/layer_add_target_date.md b/tests/testthat/_snaps/layer_add_target_date.md new file mode 100644 index 000000000..805a4205d --- /dev/null +++ b/tests/testthat/_snaps/layer_add_target_date.md @@ -0,0 +1,8 @@ +# target date works for daily and yearly + + Code + predict(wf1, latest_bad) + Condition + Error: + ! Can't convert `data$time_value` to match type of `time_value` . + diff --git a/tests/testthat/_snaps/layer_predict.md b/tests/testthat/_snaps/layer_predict.md new file mode 100644 index 000000000..5c353eb4c --- /dev/null +++ b/tests/testthat/_snaps/layer_predict.md @@ -0,0 +1,8 @@ +# layer_predict dots validation + + Code + predict(wf_bad_arg, latest) + Condition + Error: + ! argument "..3" is missing, with no default + diff --git a/tests/testthat/_snaps/layer_residual_quantiles.md b/tests/testthat/_snaps/layer_residual_quantiles.md new file mode 100644 index 000000000..41aa0448d --- /dev/null +++ b/tests/testthat/_snaps/layer_residual_quantiles.md @@ -0,0 +1,18 @@ +# Errors when used with a classifier + + Code + forecast(wf) + Condition + Error in `grab_residuals()`: + ! For meaningful residuals, the predictor should be a regression model. + +# flatline_forecaster correctly errors when n_training < ahead + + Code + flatline_forecaster(jhu, "death_rate", args_list = flatline_args_list(ahead = 10, + n_training = 9)) + Condition + Error in `slather()`: + ! Residual quantiles could not be calculated due to missing residuals. + i This may be due to `n_train` < `ahead` in your . + diff --git a/tests/testthat/_snaps/layers.md b/tests/testthat/_snaps/layers.md new file mode 100644 index 000000000..a0474eab6 --- /dev/null +++ b/tests/testthat/_snaps/layers.md @@ -0,0 +1,24 @@ +# A layer can be updated in frosting + + Code + update(f$layers[[1]], lower = 100) + Condition + Error in `recipes:::update_fields()`: + ! The step you are trying to update, `layer_predict()`, does not have the lower field. + +--- + + Code + update(f$layers[[3]], lower = 100) + Condition + Error in `f$layers[[3]]`: + ! subscript out of bounds + +--- + + Code + update(f$layers[[2]], bad_param = 100) + Condition + Error in `recipes:::update_fields()`: + ! The step you are trying to update, `layer_threshold()`, does not have the bad_param field. + diff --git a/tests/testthat/_snaps/parse_period.md b/tests/testthat/_snaps/parse_period.md new file mode 100644 index 000000000..bc782dea7 --- /dev/null +++ b/tests/testthat/_snaps/parse_period.md @@ -0,0 +1,32 @@ +# parse_period works + + Code + parse_period(c(1, 2)) + Condition + Error in `parse_period()`: + ! `x` must be a scalar. + +--- + + Code + parse_period(c(1.3)) + Condition + Error in `parse_period()`: + ! rlang::is_integerish(x) is not TRUE + +--- + + Code + parse_period("1 year") + Condition + Error in `parse_period()`: + ! incompatible timespan in `aheads`. + +--- + + Code + parse_period("2 weeks later") + Condition + Error in `parse_period()`: + ! incompatible timespan in `aheads`. + diff --git a/tests/testthat/_snaps/parsnip_model_validation.md b/tests/testthat/_snaps/parsnip_model_validation.md new file mode 100644 index 000000000..365e6b2b8 --- /dev/null +++ b/tests/testthat/_snaps/parsnip_model_validation.md @@ -0,0 +1,18 @@ +# forecaster can validate parsnip model + + Code + get_parsnip_mode(l) + Condition + Error in `get_parsnip_mode()`: + ! `trainer` must be a `parsnip` model. + i This trainer has class: . + +--- + + Code + is_classification(l) + Condition + Error in `get_parsnip_mode()`: + ! `trainer` must be a `parsnip` model. + i This trainer has class: . + diff --git a/tests/testthat/_snaps/pivot_quantiles.md b/tests/testthat/_snaps/pivot_quantiles.md new file mode 100644 index 000000000..184eb62a6 --- /dev/null +++ b/tests/testthat/_snaps/pivot_quantiles.md @@ -0,0 +1,51 @@ +# quantile pivotting wider behaves + + Code + pivot_quantiles_wider(tib, a) + Condition + Error in `UseMethod()`: + ! no applicable method for 'family' applied to an object of class "c('integer', 'numeric')" + +--- + + Code + pivot_quantiles_wider(tib, c) + Condition + Error in `validate_pivot_quantiles()`: + ! Variables(s) `c` are not `dist_quantiles`. Cannot pivot them. + +--- + + Code + pivot_quantiles_wider(tib, d1) + Condition + Error in `pivot_quantiles_wider()`: + ! Quantiles must be the same length and have the same set of taus. + i Check failed for variables(s) `d1`. + +# quantile pivotting longer behaves + + Code + pivot_quantiles_longer(tib, a) + Condition + Error in `UseMethod()`: + ! no applicable method for 'family' applied to an object of class "c('integer', 'numeric')" + +--- + + Code + pivot_quantiles_longer(tib, c) + Condition + Error in `validate_pivot_quantiles()`: + ! Variables(s) `c` are not `dist_quantiles`. Cannot pivot them. + +--- + + Code + pivot_quantiles_longer(tib, d1, d3) + Condition + Error in `pivot_quantiles_longer()`: + ! Some selected columns contain different numbers of quantiles. + The result would be a very long . + To do this anyway, rerun with `.ignore_length_check = TRUE`. + diff --git a/tests/testthat/_snaps/population_scaling.md b/tests/testthat/_snaps/population_scaling.md new file mode 100644 index 000000000..9263e8e1e --- /dev/null +++ b/tests/testthat/_snaps/population_scaling.md @@ -0,0 +1,16 @@ +# expect error if `by` selector does not match + + Code + wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) + Condition + Error in `hardhat::validate_column_names()`: + ! The following required columns are missing: 'a'. + +--- + + Code + forecast(wf) + Condition + Error in `hardhat::validate_column_names()`: + ! The following required columns are missing: 'nothere'. + diff --git a/tests/testthat/_snaps/shuffle.md b/tests/testthat/_snaps/shuffle.md new file mode 100644 index 000000000..53eea9b92 --- /dev/null +++ b/tests/testthat/_snaps/shuffle.md @@ -0,0 +1,8 @@ +# shuffle works + + Code + shuffle(matrix(NA, 2, 2)) + Condition + Error in `shuffle()`: + ! is.vector(x) is not TRUE + diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md new file mode 100644 index 000000000..84abf57d2 --- /dev/null +++ b/tests/testthat/_snaps/snapshots.md @@ -0,0 +1,1060 @@ +# flatline_forecaster snapshots + + structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" + ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, + 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, + 0.34820911), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.31206391), quantile_levels = c(0.05, 0.95 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.10325949, 0.52098931 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.21298119, 0.63071101), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.52311949, 0.94084931 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.40640751), quantile_levels = c(0.05, 0.95 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" + )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, + 18992), class = "Date"), target_date = structure(c(18999, 18999, + 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, + -6L), class = c("tbl_df", "tbl", "data.frame")) + +--- + + structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" + ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, + 0.1975426), .pred_distn = structure(list(structure(list(values = c(0.084583345, + 0.194105055), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.048438145, 0.157959855), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.257363545, 0.366885255 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.367085245, 0.476606955), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.677223545, 0.786745255 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.142781745, 0.252303455), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" + )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, + 18992), class = "Date"), target_date = structure(c(18993, 18993, + 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, + -6L), class = c("tbl_df", "tbl", "data.frame")) + +--- + + structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" + ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, + 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, + 0.34820911), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.31206391), quantile_levels = c(0.05, 0.95 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.10325949, 0.52098931 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.21298119, 0.63071101), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.52311949, 0.94084931 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.40640751), quantile_levels = c(0.05, 0.95 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" + )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, + 18992), class = "Date"), target_date = structure(c(18999, 18999, + 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, + -6L), class = c("tbl_df", "tbl", "data.frame")) + +--- + + structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" + ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, + 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, + 0.34820911), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.31206391), quantile_levels = c(0.05, 0.95 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.10325949, 0.52098931 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.21298119, 0.63071101), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.52311949, 0.94084931 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.40640751), quantile_levels = c(0.05, 0.95 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" + )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, + 18992), class = "Date"), target_date = structure(c(18993, 18993, + 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, + -6L), class = c("tbl_df", "tbl", "data.frame")) + +# cdc_baseline_forecaster snapshots + + structure(list(geo_value = c("ca", "ca", "ca", "ca", "ca", "fl", + "fl", "fl", "fl", "fl", "ga", "ga", "ga", "ga", "ga", "ny", "ny", + "ny", "ny", "ny", "pa", "pa", "pa", "pa", "pa", "tx", "tx", "tx", + "tx", "tx"), .pred = c(0.1393442, 0.1393442, 0.1393442, 0.1393442, + 0.1393442, 0.103199, 0.103199, 0.103199, 0.103199, 0.103199, + 0.3121244, 0.3121244, 0.3121244, 0.3121244, 0.3121244, 0.4218461, + 0.4218461, 0.4218461, 0.4218461, 0.4218461, 0.7319844, 0.7319844, + 0.7319844, 0.7319844, 0.7319844, 0.1975426, 0.1975426, 0.1975426, + 0.1975426, 0.1975426), ahead = c(1L, 2L, 3L, 4L, 5L, 1L, 2L, + 3L, 4L, 5L, 1L, 2L, 3L, 4L, 5L, 1L, 2L, 3L, 4L, 5L, 1L, 2L, 3L, + 4L, 5L, 1L, 2L, 3L, 4L, 5L), .pred_distn = structure(list(structure(list( + values = c(0, 0, 0, 0.05519342, 0.082372705, 0.0936219, 0.1048711, + 0.1157573, 0.12317806, 0.1302723, 0.1353526, 0.1393442, 0.1433358, + 0.1484161, 0.15551034, 0.1629311, 0.1738173, 0.1850665, 0.196315695, + 0.22349498, 0.309768685, 0.3567520625, 0.439580229), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.0335550493877939, + 0.0604073208819088, 0.0796881899581496, 0.0945180888333883, 0.107218788833388, + 0.118830788833388, 0.129717088833388, 0.1393442, 0.148949488833388, + 0.159110072060821, 0.171080110623306, 0.184009705322953, 0.19866346102411, + 0.218798896615666, 0.250961850618106, 0.300471354816148, 0.368582781136862, + 0.43909595699107, 0.520101234797705), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.0310685196688967, + 0.0565901050435504, 0.0768417663716637, 0.0947104815343153, 0.110553706525765, + 0.125192081534315, 0.1393442, 0.153133424194392, 0.167807181271713, + 0.183769310145952, 0.202099979390294, 0.224139947221972, 0.252840918770688, + 0.291417895572206, 0.341073550318203, 0.420604597710477, 0.494523225410904, + 0.573647294116801), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.00623643594225938, 0.0360877950479505, + 0.0604332430739307, 0.0824028153516535, 0.102509343235732, + 0.121439405653606, 0.1393442, 0.15780837904264, 0.176333479766098, + 0.1971089199637, 0.219859545844459, 0.246500872561225, 0.279163385675357, + 0.320379296602716, 0.374497727839579, 0.458894379633346, + 0.535545067037845, 0.628776504364044), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0, 0.0192048017017668, + 0.0478501821296211, 0.0723167026720766, 0.0958385084225842, 0.11812331897399, + 0.1393442, 0.161074539705197, 0.184026763327133, 0.207844848454635, + 0.23407004803228, 0.265166265836908, 0.302137478236883, 0.346008752873429, + 0.403205598400084, 0.495260096430714, 0.574198142463125, 0.672941852619816 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343, + 0.0847941, 0.0966258, 0.103199, 0.1097722, 0.1216039, 0.1314637, + 0.1419808, 0.15414125, 0.17090286, 0.189932235, 0.22848398, 0.30542311, + 0.40216399, 0.512353658), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.00331296053340532, 0.0234804643776438, + 0.0414109089650896, 0.0579040140087902, 0.0738391473860739, + 0.0882882738549385, 0.103199, 0.118522737211872, 0.134217143129031, + 0.15174910202592, 0.17076597900759, 0.192368859892349, 0.218887, + 0.254338497855279, 0.307871753369934, 0.407530532639726, + 0.506824682189646, 0.607973477267732), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0, 0, + 0.0185864520320203, 0.0411215858914089, 0.062281046686267, 0.0828222124563246, + 0.103199, 0.123575888447284, 0.144785989158292, 0.167277039342293, + 0.192536265178252, 0.221677797769728, 0.256887836856768, 0.302366681512415, + 0.3669383199518, 0.476508917333523, 0.574293059865274, 0.69194511433946 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0, 0, 0, 0, 0, 0, 0.0271019287070871, 0.0535555494987951, 0.0785514374097741, + 0.103199, 0.128043832742677, 0.154157375592856, 0.181874602598776, + 0.212708648669987, 0.247608381738568, 0.289082621291513, 0.342486159511745, + 0.41300665395314, 0.52870334697862, 0.634316186092986, 0.767614547228429 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0, 0, 0, 0, 0, 0, 0.0118725894981448, 0.0439446210512103, 0.0736366703227029, + 0.103199, 0.133138617710077, 0.16357656105121, 0.19575459701827, + 0.230475760859608, 0.269323345322203, 0.314976554734947, 0.373424338576786, + 0.452807955824158, 0.578141866759416, 0.690542571738594, 0.837295153768033 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0.0813658, 0.14899276, 0.1960782, 0.22542314, 0.2414296, 0.25890318, + 0.2747762, 0.2881148, 0.3027873, 0.3121244, 0.3214615, 0.336134, + 0.3494726, 0.36534562, 0.3828192, 0.39882566, 0.4281706, 0.47525604, + 0.542883, 0.682805397499999, 0.798878314999999), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0.0706949, + 0.1267172, 0.1667331, 0.198582473624236, 0.225423180397104, 0.2494327, + 0.2707747, 0.292116312116921, 0.3121244, 0.3321324, 0.353072222341423, + 0.375089999249792, 0.3988256, 0.425831930221552, 0.459232792604326, + 0.501467782274773, 0.562188443556836, 0.685648485782108, 0.80647163752115, + 0.939224788489265), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0.0704696868359684, 0.121407167925079, + 0.161930580284053, 0.197682797539976, 0.228361656891269, + 0.257706650923509, 0.285717780926109, 0.3121244, 0.338115598498035, + 0.365749693067931, 0.395921877240673, 0.427437934626446, + 0.462388578749537, 0.504066064225642, 0.558443518811788, + 0.636013559040791, 0.771225883005179, 0.89210797204162, 1.02314875759509 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, + 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, + 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0.0247190015881658, 0.0834693973257732, + 0.131490031120311, 0.173258318827988, 0.211213742349423, + 0.246202447408474, 0.279926744217642, 0.3121244, 0.344908347408474, + 0.378255200773608, 0.412935547408474, 0.45191576510605, 0.494757615230152, + 0.545060918490786, 0.609312182129471, 0.69704881099591, 0.838550239412991, + 0.962653262246773, 1.11351403170759), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.0501392705767058, + 0.104248897713977, 0.151994400390804, 0.195087767727627, 0.235544124698047, + 0.274058107118071, 0.3121244, 0.350571341810268, 0.390274666572666, + 0.43048632300908, 0.474320393891039, 0.523839613390634, 0.581010268149082, + 0.652137495469405, 0.748428674762348, 0.898563270096551, 1.03273295410124, + 1.19211145220822), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0.2148017, 0.31250597, 0.350183905, 0.3745639, + 0.3884161, 0.39746621, 0.404854, 0.4115031, 0.417413315, + 0.4218461, 0.426278885, 0.4321891, 0.4388382, 0.44622599, + 0.4552761, 0.4691283, 0.493508295, 0.53118623, 0.628890499999999, + 1.22043540499999, 1.95905017899999), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0.212369462232823, + 0.289571577546325, 0.324446887783878, 0.351262144469445, 0.37087, + 0.3863844, 0.399682509835098, 0.411036898891089, 0.4218461, 0.432927818676137, + 0.444338520819208, 0.4573077, 0.4728222, 0.492817749438994, 0.519442857224172, + 0.556165331447064, 0.635946057886079, 1.18402232252562, 1.7827032389242, + 2.5561261649726), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.0413098183761837, 0.216633655848608, + 0.28006329699657, 0.3175577049983, 0.345923291761818, 0.368957399144641, + 0.38804556403724, 0.405400893204282, 0.4218461, 0.43864616004845, + 0.456105937661177, 0.475585378227632, 0.499018124730147, + 0.5270891900114, 0.564293444378844, 0.630730263388634, 0.898212235100651, + 1.53976520159876, 2.08228809477582, 2.80588762256078), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.114729892920429, + 0.227785958288583, 0.282278878729037, 0.320407599201492, 0.350577823459785, + 0.37665230304923, 0.39981364198757, 0.4218461, 0.444009706175862, + 0.466962725214852, 0.493098379685547, 0.523708407392674, 0.562100740111401, + 0.619050517814778, 0.754868363055733, 1.1177263295869, 1.76277018354499, + 2.37278671910076, 2.9651652434047), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.0100954501382014, + 0.165091099860099, 0.244964334392844, 0.294577054174442, 0.333357739419644, + 0.365251480804308, 0.394198909379894, 0.4218461, 0.449607812233022, + 0.479120513116631, 0.511271131674317, 0.5506402899964, 0.60295411796593, + 0.690751300611906, 0.913578722060166, 1.30856988553206, 1.94020220543606, + 2.57104934168037, 3.07139639379724), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.303454977, 0.3982330425, + 0.46791125, 0.57642367, 0.631462275, 0.6694025, 0.685048, 0.69857015, + 0.7085162, 0.71633898, 0.7252792, 0.7319844, 0.7386896, 0.74762982, + 0.7554526, 0.76539865, 0.7789208, 0.7945663, 0.832506525, 0.88754513, + 0.99605755, 1.0657357575, 1.160513823), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.188727136659627, + 0.292714653217782, 0.380882595473705, 0.476427609604196, 0.5464739, + 0.6001155, 0.636506664263643, 0.6638148, 0.684726301742618, 0.701811, + 0.7174565, 0.7319844, 0.7465124, 0.7621578, 0.779322149415794, + 0.800154, 0.826981204292293, 0.8649709, 0.918345662372574, 0.987315641681917, + 1.08210087899389, 1.17564510102166, 1.27428433325155), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.0928040444059739, + 0.212569233904214, 0.310718449102641, 0.418013562853928, 0.489917936424114, + 0.546885925424654, 0.593410228218282, 0.631406259421094, 0.661579628218282, + 0.687282906872069, 0.710456666258662, 0.7319844, 0.754131389282943, + 0.776685628218282, 0.802388976168662, 0.832758896293562, 0.869440928218282, + 0.916359694097141, 0.97403912794778, 1.04529048496565, 1.15710382277548, + 1.25675656404419, 1.37098330871205), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.0108404989744699, + 0.144337973117581, 0.250292371898569, 0.367310419323293, 0.44444044802193, + 0.506592035751958, 0.558428768125431, 0.602035095628756, 0.64112383905529, + 0.674354964141041, 0.703707875219752, 0.7319844, 0.760702196782168, + 0.78975826405844, 0.823427572594726, 0.860294897090771, 0.904032120658957, + 0.955736581115011, 1.0165945004053, 1.09529786576616, 1.21614421175967, + 1.32331604019295, 1.45293812780298), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0.0783919019408445, + 0.19440762901709, 0.323264916745368, 0.407999619319143, 0.474764568463685, + 0.530890671381964, 0.580852443909739, 0.623441748828038, 0.661393469870099, + 0.69827126098506, 0.7319844, 0.766440770218252, 0.802260162496625, + 0.840536805657307, 0.883133954556946, 0.931565607767828, 0.98815401699637, + 1.05406790404239, 1.138596250043, 1.27030064370239, 1.39007785503355, + 1.5343628053761), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0.012845105, 0.07040502, 0.09495188, 0.12669976, + 0.1502248, 0.1659163, 0.1761341, 0.18586528, 0.191290375, + 0.1975426, 0.203794825, 0.20921992, 0.2189511, 0.2291689, + 0.2448604, 0.26838544, 0.30013332, 0.32468018, 0.382240095, + 0.5020427625, 0.590302013999998), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0.0133856545472455, + 0.0528330564916649, 0.0825071163605637, 0.107217748074731, 0.130397558147181, + 0.151367721571716, 0.1688357, 0.183736649076791, 0.1975426, 0.2111662, + 0.226622576069161, 0.244738709634746, 0.265660771838618, 0.289502, + 0.3157762, 0.347933515877459, 0.395446576674467, 0.494033943284933, + 0.586036939413118, 0.696507800090321), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.0119984314577645, + 0.0497573816250162, 0.081255049503995, 0.108502307388674, 0.132961558931189, + 0.156011650575706, 0.177125892134071, 0.1975426, 0.217737120618906, + 0.239458499211792, 0.263562581820818, 0.289525383565136, 0.31824420000725, + 0.35141305194052, 0.393862560773808, 0.453538799225292, 0.558631806850418, + 0.657452391363313, 0.767918764883928), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.0189057930465303, + 0.0558619823820737, 0.0885055048481483, 0.117823094349893, 0.145878789120691, + 0.171852417645726, 0.1975426, 0.222526993865839, 0.249029206661066, + 0.27731797305948, 0.306704680469104, 0.340659034209842, 0.379550761828618, + 0.429562304567396, 0.499209921951019, 0.612206099576094, 0.713714149138691, + 0.835600324727346), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0.0331956220262204, 0.0710455499705998, + 0.105140687231072, 0.136976315413355, 0.167518817907279, + 0.1975426, 0.226974062486675, 0.257640196272163, 0.289459502055271, + 0.323342029611596, 0.361500312536625, 0.407123841331413, + 0.46286764504675, 0.538379175655057, 0.659249503348734, 0.768470658367656, + 0.898774707571334), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992), class = "Date"), target_date = structure(c(18999, 19006, + 19013, 19020, 19027, 18999, 19006, 19013, 19020, 19027, 18999, + 19006, 19013, 19020, 19027, 18999, 19006, 19013, 19020, 19027, + 18999, 19006, 19013, 19020, 19027, 18999, 19006, 19013, 19020, + 19027), class = "Date")), row.names = c(NA, -30L), class = c("tbl_df", + "tbl", "data.frame")) + +--- + + structure(list(geo_value = c("ca", "ca", "ca", "ca", "ca", "fl", + "fl", "fl", "fl", "fl", "ga", "ga", "ga", "ga", "ga", "ny", "ny", + "ny", "ny", "ny", "pa", "pa", "pa", "pa", "pa", "tx", "tx", "tx", + "tx", "tx"), .pred = c(0.1393442, 0.1393442, 0.1393442, 0.1393442, + 0.1393442, 0.103199, 0.103199, 0.103199, 0.103199, 0.103199, + 0.3121244, 0.3121244, 0.3121244, 0.3121244, 0.3121244, 0.4218461, + 0.4218461, 0.4218461, 0.4218461, 0.4218461, 0.7319844, 0.7319844, + 0.7319844, 0.7319844, 0.7319844, 0.1975426, 0.1975426, 0.1975426, + 0.1975426, 0.1975426), ahead = c(2L, 3L, 4L, 5L, 6L, 2L, 3L, + 4L, 5L, 6L, 2L, 3L, 4L, 5L, 6L, 2L, 3L, 4L, 5L, 6L, 2L, 3L, 4L, + 5L, 6L, 2L, 3L, 4L, 5L, 6L), .pred_distn = structure(list(structure(list( + values = c(0, 0, 0, 0, 0.0344362435566855, 0.0610170086495865, + 0.0798865084778347, 0.0944014546310463, 0.107339121226462, + 0.11899734099851, 0.129600408649586, 0.1393442, 0.149195708649586, + 0.159627982246122, 0.170968308649587, 0.184031805880359, + 0.198909658094331, 0.219058736130861, 0.250692448549235, + 0.300646382944129, 0.368938143197633, 0.440038195052124, + 0.51997011826723), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.0303364052628526, 0.0557306728227282, + 0.0766736159703596, 0.0942284381264812, 0.11050757203172, + 0.125214601455714, 0.1393442, 0.15359732398729, 0.168500447692877, + 0.184551468093631, 0.202926420944109, 0.22476606802393, 0.253070223293233, + 0.29122995395109, 0.341963643747938, 0.419747975311502, 0.495994046054689, + 0.5748791770223), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.00603076915889168, 0.0356039073625737, + 0.0609470811194113, 0.0833232869645198, 0.103265350891109, + 0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668, + 0.196866917086671, 0.219796529731897, 0.247137200365254, + 0.280371254591746, 0.320842872758278, 0.374783454750148, + 0.461368597638526, 0.539683256474915, 0.632562403391324), + quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, + 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0.018869505399304, 0.0471517885822858, + 0.0732707765908659, 0.0969223475714758, 0.118188509171441, + 0.1393442, 0.161036861715017, 0.183255665579256, 0.207206810683007, + 0.23409988698267, 0.265549713886389, 0.302197074524145, 0.346715970732557, + 0.40460690801818, 0.498076490174802, 0.580016068409433, 0.680138975526255 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, + 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, + 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0.00232218982614828, 0.0342017690820909, + 0.062828756299263, 0.0893725834453345, 0.114623710996309, + 0.1393442, 0.163790622390774, 0.189495107256772, 0.216754530328403, + 0.247065337260473, 0.281410456107061, 0.32037037400004, 0.367018829587046, + 0.431198706165962, 0.52829547296083, 0.619021148955337, 0.728730172315724 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, + 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, + 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.00233673672776743, 0.0223488000000001, + 0.040304673503435, 0.0576262998104982, 0.0732741199141993, + 0.088455610793058, 0.103199, 0.118707592060121, 0.134185928864089, + 0.151183139276793, 0.1702454, 0.191937, 0.2182298, 0.253577609846549, + 0.307351538752588, 0.407165223924639, 0.502529513927214, + 0.605582108686126), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0, 0.0190621000375005, 0.0420071558734088, + 0.0629230825705257, 0.0833688260410605, 0.103199, 0.124118509153392, + 0.145401945823358, 0.168667287877079, 0.1939090000375, 0.222597428173282, + 0.256984900377504, 0.301709122144422, 0.366495424858649, + 0.475152766217062, 0.572497835146252, 0.693762274318904), + quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, + 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0, 0, 0.0269530072946728, 0.0530040092850928, + 0.0782481277003769, 0.103199, 0.12816325599641, 0.154866111682517, + 0.182302899107341, 0.213783044306043, 0.248363904708547, + 0.28995690796288, 0.341627908394784, 0.413707680386504, 0.528381820556805, + 0.635771182105746, 0.77652465912812), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0, 0, + 0, 0.0133969262208122, 0.0447913089328894, 0.0739787251314013, + 0.103199, 0.132965213784838, 0.163644939246192, 0.196475575572506, + 0.231647450729907, 0.271208219491195, 0.317741925837459, 0.376214875186902, + 0.454693715463155, 0.578781950822058, 0.695278060333427, 0.835521146843828 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0, 0, 0, 0, 0, 0, 0.000725156354313476, 0.036290207696477, 0.0701157049196494, + 0.103199, 0.136581757676227, 0.170980571439515, 0.20778982998995, + 0.247087076718167, 0.291689672899979, 0.343587258527985, 0.406717577407724, + 0.490437549306793, 0.620305872542078, 0.740730855925609, 0.888992767585756 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0, 0.0701359181289814, 0.126021564763798, 0.165542973066331, + 0.197412078824538, 0.2254231, 0.24849244896414, 0.271074448350284, + 0.292116376731667, 0.3121244, 0.3321324, 0.3534741, 0.375505591313813, + 0.4001594, 0.4268368, 0.459466546351464, 0.501142770839258, 0.562143084394445, + 0.686511993260583, 0.808747521078011, 0.936070949770187), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0.00157374045240457, + 0.0698662194634446, 0.120287640452405, 0.16090076400914, 0.195966561494315, + 0.227802919628796, 0.257250456567366, 0.284352940452404, 0.3121244, + 0.338954445099751, 0.366682808562485, 0.395431772465525, 0.428410340452405, + 0.464424683613586, 0.505774640452405, 0.559060310062401, 0.635868688255882, + 0.771213743700187, 0.895124744284645, 1.02835689610128), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.0203251909788099, + 0.0807941084801849, 0.131156594663197, 0.173483742579226, 0.211670557196072, + 0.246244078609487, 0.278363918673537, 0.3121244, 0.345057130768308, + 0.378403757196072, 0.414130127568126, 0.451969178608786, 0.495598517595426, + 0.545136665227352, 0.60807806098831, 0.695394235571256, 0.837130344811698, + 0.966111057134121, 1.11185508502426), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.0477276069251695, + 0.103509981435814, 0.15221877094871, 0.195952578625286, 0.236147272793828, + 0.274650521629366, 0.3121244, 0.349346986282313, 0.388561057230272, + 0.429378978625286, 0.474721256740267, 0.523806740641156, 0.581962784214742, + 0.652062951302463, 0.746838578625286, 0.896492945755508, 1.0340527654686, + 1.19219029825678), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.0166039560593608, 0.0776387168354182, + 0.132003170161801, 0.180530886857168, 0.22594722201882, 0.268822337600976, + 0.3121244, 0.354489864523245, 0.398378553881739, 0.444274543339083, + 0.494499388431484, 0.548837448212482, 0.612239188685087, + 0.690272902609576, 0.790473599123991, 0.950950996975469, + 1.09638828065763, 1.26930966690442), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0.214450885057551, + 0.288864871241312, 0.3250653, 0.3516615, 0.3716087, 0.386718885323753, + 0.399682691320713, 0.411042976158862, 0.4218461, 0.4329278, 0.444139278140181, + 0.456951313505885, 0.4720835, 0.4920307, 0.518626803531635, 0.555566110165902, + 0.636745822624727, 1.18069710590251, 1.79487371178211, 2.55270530204625 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0, 0.0412188277837779, 0.218851219710947, 0.281178109847399, + 0.318187061211362, 0.346336916208562, 0.368500427783778, 0.387753955899259, + 0.405439627783778, 0.4218461, 0.438238911502765, 0.455473161565916, + 0.474946888792488, 0.497793222697627, 0.526600327783778, 0.565677321171112, + 0.632773149305243, 0.891087255237454, 1.53723873883164, 2.07877430490449, + 2.80265665435411), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0.11916637099981, 0.229217761668717, + 0.283591182792578, 0.32089403701397, 0.351025234947199, 0.376764238355684, + 0.399580647158371, 0.4218461, 0.44387311299288, 0.466809871716417, + 0.493008689720547, 0.523409488360383, 0.563157298622986, + 0.621505313473235, 0.756485815282202, 1.12190615310943, 1.76010655352564, + 2.36678033794496, 2.94420631979259), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.0166944520132201, + 0.165418069472795, 0.245206977511275, 0.294705591133411, 0.333122440419504, + 0.365628706470365, 0.393898304736197, 0.4218461, 0.449111464628896, + 0.478419567119571, 0.511583967360174, 0.551380591704217, 0.602914542469175, + 0.695207681738717, 0.912006796599716, 1.31516316514125, 1.94296465866439, + 2.56528565211139, 3.07364144272118), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.095868346511765, + 0.20216012803078, 0.267545492825128, 0.314290150935209, 0.353895445422154, + 0.388115128404834, 0.4218461, 0.455823761272913, 0.49135719600286, + 0.53249009905049, 0.582341165610556, 0.654473427614026, 0.784511194125441, + 1.05644872659752, 1.47044175860169, 2.09183984013705, 2.69484857437112, + 3.1694157654766), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.189889609612846, 0.28831400446517, 0.378590156778518, + 0.474951757151471, 0.546550271666467, 0.599713541496415, + 0.636994072140471, 0.663814888730087, 0.6839305, 0.701811, + 0.71711131701917, 0.7319844, 0.746512343291783, 0.7621579, + 0.7800383, 0.800154, 0.826974702066021, 0.86472325100111, + 0.918612458720487, 0.988605006042461, 1.08324298909714, 1.1736324426019, + 1.27400190201593), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.0970521814156041, 0.21019273451422, 0.3073217, + 0.418096666577866, 0.489016664299943, 0.547102113575136, + 0.594490775323003, 0.63162246104581, 0.661579866583116, 0.687283, + 0.709633785855109, 0.7319844, 0.754030577281223, 0.776967707389074, + 0.802389, 0.832791429272493, 0.870576437517875, 0.917019363782438, + 0.973069487834329, 1.04481411391714, 1.15502640396814, 1.25613855529213, + 1.37419193312441), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.0121672025865257, 0.139873460696682, 0.245836896475015, + 0.366700877088971, 0.445024777793378, 0.506295707796278, + 0.557812941319663, 0.601634091201612, 0.639324955546405, + 0.673001603565436, 0.702827370737707, 0.7319844, 0.760387153293983, + 0.790515252114921, 0.823330663438584, 0.86065768198682, 0.904468070814958, + 0.954989716167962, 1.01626566701207, 1.09352836237872, 1.21548452077266, + 1.32239947141536, 1.46006378366371), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0.0755189873928237, + 0.192404624794198, 0.322282766861868, 0.409749729479745, 0.475729034228042, + 0.531171513462134, 0.579442333436034, 0.623023292701627, 0.662178609529395, + 0.697968947885378, 0.7319844, 0.766345465406154, 0.80256496503135, + 0.841452466611966, 0.884524366576965, 0.93218174000415, 0.988252217755677, + 1.05297410373014, 1.13838991320473, 1.27210128334768, 1.38822119412612, + 1.53603026586717), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.0137515321313713, 0.140785106599616, 0.283710273212032, + 0.374321519596796, 0.446394180252102, 0.505830587319873, + 0.559570052916329, 0.606684360953109, 0.65111343293503, 0.692845474832798, + 0.7319844, 0.771333743893139, 0.812267094081241, 0.855930534362644, + 0.903545840608706, 0.955193592261423, 1.01560313647486, 1.08583632750787, + 1.17818451335943, 1.31856131315813, 1.44615719776698, 1.60468791291453 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, + 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, + 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.0124103998425985, 0.0518320161167612, + 0.0822283734557346, 0.106956582246572, 0.130236689538895, + 0.150852198845738, 0.168835673455735, 0.183678547429124, + 0.1975426, 0.211166273455735, 0.226249473455735, 0.243919155834858, + 0.265304527061771, 0.289781663064881, 0.315985067670677, + 0.347644682675627, 0.394981842425824, 0.491215248628636, + 0.584975102439074, 0.694697494489265), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.0106056685868359, + 0.0491424720812208, 0.0803975947094471, 0.108060576398464, 0.133638500841809, + 0.155968088623186, 0.177107275224252, 0.1975426, 0.218180906543366, + 0.239601831646016, 0.262811949904799, 0.28886838404664, 0.317235975224252, + 0.350545157867879, 0.393998327257523, 0.454550976564066, 0.558555075803007, + 0.656859449317743, 0.763718974419534), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.0185370189554894, + 0.0562218087603375, 0.0890356919950198, 0.118731362266373, 0.146216910144001, + 0.172533896645116, 0.1975426, 0.223021121504065, 0.249412654553045, + 0.277680444480195, 0.308522683806638, 0.342270845449704, 0.382702709814398, + 0.433443929063141, 0.501610622734127, 0.61417580106326, 0.715138862353848, + 0.833535553075286), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0.0346528073343234, 0.0723584880324803, + 0.106222897173122, 0.138467941096611, 0.167844669490445, + 0.1975426, 0.227591504589096, 0.258479799230192, 0.290862843650987, + 0.325718759418194, 0.364163081687565, 0.409581315443156, + 0.46531554698862, 0.54043504498905, 0.659111642885379, 0.761453612496025, + 0.889794566241181), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0.0134397969692197, 0.0557212574100741, + 0.0941597345954959, 0.130401776157262, 0.164200585080601, + 0.1975426, 0.231566981332063, 0.265597088493385, 0.30192115798073, + 0.341652226704467, 0.384249568152932, 0.43541812199952, 0.495340659591346, + 0.575765691755518, 0.703032070294999, 0.815605113815338, + 0.955488202108743), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992), class = "Date"), target_date = structure(c(19006, 19013, + 19020, 19027, 19034, 19006, 19013, 19020, 19027, 19034, 19006, + 19013, 19020, 19027, 19034, 19006, 19013, 19020, 19027, 19034, + 19006, 19013, 19020, 19027, 19034, 19006, 19013, 19020, 19027, + 19034), class = "Date")), row.names = c(NA, -30L), class = c("tbl_df", + "tbl", "data.frame")) + +--- + + structure(list(geo_value = c("ca", "ca", "ca", "ca", "ca", "fl", + "fl", "fl", "fl", "fl", "ga", "ga", "ga", "ga", "ga", "ny", "ny", + "ny", "ny", "ny", "pa", "pa", "pa", "pa", "pa", "tx", "tx", "tx", + "tx", "tx"), .pred = c(0.1393442, 0.1393442, 0.1393442, 0.1393442, + 0.1393442, 0.103199, 0.103199, 0.103199, 0.103199, 0.103199, + 0.3121244, 0.3121244, 0.3121244, 0.3121244, 0.3121244, 0.4218461, + 0.4218461, 0.4218461, 0.4218461, 0.4218461, 0.7319844, 0.7319844, + 0.7319844, 0.7319844, 0.7319844, 0.1975426, 0.1975426, 0.1975426, + 0.1975426, 0.1975426), ahead = c(1L, 2L, 3L, 4L, 5L, 1L, 2L, + 3L, 4L, 5L, 1L, 2L, 3L, 4L, 5L, 1L, 2L, 3L, 4L, 5L, 1L, 2L, 3L, + 4L, 5L, 1L, 2L, 3L, 4L, 5L), .pred_distn = structure(list(structure(list( + values = c(0, 0, 0.00812835000000001, 0.07297428, 0.0936219, + 0.10421786, 0.1121285, 0.1201118, 0.1273693, 0.1317238, 0.1360783, + 0.1393442, 0.1426101, 0.1469646, 0.1513191, 0.1585766, 0.1665599, + 0.17447054, 0.1850665, 0.20571412, 0.27056005, 0.313941744999999, + 0.384931126999997), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.0250982954899548, 0.0576421230361804, + 0.0776985410529105, 0.0929731777892779, 0.104205115094451, + 0.114209292598776, 0.123365027741977, 0.131496226094211, + 0.1393442, 0.147007648291083, 0.154990950042, 0.16406284204392, + 0.173835548288583, 0.185472494222942, 0.200167568392984, + 0.221760005190952, 0.260313716029161, 0.318794320716957, + 0.376941794597195, 0.461705276864399), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.028693230499105, + 0.055453963203632, 0.0755679534410344, 0.0913921813275133, 0.104804902302573, + 0.117142722458225, 0.128444430213702, 0.1393442, 0.150479535783308, + 0.161776522458225, 0.173925041831968, 0.187540579925299, 0.204200618941439, + 0.225353161205212, 0.253695961466565, 0.294498109305393, 0.358245879234942, + 0.427563795224327, 0.501665748776186), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.00587171510650109, + 0.0364866623781238, 0.0602683002957529, 0.0794861096145961, 0.0963414561651617, + 0.111439230212802, 0.125394639614746, 0.1393442, 0.153216527502025, + 0.167801944181742, 0.183359587288923, 0.200880434888349, 0.221656465706657, + 0.24743726609676, 0.279449270180852, 0.322415149384594, 0.395367499639696, + 0.464904880713406, 0.539558052669137), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.019055042091221, + 0.0457625510440105, 0.068309473710537, 0.087945102194822, 0.106033592330923, + 0.123045226382564, 0.1393442, 0.155351600131351, 0.172491058371384, + 0.19101350900654, 0.211425349928599, 0.234936300692507, 0.264303292652126, + 0.299599722715327, 0.346282638921389, 0.423857010226352, 0.494689091614341, + 0.577833814673327), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.00138033000000002, 0.030893965, 0.0479842, + 0.059815975, 0.07118759, 0.0815075, 0.0926819, 0.0992551, + 0.103199, 0.1071429, 0.1137161, 0.1248905, 0.13521041, 0.146582025, + 0.1584138, 0.175504035, 0.20501767, 0.25694586, 0.335051815, + 0.436709474), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.0179658025100251, 0.0356060154111541, + 0.050834301692017, 0.0650050989327893, 0.0784417069434695, + 0.0916422518458685, 0.103199, 0.115251501692017, 0.128398001692017, + 0.142201701692017, 0.157319973859039, 0.174980914065641, + 0.196101805086251, 0.223989860848608, 0.266334685464555, + 0.354050965519204, 0.437948459272293, 0.520203978940639), + quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, + 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0, 0.0134241653129031, 0.0338447112456125, + 0.052643303388484, 0.0699345638167383, 0.0866373614747148, + 0.103199, 0.119627111136411, 0.137401026927169, 0.156056395793358, + 0.175781901322513, 0.198564535163602, 0.226934571881819, + 0.263862501322513, 0.317121769745397, 0.412419996940619, + 0.491470213131306, 0.580892509639735), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0, 0, + 0.0170903, 0.0403385023363734, 0.0616387632732329, 0.0827585779094291, + 0.103199, 0.123094939420544, 0.14464638301663, 0.1669589, 0.191770645535455, + 0.220735117412174, 0.254231042750228, 0.296807527848978, 0.357153759489695, + 0.45347931404539, 0.538725322834228, 0.636530647411066), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0, 0, + 0.0026415954262542, 0.0297423239924899, 0.0555402340406406, 0.0792255827466275, + 0.103199, 0.127366925585556, 0.151700351432014, 0.177708522618176, + 0.206088123699737, 0.238712707453825, 0.277708313715037, 0.325132239647296, + 0.390468252727729, 0.490417296529864, 0.578557086846368, 0.688679948593326 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0.0320461375000001, + 0.129384955, 0.18940881, 0.2200878, 0.2427634, 0.2587698, 0.2734423, + 0.2841133, 0.296118, 0.3041212, 0.3121244, 0.3201276, 0.3281308, + 0.3401355, 0.3508065, 0.365479, 0.3814854, 0.404161, 0.43483999, + 0.494863845, 0.592202662499998, 0.737413847999994), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0.0319186440152902, + 0.118606588418984, 0.166386434627046, 0.198884154069741, 0.224089313858389, + 0.245418255377554, 0.2641052, 0.281445422925429, 0.297451875378704, + 0.3121244, 0.327667648091081, 0.343487967727477, 0.360314881408664, + 0.379575527422374, 0.400991145952209, 0.426605204088841, 0.4588495, + 0.506128350755908, 0.604640728888889, 0.713520019350718, 0.848429920658984 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, + 0, 0.0628145244703447, 0.119951261697167, 0.161800708429584, + 0.194481529786298, 0.221976473503235, 0.246382528361484, 0.268661795456855, + 0.29099237601426, 0.3121244, 0.332687273503235, 0.354487379145491, + 0.376704773503235, 0.401222379758598, 0.428725473503235, 0.462071908680987, + 0.503745448659536, 0.564825512591627, 0.677307126205362, 0.788889302835928, + 0.92389000979736), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.0154147362739629, 0.0815589624901754, + 0.130419447103471, 0.16933591200637, 0.202296191455315, 0.23230661698317, + 0.260103744489245, 0.28583424396924, 0.3121244, 0.337226511153312, + 0.3628113, 0.3894886, 0.419049975899859, 0.453339140405904, + 0.492830630339104, 0.542883079890499, 0.613577832767128, + 0.73571689900399, 0.853844909059791, 0.988010467319443), + quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, + 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0.0493531737111374, 0.104172112803728, + 0.147940700281253, 0.185518687303273, 0.220197034594646, + 0.2521005, 0.282477641919719, 0.3121244, 0.3414694, 0.371435390499905, + 0.402230766363414, 0.436173824348844, 0.474579164424894, + 0.519690345185252, 0.57667375206677, 0.655151246845668, 0.78520792902029, + 0.90968118047453, 1.05112182091783), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0.28439515, 0.33688581, + 0.369872555, 0.3863845, 0.3945111, 0.40189893, 0.4078092, 0.4137194, + 0.4174134, 0.4218461, 0.4262788, 0.4299728, 0.435883, 0.44179327, + 0.4491811, 0.4573077, 0.473819645, 0.50680639, 0.55929705, 0.9841905175, + 1.556671116), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, + 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, + 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0.003694, 0.268840486221162, 0.320208490155752, + 0.34804029700677, 0.368653615349654, 0.3834292, 0.3945111, + 0.4041153, 0.413171785132151, 0.4218461, 0.430424661802068, + 0.4395769, 0.4491812, 0.4610017, 0.47590450199302, 0.497193409669697, + 0.525275921931869, 0.57616046396334, 0.97179808113241, 1.42880557869041, + 2.00265362857685), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.0925362072632727, 0.270427502912579, + 0.315212102423624, 0.343335698090731, 0.364285966419164, + 0.381412585636556, 0.3959887, 0.4092868, 0.4218461, 0.4344055, + 0.447738051828318, 0.4632179, 0.480948870517105, 0.502553166907419, + 0.531676966454865, 0.576804782629326, 0.776643061384413, + 1.21840177544959, 1.666716830807, 2.19163048441111), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.186887482630176, + 0.277238777881179, 0.317854348809488, 0.345779327332173, 0.367941987952029, + 0.38755201396574, 0.405055828677287, 0.4218461, 0.438666668060931, + 0.456611962704227, 0.476718028677287, 0.499751625882259, 0.528508989683397, + 0.569810205861059, 0.666081219804098, 0.934028445917159, 1.42658287124316, + 1.85311957889209, 2.30760254154095), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.0845659921302213, + 0.228553649752897, 0.289236861333113, 0.326073140839108, 0.354785333802038, + 0.379166830409904, 0.401230227456875, 0.4218461, 0.442801275729157, + 0.465572618600986, 0.490133389090691, 0.520052318734487, 0.558588500497255, + 0.62065225601836, 0.788392143304334, 1.05428294678997, 1.55684044507063, + 2.01374350966068, 2.37954449328776), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.33818795, 0.4386877525, + 0.528816855, 0.61252005, 0.6626973, 0.6816954, 0.697340875, 0.7085162, + 0.7152214, 0.7208091, 0.72745833, 0.7319844, 0.73651047, 0.7431597, + 0.7487474, 0.7554526, 0.766627925, 0.7822734, 0.8012715, 0.85144875, + 0.935151945, 1.0252810475, 1.12578085), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.276821846502455, + 0.354318476867519, 0.440270225449805, 0.533132934163242, 0.5900576, + 0.631102729748298, 0.660462274661497, 0.680831108876989, 0.696223359635746, + 0.7096337, 0.7219265, 0.7319844, 0.7431597, 0.7543351, 0.7677455, + 0.783391, 0.804046832839828, 0.833541896886769, 0.873735298798638, + 0.929106903073231, 1.02188617627186, 1.10971107833641, 1.18626816850867 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, + 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, + 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.202265200637946, + 0.298325094034965, 0.380907645938709, 0.481339524857949, 0.543219696138311, + 0.589507953775938, 0.6258186, 0.654874580912809, 0.6783427, 0.6984583, + 0.715655544727447, 0.7319844, 0.7487473, 0.7666278, 0.785715489951649, + 0.8090941, 0.83815, 0.873623567291473, 0.920206978680437, 0.98231174201862, + 1.08425930872329, 1.16639411427812, 1.25926838507547), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.129193504425124, + 0.241744300793533, 0.331949483165032, 0.43649858695157, 0.504472062268773, + 0.556141464729147, 0.597172505336053, 0.631406591640416, 0.660898437441874, + 0.686684727470375, 0.709633972330423, 0.7319844, 0.753217699696647, + 0.77608746100351, 0.8012715950276, 0.830327492252422, 0.86464477397774, + 0.906319686121761, 0.956815387818928, 1.02495125855129, 1.13129413647201, + 1.21644533535035, 1.32424172966634), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.0667682979050189, + 0.189580042212397, 0.290485041721667, 0.402951609190092, 0.475328740486855, + 0.530590906520765, 0.575504908587586, 0.613421932920829, 0.647285177364573, + 0.678099283398734, 0.70593862799773, 0.7319844, 0.758701322488325, + 0.786639532920829, 0.816837200234752, 0.850627936753767, 0.888963924063491, + 0.933785069065791, 0.988913131611816, 1.06240172852619, 1.16959624730917, + 1.2662008825538, 1.38860505690239), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, + 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 + )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0.0419413650000001, + 0.09882005, 0.1230992, 0.14226962, 0.1600776, 0.1722416, 0.1800265, + 0.1880061, 0.1936501, 0.1975426, 0.2014351, 0.2070791, 0.2150587, + 0.2228436, 0.2350076, 0.25281558, 0.271986, 0.29626515, 0.353143835, + 0.4353357125, 0.545314878), quantile_levels = c(0.01, 0.025, + 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, + 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.0438463650372504, 0.0808594787511875, + 0.106995615813358, 0.127478232938079, 0.145480846633466, + 0.1610508, 0.17461199504795, 0.186668812203222, 0.1975426, + 0.208428571374764, 0.2204108, 0.233930283744537, 0.249894552784127, + 0.267362348440485, 0.288755575723157, 0.316120297580926, + 0.355450425419354, 0.443192503687136, 0.536871211931719, + 0.636344785545224), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0.00188932708477086, 0.0470905919531195, + 0.079226864399944, 0.105414109111591, 0.127225815559956, + 0.146699420891509, 0.164644114298843, 0.18142942603581, 0.1975426, + 0.213933119201142, 0.231001630488804, 0.24941229702312, 0.269578845560456, + 0.292362546530965, 0.319632071367214, 0.354433951358713, + 0.406915236639266, 0.506944745332152, 0.596044605353528, + 0.695533388807317), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0.0156342454546545, 0.0536811248488485, + 0.084228833507335, 0.110407751354614, 0.134410113872139, + 0.156669167575476, 0.177701902429674, 0.1975426, 0.217759024165492, + 0.238897316673167, 0.261484572608426, 0.286120039498095, + 0.313065324705997, 0.345395334882349, 0.386811116673167, + 0.44780805303823, 0.550781846423163, 0.644984940689833, 0.752937731654986 + ), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, + 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, + 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0, 0, 0, 0, 0.0290260214229144, 0.0653218111708617, + 0.0966336637233373, 0.124670861123061, 0.149775978614687, + 0.174275935467055, 0.1975426, 0.221291415429954, 0.246723385601356, + 0.273144383515685, 0.30101566402084, 0.33204051788793, 0.369730347126771, + 0.416909038104281, 0.481925596660567, 0.58989871202142, 0.688635568252056, + 0.803906183401304), quantile_levels = c(0.01, 0.025, 0.05, + 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, + 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992), class = "Date"), target_date = structure(c(18997, 19002, + 19007, 19012, 19017, 18997, 19002, 19007, 19012, 19017, 18997, + 19002, 19007, 19012, 19017, 18997, 19002, 19007, 19012, 19017, + 18997, 19002, 19007, 19012, 19017, 18997, 19002, 19007, 19012, + 19017), class = "Date")), row.names = c(NA, -30L), class = c("tbl_df", + "tbl", "data.frame")) + +# arx_forecaster snapshots + + structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" + ), .pred = c(0.353013358779435, 0.648525432444877, 0.667670289394328, + 1.1418673907239, 0.830448695683587, 0.329799431948649), .pred_distn = structure(list( + structure(list(values = c(0.171022956902288, 0.535003760656582 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.46653503056773, 0.830515834322024), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.485679887517181, + 0.849660691271475), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.959876988846753, 1.32385779260105), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.64845829380644, + 1.01243909756073), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.147809030071502, 0.511789833825796), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18999, + 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, + -6L), class = c("tbl_df", "tbl", "data.frame")) + +--- + + structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" + ), .pred = c(0.149303403634373, 0.139764664505948, 0.333186321066645, + 0.470345577837144, 0.725986105412008, 0.212686665274007), .pred_distn = structure(list( + structure(list(values = c(0.0961118191398634, 0.202494988128882 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.0865730800114383, 0.192956249000457), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.279994736572136, + 0.386377905561154), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.417153993342634, 0.523537162331653), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.672794520917498, + 0.779177689906517), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.159495080779498, 0.265878249768516), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18993, + 18993, 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, + -6L), class = c("tbl_df", "tbl", "data.frame")) + +# arx_classifier snapshots + + structure(list(geo_value = c("ak", "al", "ar", "az", "ca", "co", + "ct", "dc", "de", "fl", "ga", "gu", "hi", "ia", "id", "il", "in", + "ks", "ky", "la", "ma", "me", "mi", "mn", "mo", "mp", "ms", "mt", + "nc", "nd", "ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or", + "pa", "pr", "ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa", + "wi", "wv", "wy"), .pred_class = structure(c(1L, 1L, 1L, 1L, + 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, + 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, + 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, + 1L), levels = c("(-Inf,0.25]", "(0.25, Inf]"), class = "factor"), + forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, + -53L), class = c("tbl_df", "tbl", "data.frame")) + diff --git a/tests/testthat/_snaps/step_epi_naomit.md b/tests/testthat/_snaps/step_epi_naomit.md new file mode 100644 index 000000000..653e84d0e --- /dev/null +++ b/tests/testthat/_snaps/step_epi_naomit.md @@ -0,0 +1,8 @@ +# Argument must be a recipe + + Code + step_epi_naomit(x) + Condition + Error in `step_epi_naomit()`: + ! inherits(recipe, "recipe") is not TRUE + diff --git a/tests/testthat/_snaps/step_epi_shift.md b/tests/testthat/_snaps/step_epi_shift.md new file mode 100644 index 000000000..44c828118 --- /dev/null +++ b/tests/testthat/_snaps/step_epi_shift.md @@ -0,0 +1,36 @@ +# Values for ahead and lag must be integer values + + Code + r1 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 3.6) %>% + step_epi_lag(death_rate, lag = 1.9) + Condition + Error in `step_epi_ahead()`: + ! `ahead` must be a non-negative integer. + +# A negative lag value should should throw an error + + Code + r2 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag( + death_rate, lag = -7) + Condition + Error in `step_epi_lag()`: + ! `lag` must be a non-negative integer. + +# A nonpositive ahead value should throw an error + + Code + r3 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = -7) %>% step_epi_lag( + death_rate, lag = 7) + Condition + Error in `step_epi_ahead()`: + ! `ahead` must be a non-negative integer. + +# Values for ahead and lag cannot be duplicates + + Code + slm_fit(r4) + Condition + Error in `bake()`: + ! Name collision occured in + The following variable name already exists: "lag_7_death_rate". + diff --git a/tests/testthat/_snaps/step_epi_slide.md b/tests/testthat/_snaps/step_epi_slide.md new file mode 100644 index 000000000..a4b9d64c8 --- /dev/null +++ b/tests/testthat/_snaps/step_epi_slide.md @@ -0,0 +1,145 @@ +# epi_slide errors when needed + + Code + recipe(edf) %>% step_epi_slide(value, .f = mean, .window_size = 7L) + Condition + Error in `step_epi_slide()`: + ! This recipe step can only operate on an . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = c(3L, 6L)) + Condition + Error in `epiprocess:::validate_slide_window_arg()`: + ! Slide function expected `.window_size` to be a non-null, scalar integer >= 1. + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .align = c("right", "left")) + Condition + Error in `step_epi_slide()`: + ! step_epi_slide: `.window_size` must be specified. + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, skip = c(TRUE, FALSE)) + Condition + Error in `step_epi_slide()`: + ! `skip` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, role = letters[1:2]) + Condition + Error in `step_epi_slide()`: + ! `role` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, prefix = letters[1:2]) + Condition + Error in `step_epi_slide()`: + ! `prefix` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, id = letters[1:2]) + Condition + Error in `step_epi_slide()`: + ! `id` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1.5) + Condition + Error in `epiprocess:::validate_slide_window_arg()`: + ! Slide function expected `.window_size` to be a difftime with units in days or non-negative integer or Inf. + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, .align = 1.5) + Condition + Error in `step_epi_slide()`: + ! `.align` must be a character vector, not the number 1.5. + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, skip = "a") + Condition + Error in `step_epi_slide()`: + ! `skip` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, role = 1) + Condition + Error in `step_epi_slide()`: + ! `role` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, prefix = 1) + Condition + Error in `step_epi_slide()`: + ! `prefix` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value, .f = mean, .window_size = 1L, id = 1) + Condition + Error in `step_epi_slide()`: + ! `id` must be a scalar of type . + +--- + + Code + r %>% step_epi_slide(value) + Condition + Error in `step_epi_slide()`: + ! argument ".f" is missing, with no default + +--- + + Code + r %>% step_epi_slide(value, .f = 1) + Condition + Error in `validate_slide_fun()`: + ! In, `step_epi_slide()`, `.f` must be a function. + +--- + + Code + r %>% step_epi_slide(value) + Condition + Error in `step_epi_slide()`: + ! argument ".f" is missing, with no default + +--- + + Code + r %>% step_epi_slide(value, .f = 1) + Condition + Error in `validate_slide_fun()`: + ! In, `step_epi_slide()`, `.f` must be a function. + +# epi_slide handles different function specs + + Code + lfun <- r %>% step_epi_slide(value, .f = ~ mean(.x, na.rm = TRUE), + .window_size = 4L) + Condition + Error in `validate_slide_fun()`: + ! In, `step_epi_slide()`, `.f` cannot be a formula. + diff --git a/tests/testthat/_snaps/step_growth_rate.md b/tests/testthat/_snaps/step_growth_rate.md new file mode 100644 index 000000000..5a3ac6f44 --- /dev/null +++ b/tests/testthat/_snaps/step_growth_rate.md @@ -0,0 +1,121 @@ +# step_growth_rate validates arguments + + Code + step_growth_rate(r) + Condition + Error in `step_growth_rate()`: + ! This recipe step can only operate on an . + +--- + + Code + step_growth_rate(r, value, role = 1) + Condition + Error in `step_growth_rate()`: + ! `role` must be of type . + +--- + + Code + step_growth_rate(r, value, method = "abc") + Condition + Error in `step_growth_rate()`: + ! `method` must be one of "rel_change" or "linear_reg", not "abc". + +--- + + Code + step_growth_rate(r, value, horizon = 0) + Condition + Error in `step_growth_rate()`: + ! `horizon` must be a positive integer. + +--- + + Code + step_growth_rate(r, value, horizon = c(1, 2)) + Condition + Error in `step_growth_rate()`: + ! `horizon` must be a scalar. + +--- + + Code + step_growth_rate(r, value, prefix = letters[1:2]) + Condition + Error in `step_growth_rate()`: + ! `prefix` must be a scalar of type . + +--- + + Code + step_growth_rate(r, value, id = letters[1:2]) + Condition + Error in `step_growth_rate()`: + ! `id` must be a scalar of type . + +--- + + Code + step_growth_rate(r, value, prefix = letters[1:2]) + Condition + Error in `step_growth_rate()`: + ! `prefix` must be a scalar of type . + +--- + + Code + step_growth_rate(r, value, prefix = 1) + Condition + Error in `step_growth_rate()`: + ! `prefix` must be a scalar of type . + +--- + + Code + step_growth_rate(r, value, id = 1) + Condition + Error in `step_growth_rate()`: + ! `id` must be a scalar of type . + +--- + + Code + step_growth_rate(r, value, log_scale = 1) + Condition + Error in `step_growth_rate()`: + ! `log_scale` must be a scalar of type . + +--- + + Code + step_growth_rate(r, value, skip = 1) + Condition + Error in `step_growth_rate()`: + ! `skip` must be a scalar of type . + +--- + + Code + step_growth_rate(r, value, additional_gr_args_list = 1:5) + Condition + Error in `step_growth_rate()`: + ! `additional_gr_args_list` must be a . + i See `?epiprocess::growth_rate` for available options. + +--- + + Code + step_growth_rate(r, value, replace_Inf = "c") + Condition + Error in `step_growth_rate()`: + ! `replace_Inf` must be of type . + +--- + + Code + step_growth_rate(r, value, replace_Inf = c(1, 2)) + Condition + Error in `step_growth_rate()`: + ! replace_Inf must be a scalar. + diff --git a/tests/testthat/_snaps/step_lag_difference.md b/tests/testthat/_snaps/step_lag_difference.md new file mode 100644 index 000000000..4edc9c287 --- /dev/null +++ b/tests/testthat/_snaps/step_lag_difference.md @@ -0,0 +1,72 @@ +# step_lag_difference validates arguments + + Code + step_lag_difference(r) + Condition + Error in `step_lag_difference()`: + ! This recipe step can only operate on an . + +--- + + Code + step_lag_difference(r, value, role = 1) + Condition + Error in `step_lag_difference()`: + ! `role` must be of type . + +--- + + Code + step_lag_difference(r, value, horizon = 0) + Condition + Error in `step_lag_difference()`: + ! `horizon` must be a positive integer. + +--- + + Code + step_lag_difference(r, value, prefix = letters[1:2]) + Condition + Error in `step_lag_difference()`: + ! `prefix` must be a scalar of type . + +--- + + Code + step_lag_difference(r, value, id = letters[1:2]) + Condition + Error in `step_lag_difference()`: + ! `id` must be a scalar of type . + +--- + + Code + step_lag_difference(r, value, prefix = letters[1:2]) + Condition + Error in `step_lag_difference()`: + ! `prefix` must be a scalar of type . + +--- + + Code + step_lag_difference(r, value, prefix = 1) + Condition + Error in `step_lag_difference()`: + ! `prefix` must be a scalar of type . + +--- + + Code + step_lag_difference(r, value, id = 1) + Condition + Error in `step_lag_difference()`: + ! `id` must be a scalar of type . + +--- + + Code + step_lag_difference(r, value, skip = 1) + Condition + Error in `step_lag_difference()`: + ! `skip` must be a scalar of type . + diff --git a/tests/testthat/_snaps/wis-dist-quantiles.md b/tests/testthat/_snaps/wis-dist-quantiles.md new file mode 100644 index 000000000..fb9cfbdf6 --- /dev/null +++ b/tests/testthat/_snaps/wis-dist-quantiles.md @@ -0,0 +1,17 @@ +# wis dispatches and produces the correct values + + Code + weighted_interval_score(1:10, 10) + Condition + Error in `weighted_interval_score()`: + ! Weighted interval score can only be calculated if `x` + has class . + +--- + + Code + weighted_interval_score(dist_quantiles(list(1:4, 8:11), 1:4 / 5), 1:3) + Condition + Error in `weighted_interval_score()`: + ! Can't recycle `x` (size 2) to match `actual` (size 3). + diff --git a/tests/testthat/test-arg_is_.R b/tests/testthat/test-arg_is_.R index 7ca6f1d7f..89c2c936f 100644 --- a/tests/testthat/test-arg_is_.R +++ b/tests/testthat/test-arg_is_.R @@ -15,15 +15,16 @@ dd <- Sys.Date() - 5 v <- 1:5 l <- TRUE ll <- c(TRUE, FALSE) +z <- character(0) test_that("logical", { expect_silent(arg_is_lgl(l)) expect_silent(arg_is_lgl(ll)) expect_silent(arg_is_lgl(l, ll)) - expect_error(arg_is_lgl(l, ll, n)) - expect_error(arg_is_lgl(x)) + expect_snapshot(error = TRUE, arg_is_lgl(l, ll, n)) + expect_snapshot(error = TRUE, arg_is_lgl(x)) expect_silent(arg_is_lgl(l, ll, n, allow_null = TRUE)) - expect_error(arg_is_lgl(l, ll, nn)) + expect_snapshot(error = TRUE, arg_is_lgl(l, ll, nn)) expect_silent(arg_is_lgl(l, ll, nn, allow_na = TRUE)) }) @@ -31,122 +32,123 @@ test_that("scalar", { expect_silent(arg_is_scalar(x)) expect_silent(arg_is_scalar(dd)) expect_silent(arg_is_scalar(x, y, dd)) - expect_error(arg_is_scalar(x, y, n)) + expect_snapshot(error = TRUE, arg_is_scalar(x, y, n)) expect_silent(arg_is_scalar(x, y, n, allow_null = TRUE)) - expect_error(arg_is_scalar(x, y, nn)) + expect_snapshot(error = TRUE, arg_is_scalar(x, y, nn)) expect_silent(arg_is_scalar(x, y, nn, allow_na = TRUE)) - expect_error(arg_is_scalar(v, nn)) - expect_error(arg_is_scalar(v, nn, allow_na = TRUE)) - expect_error(arg_is_scalar(v, n, allow_null = TRUE)) - expect_error(arg_is_scalar(nnn, allow_na = TRUE)) + expect_snapshot(error = TRUE, arg_is_scalar(v, nn)) + expect_snapshot(error = TRUE, arg_is_scalar(v, nn, allow_na = TRUE)) + expect_snapshot(error = TRUE, arg_is_scalar(v, n, allow_null = TRUE)) + expect_snapshot(error = TRUE, arg_is_scalar(nnn, allow_na = TRUE)) }) test_that("numeric", { expect_silent(arg_is_numeric(i, j, x, y)) - expect_error(arg_is_numeric(a)) - expect_error(arg_is_numeric(d)) + expect_snapshot(error = TRUE, arg_is_numeric(a)) + expect_silent(arg_is_numeric(d)) expect_silent(arg_is_numeric(c(i, j))) expect_silent(arg_is_numeric(i, k)) expect_silent(arg_is_numeric(i, j, n, allow_null = TRUE)) - expect_error(arg_is_numeric(i, j, n)) - expect_error(arg_is_numeric(i, nn)) + expect_snapshot(error = TRUE, arg_is_numeric(i, j, n)) + expect_snapshot(error = TRUE, arg_is_numeric(i, nn)) expect_silent(arg_is_numeric(a = -10:10)) }) test_that("positive", { expect_silent(arg_is_pos(i, j, x, y)) - expect_error(arg_is_pos(a)) - expect_error(arg_is_pos(d)) + expect_snapshot(error = TRUE, arg_is_pos(a)) + expect_silent(arg_is_pos(d)) expect_silent(arg_is_pos(c(i, j))) - expect_error(arg_is_pos(i, k)) + expect_snapshot(error = TRUE, arg_is_pos(i, k)) expect_silent(arg_is_pos(i, j, n, allow_null = TRUE)) - expect_error(arg_is_pos(i, j, n)) - expect_error(arg_is_pos(i, nn)) - expect_error(arg_is_pos(a = 0:10)) + expect_snapshot(error = TRUE, arg_is_pos(i, j, n)) + expect_snapshot(error = TRUE, arg_is_pos(i, nn)) + expect_snapshot(error = TRUE, arg_is_pos(a = 0:10)) }) test_that("nonneg", { expect_silent(arg_is_nonneg(i, j, x, y)) - expect_error(arg_is_nonneg(a)) - expect_error(arg_is_nonneg(d)) + expect_snapshot(error = TRUE, arg_is_nonneg(a)) + expect_silent(arg_is_nonneg(d)) expect_silent(arg_is_nonneg(c(i, j))) - expect_error(arg_is_nonneg(i, k)) + expect_snapshot(error = TRUE, arg_is_nonneg(i, k)) expect_silent(arg_is_nonneg(i, j, n, allow_null = TRUE)) - expect_error(arg_is_nonneg(i, j, n)) - expect_error(arg_is_nonneg(i, nn)) + expect_snapshot(error = TRUE, arg_is_nonneg(i, j, n)) + expect_snapshot(error = TRUE, arg_is_nonneg(i, nn)) expect_silent(arg_is_nonneg(a = 0:10)) }) test_that("nonneg-int", { - expect_error(arg_is_nonneg_int(a)) - expect_error(arg_is_nonneg_int(d)) + expect_snapshot(error = TRUE, arg_is_nonneg_int(a)) + expect_snapshot(error = TRUE, arg_is_nonneg_int(d)) expect_silent(arg_is_nonneg_int(i, j)) expect_silent(arg_is_nonneg_int(c(i, j))) - expect_error(arg_is_nonneg_int(i, k)) + expect_snapshot(error = TRUE, arg_is_nonneg_int(i, k)) expect_silent(arg_is_nonneg_int(i, j, n, allow_null = TRUE)) - expect_error(arg_is_nonneg_int(i, j, n)) - expect_error(arg_is_nonneg_int(i, nn)) + expect_snapshot(error = TRUE, arg_is_nonneg_int(i, j, n)) + expect_snapshot(error = TRUE, arg_is_nonneg_int(i, nn)) expect_silent(arg_is_nonneg_int(a = 0:10)) }) test_that("date", { expect_silent(arg_is_date(d, dd)) expect_silent(arg_is_date(c(d, dd))) - expect_error(arg_is_date(d, dd, n)) - expect_error(arg_is_date(d, dd, nn)) + expect_snapshot(error = TRUE, arg_is_date(d, dd, n)) + expect_snapshot(error = TRUE, arg_is_date(d, dd, nn)) expect_silent(arg_is_date(d, dd, n, allow_null = TRUE)) - expect_silent(arg_is_date(d, dd, nn, allow_na = TRUE)) - expect_error(arg_is_date(a)) - expect_error(arg_is_date(v)) - expect_error(arg_is_date(ll)) + # Upstream issue, see: https://github.com/mllg/checkmate/issues/256 + # expect_silent(arg_is_date(d, dd, nn, allow_na = TRUE)) + expect_snapshot(error = TRUE, arg_is_date(a)) + expect_snapshot(error = TRUE, arg_is_date(v)) + expect_snapshot(error = TRUE, arg_is_date(ll)) }) test_that("probabilities", { expect_silent(arg_is_probabilities(i, x)) - expect_error(arg_is_probabilities(a)) - expect_error(arg_is_probabilities(d)) + expect_snapshot(error = TRUE, arg_is_probabilities(a)) + expect_snapshot(error = TRUE, arg_is_probabilities(d)) expect_silent(arg_is_probabilities(c(.4, .7))) - expect_error(arg_is_probabilities(i, 1.1)) + expect_snapshot(error = TRUE, arg_is_probabilities(i, 1.1)) expect_silent(arg_is_probabilities(c(.4, .8), n, allow_null = TRUE)) - expect_error(arg_is_probabilities(c(.4, .8), n)) - expect_error(arg_is_probabilities(c(.4, .8), nn)) + expect_snapshot(error = TRUE, arg_is_probabilities(c(.4, .8), n)) + expect_snapshot(error = TRUE, arg_is_probabilities(c(.4, .8), nn)) }) test_that("chr", { expect_silent(arg_is_chr(a, b)) expect_silent(arg_is_chr(c(a, b))) - expect_error(arg_is_chr(a, b, n)) - expect_error(arg_is_chr(a, b, nn)) + expect_snapshot(error = TRUE, arg_is_chr(a, b, n)) + expect_snapshot(error = TRUE, arg_is_chr(a, b, nn)) expect_silent(arg_is_chr(a, b, n, allow_null = TRUE)) expect_silent(arg_is_chr(a, b, nn, allow_na = TRUE)) - expect_error(arg_is_chr(d)) - expect_error(arg_is_chr(v)) - expect_error(arg_is_chr(ll)) - expect_error(arg_is_chr(z = character(0))) - expect_silent(arg_is_chr(z = character(0), allow_empty = TRUE)) + expect_snapshot(error = TRUE, arg_is_chr(d)) + expect_snapshot(error = TRUE, arg_is_chr(v)) + expect_snapshot(error = TRUE, arg_is_chr(ll)) + expect_snapshot(error = TRUE, arg_is_chr(z)) + expect_silent(arg_is_chr(z, allow_empty = TRUE)) }) test_that("function", { expect_silent(arg_is_function(f, g, parsnip::linear_reg)) - expect_error(arg_is_function(c(a, b))) - expect_error(arg_is_function(c(f, g))) - expect_error(arg_is_function(f = NULL)) - expect_silent(arg_is_function(g, f = NULL, allow_null = TRUE)) + expect_snapshot(error = TRUE, arg_is_function(c(a, b))) + expect_snapshot(error = TRUE, arg_is_function(c(f, g))) + f <- NULL + expect_snapshot(error = TRUE, arg_is_function(f)) + expect_silent(arg_is_function(g, f, allow_null = TRUE)) }) -test_that("sorted", { - expect_silent(arg_is_sorted(a = 1:5, b = 6:10)) - expect_error(arg_is_sorted(a = 5:1, b = 6:10)) - expect_error(arg_is_sorted(b = NULL)) - expect_silent(arg_is_sorted(b = NULL, allow_null = TRUE)) -}) - - test_that("coerce scalar to date", { - expect_error(arg_to_date("12345")) + expect_snapshot(error = TRUE, arg_to_date("12345")) expect_s3_class(arg_to_date(12345), "Date") expect_s3_class(arg_to_date("2020-01-01"), "Date") - expect_error(arg_to_date(c("12345", "12345"))) + expect_snapshot(error = TRUE, arg_to_date(c("12345", "12345"))) +}) + +test_that("simple surface step test", { + expect_snapshot( + error = TRUE, + epi_recipe(jhu_csse_daily_subset) %>% step_epi_lag(death_rate, lag = "hello") + ) }) diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R index 7566fd90d..03cbc0025 100644 --- a/tests/testthat/test-arx_args_list.R +++ b/tests/testthat/test-arx_args_list.R @@ -1,30 +1,36 @@ test_that("arx_args checks inputs", { expect_s3_class(arx_args_list(), c("arx_fcast", "alist")) - expect_error(arx_args_list(ahead = c(0, 4))) - expect_error(arx_args_list(n_training = c(28, 65))) + expect_snapshot(error = TRUE, arx_args_list(ahead = c(0, 4))) + expect_snapshot(error = TRUE, arx_args_list(n_training = c(28, 65))) - expect_error(arx_args_list(ahead = -1)) - expect_error(arx_args_list(ahead = 1.5)) - expect_error(arx_args_list(n_training = -1)) - expect_error(arx_args_list(n_training = 1.5)) - expect_error(arx_args_list(lags = c(-1, 0))) - expect_error(arx_args_list(lags = list(c(1:5, 6.5), 2:8))) + expect_snapshot(error = TRUE, arx_args_list(ahead = -1)) + expect_snapshot(error = TRUE, arx_args_list(ahead = 1.5)) + expect_snapshot(error = TRUE, arx_args_list(n_training = -1)) + expect_snapshot(error = TRUE, arx_args_list(n_training = 1.5)) + expect_snapshot(error = TRUE, arx_args_list(lags = c(-1, 0))) + expect_snapshot(error = TRUE, arx_args_list(lags = list(c(1:5, 6.5), 2:8))) - expect_error(arx_args_list(symmetrize = 4)) - expect_error(arx_args_list(nonneg = 4)) + expect_snapshot(error = TRUE, arx_args_list(symmetrize = 4)) + expect_snapshot(error = TRUE, arx_args_list(nonneg = 4)) - expect_error(arx_args_list(quantile_levels = -.1)) - expect_error(arx_args_list(quantile_levels = 1.1)) + expect_snapshot(error = TRUE, arx_args_list(quantile_levels = -.1)) + expect_snapshot(error = TRUE, arx_args_list(quantile_levels = 1.1)) expect_type(arx_args_list(quantile_levels = NULL), "list") - expect_error(arx_args_list(target_date = "2022-01-01")) + expect_snapshot(error = TRUE, arx_args_list(target_date = "2022-01-01")) expect_identical( arx_args_list(target_date = as.Date("2022-01-01"))$target_date, as.Date("2022-01-01") ) - expect_error(arx_args_list(n_training_min = "de")) - expect_error(arx_args_list(epi_keys = 1)) + expect_snapshot(error = TRUE, arx_args_list(n_training_min = "de")) + expect_snapshot(error = TRUE, arx_args_list(epi_keys = 1)) + + expect_warning(arx_args_list( + forecast_date = as.Date("2022-01-01"), + target_date = as.Date("2022-01-03"), + ahead = 1L + )) }) test_that("arx forecaster disambiguates quantiles", { @@ -52,7 +58,7 @@ test_that("arx forecaster disambiguates quantiles", { sort(unique(tlist)) ) alist <- c(.1, .3, .5, .7, .9) # neither default, and different, - expect_error(compare_quantile_args(alist, tlist)) + expect_snapshot(error = TRUE, compare_quantile_args(alist, tlist)) }) test_that("arx_lags_validator handles named & unnamed lists as expected", { @@ -88,7 +94,7 @@ test_that("arx_lags_validator handles named & unnamed lists as expected", { ) # More lags than predictors - Error - expect_error(arx_lags_validator(pred_vec, lags_finit_fn_switch2)) + expect_snapshot(error = TRUE, arx_lags_validator(pred_vec, lags_finit_fn_switch2)) # Unnamed list of lags lags_init_un <- list(c(0, 7, 14), c(0, 1, 2, 3, 7, 14)) @@ -109,5 +115,5 @@ test_that("arx_lags_validator handles named & unnamed lists as expected", { # Try use a name not in predictors - Error lags_init_other_name <- list(death_rate = c(0, 7, 14), test_var = c(0, 1, 2, 3, 7, 14)) - expect_error(arx_lags_validator(pred_vec, lags_init_other_name)) + expect_snapshot(error = TRUE, arx_lags_validator(pred_vec, lags_init_other_name)) }) diff --git a/tests/testthat/test-arx_cargs_list.R b/tests/testthat/test-arx_cargs_list.R index 69901220c..12087e45f 100644 --- a/tests/testthat/test-arx_cargs_list.R +++ b/tests/testthat/test-arx_cargs_list.R @@ -1,22 +1,28 @@ test_that("arx_class_args checks inputs", { expect_s3_class(arx_class_args_list(), c("arx_class", "alist")) - expect_error(arx_class_args_list(ahead = c(0, 4))) - expect_error(arx_class_args_list(n_training = c(28, 65))) + expect_snapshot(error = TRUE, arx_class_args_list(ahead = c(0, 4))) + expect_snapshot(error = TRUE, arx_class_args_list(n_training = c(28, 65))) - expect_error(arx_class_args_list(ahead = -1)) - expect_error(arx_class_args_list(ahead = 1.5)) - expect_error(arx_class_args_list(n_training = -1)) - expect_error(arx_class_args_list(n_training = 1.5)) - expect_error(arx_class_args_list(lags = c(-1, 0))) - expect_error(arx_class_args_list(lags = list(c(1:5, 6.5), 2:8))) + expect_snapshot(error = TRUE, arx_class_args_list(ahead = -1)) + expect_snapshot(error = TRUE, arx_class_args_list(ahead = 1.5)) + expect_snapshot(error = TRUE, arx_class_args_list(n_training = -1)) + expect_snapshot(error = TRUE, arx_class_args_list(n_training = 1.5)) + expect_snapshot(error = TRUE, arx_class_args_list(lags = c(-1, 0))) + expect_snapshot(error = TRUE, arx_class_args_list(lags = list(c(1:5, 6.5), 2:8))) - expect_error(arx_class_args_list(target_date = "2022-01-01")) + expect_snapshot(error = TRUE, arx_class_args_list(target_date = "2022-01-01")) expect_identical( arx_class_args_list(target_date = as.Date("2022-01-01"))$target_date, as.Date("2022-01-01") ) - expect_error(arx_class_args_list(n_training_min = "de")) - expect_error(arx_class_args_list(epi_keys = 1)) + expect_snapshot(error = TRUE, arx_class_args_list(n_training_min = "de")) + expect_snapshot(error = TRUE, arx_class_args_list(epi_keys = 1)) + + expect_warning(arx_class_args_list( + forecast_date = as.Date("2022-01-01"), + target_date = as.Date("2022-01-03"), + ahead = 1L + )) }) diff --git a/tests/testthat/test-bake-method.R b/tests/testthat/test-bake-method.R new file mode 100644 index 000000000..06f861012 --- /dev/null +++ b/tests/testthat/test-bake-method.R @@ -0,0 +1,29 @@ +test_that("bake method works in all cases", { + edf <- case_death_rate_subset %>% + filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + r <- epi_recipe(edf) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) + + r2 <- epi_recipe(edf) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() + + b_null <- bake(prep(r, edf), NULL) + b_train <- bake(prep(r, edf), edf) + expect_s3_class(b_null, "epi_df") + expect_identical(b_null, b_train) + + b_baked <- bake(prep(r2, edf), edf) # leaves rows with NA in the response + # doesn't (because we "juice", so skip doesn't apply) + b_juiced <- bake(prep(r2, edf), NULL) + expect_equal(nrow(b_juiced), sum(complete.cases(b_train))) + expect_equal(nrow(b_baked), sum(complete.cases(b_train)) + 3 * 7) + + # check that the {recipes} behaves + expect_s3_class(bake(prep(r, edf), NULL, composition = "tibble"), "tbl_df") + expect_s3_class(bake(prep(r, edf), NULL, composition = "data.frame"), "data.frame") + # can't be a matrix because time_value/geo_value aren't numeric + expect_snapshot(error = TRUE, bake(prep(r, edf), NULL, composition = "matrix")) +}) diff --git a/tests/testthat/test-check-training-set.R b/tests/testthat/test-check-training-set.R index 0f9246282..64d4d6945 100644 --- a/tests/testthat/test-check-training-set.R +++ b/tests/testthat/test-check-training-set.R @@ -7,7 +7,7 @@ test_that("training set validation works", { expect_silent(validate_meta_match(template, template, "time_type", "blah")) attr(t1, "metadata")$geo_type <- "county" expect_warning(validate_meta_match(t1, template, "geo_type"), "county") - expect_error(validate_meta_match(t1, template, "geo_type", "abort"), "county") + expect_snapshot(error = TRUE, validate_meta_match(t1, template, "geo_type", "abort")) expect_identical(template, epi_check_training_set(template, rec)) @@ -25,5 +25,5 @@ test_that("training set validation works", { expect_warning(t4 <- epi_check_training_set(t3, rec)) expect_identical(rec$template, t4) attr(rec$template, "metadata")$other_keys <- "missing_col" - expect_error(epi_check_training_set(t4, rec), "missing_col") + expect_snapshot(error = TRUE, epi_check_training_set(t4, rec)) }) diff --git a/tests/testthat/test-check_enough_train_data.R b/tests/testthat/test-check_enough_train_data.R index 5eae01bb2..9b2ef5f34 100644 --- a/tests/testthat/test-check_enough_train_data.R +++ b/tests/testthat/test-check_enough_train_data.R @@ -19,23 +19,24 @@ test_that("check_enough_train_data works on pooled data", { expect_no_error( epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n, drop_na = FALSE) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) ) # Check both column don't have enough data - expect_error( + expect_snapshot( + error = TRUE, epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n + 1, drop_na = FALSE) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL), - regexp = "The following columns don't have enough data" + prep(toy_epi_df) %>% + bake(new_data = NULL) ) # Check drop_na works - expect_error( + expect_snapshot( + error = TRUE, epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n - 1, drop_na = TRUE) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) ) }) @@ -44,23 +45,24 @@ test_that("check_enough_train_data works on unpooled data", { expect_no_error( epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = n, epi_keys = "geo_value", drop_na = FALSE) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) ) # Check one column don't have enough data - expect_error( + expect_snapshot( + error = TRUE, epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = n + 1, epi_keys = "geo_value", drop_na = FALSE) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL), - regexp = "The following columns don't have enough data" + prep(toy_epi_df) %>% + bake(new_data = NULL) ) # Check drop_na works - expect_error( + expect_snapshot( + error = TRUE, epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n - 3, epi_keys = "geo_value", drop_na = TRUE) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) ) }) @@ -68,14 +70,14 @@ test_that("check_enough_train_data outputs the correct recipe values", { expect_no_error( p <- epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = 2 * n - 2) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) ) expect_equal(nrow(p), 2 * n) expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") - expect_named(p, c("time_value", "geo_value", "x", "y")) + expect_named(p, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p$time_value, rep(seq(as.Date("2020-01-01"), by = 1, length.out = n), times = 2) @@ -93,15 +95,15 @@ test_that("check_enough_train_data only checks train data", { expect_no_error( epi_recipe(toy_epi_df) %>% check_enough_train_data(x, y, n = n - 2, epi_keys = "geo_value") %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = toy_test_data) + prep(toy_epi_df) %>% + bake(new_data = toy_test_data) ) # Same thing, but skip = FALSE expect_no_error( epi_recipe(toy_epi_df) %>% check_enough_train_data(y, n = n - 2, epi_keys = "geo_value", skip = FALSE) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = toy_test_data) + prep(toy_epi_df) %>% + bake(new_data = toy_test_data) ) }) @@ -111,14 +113,15 @@ test_that("check_enough_train_data works with all_predictors() downstream of con epi_recipe(toy_epi_df) %>% step_epi_lag(x, lag = c(1, 2)) %>% check_enough_train_data(all_predictors(), y, n = 2 * n - 6) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) ) - expect_error( + expect_snapshot( + error = TRUE, epi_recipe(toy_epi_df) %>% step_epi_lag(x, lag = c(1, 2)) %>% check_enough_train_data(all_predictors(), y, n = 2 * n - 5) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) ) }) diff --git a/tests/testthat/test-dist_quantiles.R b/tests/testthat/test-dist_quantiles.R index 4fc5587d4..8112326dc 100644 --- a/tests/testthat/test-dist_quantiles.R +++ b/tests/testthat/test-dist_quantiles.R @@ -1,23 +1,15 @@ library(distributional) test_that("constructor returns reasonable quantiles", { - expect_error(new_quantiles(rnorm(5), rnorm(5))) + expect_snapshot(error = TRUE, new_quantiles(rnorm(5), rnorm(5))) expect_silent(new_quantiles(sort(rnorm(5)), sort(runif(5)))) - expect_error(new_quantiles(sort(rnorm(5)), sort(runif(2)))) + expect_snapshot(error = TRUE, new_quantiles(sort(rnorm(5)), sort(runif(2)))) expect_silent(new_quantiles(1:5, 1:5 / 10)) - expect_error(new_quantiles(c(2, 1, 3, 4, 5), c(.1, .1, .2, .5, .8))) - expect_error(new_quantiles(c(2, 1, 3, 4, 5), c(.1, .15, .2, .5, .8))) - expect_error(new_quantiles(c(1, 2, 3), c(.1, .2, 3))) + expect_snapshot(error = TRUE, new_quantiles(c(2, 1, 3, 4, 5), c(.1, .1, .2, .5, .8))) + expect_snapshot(error = TRUE, new_quantiles(c(2, 1, 3, 4, 5), c(.1, .15, .2, .5, .8))) + expect_snapshot(error = TRUE, new_quantiles(c(1, 2, 3), c(.1, .2, 3))) }) -test_that("tail functions give reasonable output", { - expect_equal(norm_q_par(qnorm(c(.75, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(norm_q_par(qnorm(c(.25, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(norm_q_par(qnorm(c(.25, .5), 0, 1)), list(m = 0, s = 1)) - expect_equal(exp_q_par(qlaplace(c(.75, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(exp_q_par(qlaplace(c(.25, .5), 10, 5)), list(m = 10, s = 5)) - expect_equal(exp_q_par(qlaplace(c(.25, .5), 0, 1)), list(m = 0, s = 1)) -}) test_that("single dist_quantiles works, quantiles are accessible", { z <- new_quantiles(values = 1:5, quantile_levels = c(.2, .4, .5, .6, .8)) @@ -32,21 +24,67 @@ test_that("single dist_quantiles works, quantiles are accessible", { extrapolate_quantiles(z, c(.3, .7), middle = "linear"), new_quantiles(values = c(1, 1.5, 2, 3, 4, 4.5, 5), quantile_levels = 2:8 / 10) ) + # empty values slot results in a length zero distribution + # see issue #361 + expect_length(dist_quantiles(list(), c(.1, .9)), 0L) + expect_identical( + dist_quantiles(list(), c(.1, .9)), + distributional::dist_degenerate(double()) + ) }) + test_that("quantile extrapolator works", { dstn <- dist_normal(c(10, 2), c(5, 10)) - qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) + qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) expect_s3_class(qq, "distribution") expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles") - expect_length(parameters(qq[1])$q[[1]], 3L) - + expect_length(parameters(qq[1])$quantile_levels[[1]], 3L) dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) - qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) + qq <- extrapolate_quantiles(dstn, probs = c(.25, 0.5, .75)) expect_s3_class(qq, "distribution") expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles") - expect_length(parameters(qq[1])$q[[1]], 7L) + expect_length(parameters(qq[1])$quantile_levels[[1]], 7L) + + dstn <- dist_quantiles(1:4, 1:4 / 5) + qq <- extrapolate_quantiles(dstn, 1:9 / 10) + dstn_na <- dist_quantiles(c(1, 2, NA, 4), 1:4 / 5) + qq2 <- extrapolate_quantiles(dstn_na, 1:9 / 10) + expect_equal(qq, qq2) + qq3 <- extrapolate_quantiles(dstn_na, 1:9 / 10, replace_na = FALSE) + qq2_vals <- field(vec_data(qq2)[[1]], "values") + qq3_vals <- field(vec_data(qq3)[[1]], "values") + qq2_vals[6] <- NA + expect_equal(qq2_vals, qq3_vals) +}) + +test_that("small deviations of quantile requests work", { + l <- c(.05, .1, .25, .75, .9, .95) + v <- c(0.0890306, 0.1424997, 0.1971793, 0.2850978, 0.3832912, 0.4240479) + badl <- l + badl[1] <- badl[1] - 1e-14 + distn <- dist_quantiles(list(v), list(l)) + + # was broken before, now works + expect_equal(quantile(distn, l), quantile(distn, badl)) + + # The tail extrapolation was still poor. It needs to _always_ use + # the smallest (largest) values or we could end up unsorted + l <- 1:9 / 10 + v <- 1:9 + distn <- dist_quantiles(list(v), list(l)) + expect_equal(quantile(distn, c(.25, .75)), list(c(2.5, 7.5))) + expect_equal(quantile(distn, c(.1, .9)), list(c(1, 9))) + qv <- data.frame(q = l, v = v) + expect_equal( + unlist(quantile(distn, c(.01, .05))), + tail_extrapolate(c(.01, .05), head(qv, 2)) + ) + expect_equal( + unlist(quantile(distn, c(.99, .95))), + tail_extrapolate(c(.95, .99), tail(qv, 2)) + ) }) test_that("unary math works on quantiles", { @@ -68,6 +106,6 @@ test_that("arithmetic works on quantiles", { expect_identical(dstn / 4, dstn2) expect_identical((1 / 4) * dstn, dstn2) - expect_error(sum(dstn)) - expect_error(suppressWarnings(dstn + distributional::dist_normal())) + expect_snapshot(error = TRUE, sum(dstn)) + expect_snapshot(error = TRUE, suppressWarnings(dstn + distributional::dist_normal())) }) diff --git a/tests/testthat/test-enframer.R b/tests/testthat/test-enframer.R index c555ea9b2..0926c587b 100644 --- a/tests/testthat/test-enframer.R +++ b/tests/testthat/test-enframer.R @@ -1,11 +1,11 @@ test_that("enframer errors/works as needed", { template1 <- data.frame(aa = 1:5, a = NA, b = NA, c = NA) template2 <- data.frame(aa = 1:5, a = 2:6, b = 2:6, c = 2:6) - expect_error(enframer(1:5, letters[1])) - expect_error(enframer(data.frame(a = 1:5), 1:3)) - expect_error(enframer(data.frame(a = 1:5), letters[1:3])) + expect_snapshot(error = TRUE, enframer(1:5, letters[1])) + expect_snapshot(error = TRUE, enframer(data.frame(a = 1:5), 1:3)) + expect_snapshot(error = TRUE, enframer(data.frame(a = 1:5), letters[1:3])) expect_identical(enframer(data.frame(aa = 1:5), letters[1:3]), template1) - expect_error(enframer(data.frame(aa = 1:5), letters[1:2], fill = 1:4)) + expect_snapshot(error = TRUE, enframer(data.frame(aa = 1:5), letters[1:2], fill = 1:4)) expect_identical( enframer(data.frame(aa = 1:5), letters[1:3], fill = 2:6), template2 diff --git a/tests/testthat/test-epi_recipe.R b/tests/testthat/test-epi_recipe.R index d288ec058..1b06cf24c 100644 --- a/tests/testthat/test-epi_recipe.R +++ b/tests/testthat/test-epi_recipe.R @@ -1,24 +1,12 @@ -test_that("epi_recipe produces default recipe", { - # these all call recipes::recipe(), but the template will always have 1 row +test_that("epi_recipe produces error if not an epi_df", { tib <- tibble( x = 1:5, y = 1:5, time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5) ) - rec <- recipes::recipe(tib) - rec$template <- rec$template[1, ] - expect_identical(rec, epi_recipe(tib)) - expect_equal(nrow(rec$template), 1L) - - rec <- recipes::recipe(y ~ x, tib) - rec$template <- rec$template[1, ] - expect_identical(rec, epi_recipe(y ~ x, tib)) - expect_equal(nrow(rec$template), 1L) - + expect_snapshot(error = TRUE, epi_recipe(tib)) + expect_snapshot(error = TRUE, epi_recipe(y ~ x, tib)) m <- as.matrix(tib) - rec <- recipes::recipe(m) - rec$template <- rec$template[1, ] - expect_identical(rec, epi_recipe(m)) - expect_equal(nrow(rec$template), 1L) + expect_snapshot(error = TRUE, epi_recipe(m)) }) test_that("epi_recipe formula works", { @@ -56,7 +44,7 @@ test_that("epi_recipe formula works", { time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5), geo_value = "ca", z = "dummy_key" - ) %>% epiprocess::as_epi_df(additional_metadata = list(other_keys = "z")) + ) %>% epiprocess::as_epi_df(other_keys = "z") # with an additional key r <- epi_recipe(y ~ x + geo_value, tib) @@ -125,7 +113,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { wf <- epi_workflow() %>% add_epi_recipe(r) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 3) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 7, 14)) @@ -140,7 +128,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { wf <- update_epi_recipe(wf, r2) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 2) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 1)) @@ -149,7 +137,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { # adjust_epi_recipe using step number wf <- adjust_epi_recipe(wf, which_step = 2, ahead = 7) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 2) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 1)) @@ -158,7 +146,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { # adjust_epi_recipe using step name wf <- adjust_epi_recipe(wf, which_step = "step_epi_ahead", ahead = 8) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 2) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 1)) @@ -167,6 +155,6 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { wf <- remove_epi_recipe(wf) - expect_error(extract_preprocessor(wf)$steps) + expect_snapshot(error = TRUE, workflows::extract_preprocessor(wf)$steps) expect_equal(wf$pre$actions$recipe$recipe, NULL) }) diff --git a/tests/testthat/test-epi_shift.R b/tests/testthat/test-epi_shift.R index b0ab3a21f..78c9384f1 100644 --- a/tests/testthat/test-epi_shift.R +++ b/tests/testthat/test-epi_shift.R @@ -24,7 +24,7 @@ test_that("epi shift single works, renames", { time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5), geo_value = "ca" ) %>% epiprocess::as_epi_df() - ess <- epi_shift_single(tib, "x", 1, "test", epi_keys(tib)) - expect_named(ess, c("time_value", "geo_value", "test")) + ess <- epi_shift_single(tib, "x", 1, "test", key_colnames(tib)) + expect_named(ess, c("geo_value", "time_value", "test")) expect_equal(ess$time_value, tib$time_value + 1) }) diff --git a/tests/testthat/test-epi_workflow.R b/tests/testthat/test-epi_workflow.R index c2be0e1cb..8bb58b0bc 100644 --- a/tests/testthat/test-epi_workflow.R +++ b/tests/testthat/test-epi_workflow.R @@ -59,6 +59,86 @@ test_that("model can be added/updated/removed from epi_workflow", { expect_equal(class(model_spec2), c("linear_reg", "model_spec")) wf <- remove_model(wf) - expect_error(extract_spec_parsnip(wf)) + expect_snapshot(error = TRUE, extract_spec_parsnip(wf)) expect_equal(wf$fit$actions$model$spec, NULL) }) + +test_that("forecast method works", { + jhu <- case_death_rate_subset %>% + filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + r <- epi_recipe(jhu) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() + wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) + expect_equal( + forecast(wf), + predict(wf, new_data = get_test_data( + hardhat::extract_preprocessor(wf), + jhu + )) + ) + + args <- list( + fill_locf = TRUE, + n_recent = 360 * 3, + forecast_date = as.Date("2024-01-01") + ) + expect_equal( + forecast(wf, !!!args), + predict(wf, new_data = get_test_data( + hardhat::extract_preprocessor(wf), + jhu, + !!!args + )) + ) +}) + +test_that("forecast method errors when workflow not fit", { + jhu <- case_death_rate_subset %>% + filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + r <- epi_recipe(jhu) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() + wf <- epi_workflow(r, parsnip::linear_reg()) + + expect_snapshot(error = TRUE, forecast(wf)) +}) + +test_that("fit method does not silently drop the class", { + # This is issue #363 + + library(recipes) + tbl <- tibble::tibble( + geo_value = 1, + time_value = 1:100, + x = 1:100, + y = x + rnorm(100L) + ) + edf <- as_epi_df(tbl) + + rec_tbl <- recipe(y ~ x, data = tbl) + rec_edf <- recipe(y ~ x, data = edf) + expect_snapshot(error = TRUE, epi_recipe(y ~ x, data = tbl)) + erec_edf <- epi_recipe(y ~ x, data = edf) + + ewf_rec_tbl <- epi_workflow(rec_tbl, linear_reg()) + ewf_rec_edf <- epi_workflow(rec_edf, linear_reg()) + ewf_erec_edf <- epi_workflow(erec_edf, linear_reg()) + + # above are all epi_workflows: + + expect_s3_class(ewf_rec_tbl, "epi_workflow") + expect_s3_class(ewf_rec_edf, "epi_workflow") + expect_s3_class(ewf_erec_edf, "epi_workflow") + + # but fitting drops the class or generates errors in many cases: + + expect_s3_class(ewf_rec_tbl %>% fit(tbl), "epi_workflow") + expect_s3_class(ewf_rec_tbl %>% fit(edf), "epi_workflow") + expect_s3_class(ewf_rec_edf %>% fit(tbl), "epi_workflow") + expect_s3_class(ewf_rec_edf %>% fit(edf), "epi_workflow") + expect_snapshot(ewf_erec_edf %>% fit(tbl), error = TRUE) + expect_s3_class(ewf_erec_edf %>% fit(edf), "epi_workflow") +}) diff --git a/tests/testthat/test-extract_argument.R b/tests/testthat/test-extract_argument.R index 0654304ba..7434763e7 100644 --- a/tests/testthat/test-extract_argument.R +++ b/tests/testthat/test-extract_argument.R @@ -4,27 +4,27 @@ test_that("layer argument extractor works", { layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) %>% layer_naomit(.pred) - expect_error(extract_argument(f$layers[[1]], "uhoh", "bubble")) - expect_error(extract_argument(f$layers[[1]], "layer_predict", "bubble")) + expect_snapshot(error = TRUE, extract_argument(f$layers[[1]], "uhoh", "bubble")) + expect_snapshot(error = TRUE, extract_argument(f$layers[[1]], "layer_predict", "bubble")) expect_identical( extract_argument(f$layers[[2]], "layer_residual_quantiles", "quantile_levels"), c(0.0275, 0.9750) ) - expect_error(extract_argument(f, "layer_thresh", "quantile_levels")) + expect_snapshot(error = TRUE, extract_argument(f, "layer_thresh", "quantile_levels")) expect_identical( extract_argument(f, "layer_residual_quantiles", "quantile_levels"), c(0.0275, 0.9750) ) wf <- epi_workflow(postprocessor = f) - expect_error(extract_argument(epi_workflow(), "layer_residual_quantiles", "quantile_levels")) + expect_snapshot(error = TRUE, extract_argument(epi_workflow(), "layer_residual_quantiles", "quantile_levels")) expect_identical( extract_argument(wf, "layer_residual_quantiles", "quantile_levels"), c(0.0275, 0.9750) ) - expect_error(extract_argument(wf, "layer_predict", c("type", "opts"))) + expect_snapshot(error = TRUE, extract_argument(wf, "layer_predict", c("type", "opts"))) }) test_that("recipe argument extractor works", { @@ -41,21 +41,21 @@ test_that("recipe argument extractor works", { step_naomit(all_outcomes(), skip = TRUE) - expect_error(extract_argument(r$steps[[1]], "uhoh", "bubble")) - expect_error(extract_argument(r$steps[[1]], "step_epi_lag", "bubble")) - expect_identical(extract_argument(r$steps[[2]], "step_epi_ahead", "ahead"), 7) + expect_snapshot(error = TRUE, extract_argument(r$steps[[1]], "uhoh", "bubble")) + expect_snapshot(error = TRUE, extract_argument(r$steps[[1]], "step_epi_lag", "bubble")) + expect_identical(extract_argument(r$steps[[2]], "step_epi_ahead", "ahead"), 7L) - expect_error(extract_argument(r, "step_lightly", "quantile_levels")) + expect_snapshot(error = TRUE, extract_argument(r, "step_lightly", "quantile_levels")) expect_identical( extract_argument(r, "step_epi_lag", "lag"), - list(c(0, 7, 14), c(0, 7, 14)) + list(c(0L, 7L, 14L), c(0L, 7L, 14L)) ) wf <- epi_workflow(preprocessor = r) - expect_error(extract_argument(epi_workflow(), "step_epi_lag", "lag")) + expect_snapshot(error = TRUE, extract_argument(epi_workflow(), "step_epi_lag", "lag")) expect_identical( extract_argument(wf, "step_epi_lag", "lag"), - list(c(0, 7, 14), c(0, 7, 14)) + list(c(0L, 7L, 14L), c(0L, 7L, 14L)) ) }) diff --git a/tests/testthat/test-flatline_args_list.R b/tests/testthat/test-flatline_args_list.R new file mode 100644 index 000000000..6359afc27 --- /dev/null +++ b/tests/testthat/test-flatline_args_list.R @@ -0,0 +1,35 @@ +test_that("flatline_args_list checks inputs", { + expect_s3_class(flatline_args_list(), c("flat_fcast", "alist")) + expect_snapshot(error = TRUE, flatline_args_list(ahead = c(0, 4))) + expect_snapshot(error = TRUE, flatline_args_list(n_training = c(28, 65))) + + expect_snapshot(error = TRUE, flatline_args_list(ahead = -1)) + expect_snapshot(error = TRUE, flatline_args_list(ahead = 1.5)) + expect_snapshot(error = TRUE, flatline_args_list(n_training = -1)) + expect_snapshot(error = TRUE, flatline_args_list(n_training = 1.5)) + expect_snapshot(error = TRUE, flatline_args_list(lags = c(-1, 0))) + expect_snapshot(error = TRUE, flatline_args_list(lags = list(c(1:5, 6.5), 2:8))) + + expect_snapshot(error = TRUE, flatline_args_list(symmetrize = 4)) + expect_snapshot(error = TRUE, flatline_args_list(nonneg = 4)) + + expect_snapshot(error = TRUE, flatline_args_list(quantile_levels = -.1)) + expect_snapshot(error = TRUE, flatline_args_list(quantile_levels = 1.1)) + expect_type(flatline_args_list(quantile_levels = NULL), "list") + + expect_snapshot(error = TRUE, flatline_args_list(target_date = "2022-01-01")) + expect_identical( + flatline_args_list(target_date = as.Date("2022-01-01"))$target_date, + as.Date("2022-01-01") + ) + + expect_snapshot(error = TRUE, flatline_args_list(n_training_min = "de")) + expect_snapshot(error = TRUE, flatline_args_list(epi_keys = 1)) + + # Detect mismatched ahead and target_date - forecast_date difference + expect_warning(flatline_args_list( + forecast_date = as.Date("2022-01-01"), + target_date = as.Date("2022-01-03"), + ahead = 1L + )) +}) diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index 9e8a6f90c..1bdce3b5a 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -7,7 +7,7 @@ test_that("frosting validators / constructors work", { expect_false(has_postprocessor_frosting(wf)) expect_silent(wf %>% add_frosting(new_frosting())) expect_silent(wf %>% add_postprocessor(new_frosting())) - expect_error(wf %>% add_postprocessor(list())) + expect_snapshot(error = TRUE, wf %>% add_postprocessor(list())) wf <- wf %>% add_frosting(new_frosting()) expect_true(has_postprocessor(wf)) @@ -16,7 +16,7 @@ test_that("frosting validators / constructors work", { test_that("frosting can be created/added/updated/adjusted/removed", { f <- frosting() - expect_error(frosting(layers = 1:5)) + expect_snapshot(error = TRUE, frosting(layers = 1:5)) wf <- epi_workflow() %>% add_frosting(f) expect_true(has_postprocessor_frosting(wf)) wf1 <- update_frosting(wf, frosting() %>% layer_predict() %>% layer_threshold(.pred)) @@ -72,8 +72,6 @@ test_that("layer_predict is added by default if missing", { wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) - latest <- get_test_data(recipe = r, x = jhu) - f1 <- frosting() %>% layer_naomit(.pred) %>% layer_residual_quantiles() @@ -86,5 +84,62 @@ test_that("layer_predict is added by default if missing", { wf1 <- wf %>% add_frosting(f1) wf2 <- wf %>% add_frosting(f2) - expect_equal(predict(wf1, latest), predict(wf2, latest)) + expect_equal(forecast(wf1), forecast(wf2)) +}) + + +test_that("parsnip settings can be passed through predict.epi_workflow", { + jhu <- case_death_rate_subset %>% + dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + + r <- epi_recipe(jhu) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() + + wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) + + latest <- get_test_data(r, jhu) + + f1 <- frosting() %>% layer_predict() + f2 <- frosting() %>% layer_predict(type = "pred_int") + f3 <- frosting() %>% layer_predict(type = "pred_int", level = 0.6) + + pred2 <- wf %>% + add_frosting(f2) %>% + predict(latest) + pred3 <- wf %>% + add_frosting(f3) %>% + predict(latest) + + pred2_re <- wf %>% + add_frosting(f1) %>% + predict(latest, type = "pred_int") + pred3_re <- wf %>% + add_frosting(f1) %>% + predict(latest, type = "pred_int", level = 0.6) + + expect_identical(pred2, pred2_re) + expect_identical(pred3, pred3_re) + + expect_error(wf %>% add_frosting(f2) %>% predict(latest, type = "raw"), + class = "epipredict__layer_predict__conflicting_type_settings" + ) + + f4 <- frosting() %>% + layer_predict() %>% + layer_threshold(.pred, lower = 0) + + expect_error(wf %>% add_frosting(f4) %>% predict(latest, type = "pred_int"), + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers" + ) + + # We also refuse to continue when just passing the level, which might not be ideal: + f5 <- frosting() %>% + layer_predict(type = "pred_int") %>% + layer_threshold(.pred_lower, .pred_upper, lower = 0) + + expect_error(wf %>% add_frosting(f5) %>% predict(latest, level = 0.6), + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers" + ) }) diff --git a/tests/testthat/test-get_test_data.R b/tests/testthat/test-get_test_data.R index 035fc6463..aa799150b 100644 --- a/tests/testthat/test-get_test_data.R +++ b/tests/testthat/test-get_test_data.R @@ -25,7 +25,7 @@ test_that("expect insufficient training data error", { step_naomit(all_predictors()) %>% step_naomit(all_outcomes(), skip = TRUE) - expect_error(get_test_data(recipe = r, x = case_death_rate_subset)) + expect_snapshot(error = TRUE, get_test_data(recipe = r, x = case_death_rate_subset)) }) @@ -39,7 +39,7 @@ test_that("expect error that geo_value or time_value does not exist", { wrong_epi_df <- case_death_rate_subset %>% dplyr::select(-geo_value) - expect_error(get_test_data(recipe = r, x = wrong_epi_df)) + expect_snapshot(error = TRUE, get_test_data(recipe = r, x = wrong_epi_df)) }) @@ -60,15 +60,15 @@ test_that("NA fill behaves as desired", { expect_silent(tt <- get_test_data(r, df)) expect_s3_class(tt, "epi_df") - expect_error(get_test_data(r, df, "A")) - expect_error(get_test_data(r, df, TRUE, -3)) + expect_snapshot(error = TRUE, get_test_data(r, df, "A")) + expect_snapshot(error = TRUE, get_test_data(r, df, TRUE, -3)) df2 <- df df2$x1[df2$geo_value == "ca"] <- NA td <- get_test_data(r, df2) expect_true(any(is.na(td))) - expect_error(get_test_data(r, df2, TRUE)) + expect_snapshot(error = TRUE, get_test_data(r, df2, TRUE)) df1 <- df2 df1$x1[1:4] <- 1:4 @@ -93,9 +93,9 @@ test_that("forecast date behaves", { step_epi_ahead(x1, ahead = 3) %>% step_epi_lag(x1, x2, lag = c(1, 3)) - expect_error(get_test_data(r, df, TRUE, forecast_date = 9)) # class error - expect_error(get_test_data(r, df, TRUE, forecast_date = 9L)) # fd too early - expect_error(get_test_data(r, df, forecast_date = 9L)) # fd too early + expect_snapshot(error = TRUE, get_test_data(r, df, TRUE, forecast_date = 9)) # class error + expect_snapshot(error = TRUE, get_test_data(r, df, TRUE, forecast_date = 9L)) # fd too early + expect_snapshot(error = TRUE, get_test_data(r, df, forecast_date = 9L)) # fd too early ndf <- get_test_data(r, df, TRUE, forecast_date = 12L) expect_equal(max(ndf$time_value), 11L) # max lag was 1 diff --git a/tests/testthat/test-grab_names.R b/tests/testthat/test-grab_names.R deleted file mode 100644 index 6e0376f5a..000000000 --- a/tests/testthat/test-grab_names.R +++ /dev/null @@ -1,8 +0,0 @@ -df <- data.frame(b = 1, c = 2, ca = 3, cat = 4) - -test_that("Names are grabbed properly", { - expect_identical( - grab_names(df, dplyr::starts_with("ca")), - subset(names(df), startsWith(names(df), "ca")) - ) -}) diff --git a/tests/testthat/test-grf_quantiles.R b/tests/testthat/test-grf_quantiles.R new file mode 100644 index 000000000..2570c247d --- /dev/null +++ b/tests/testthat/test-grf_quantiles.R @@ -0,0 +1,52 @@ +set.seed(12345) +library(grf) +tib <- tibble( + y = rnorm(100), x = rnorm(100), z = rnorm(100), + f = factor(sample(letters[1:3], 100, replace = TRUE)) +) + +test_that("quantile_rand_forest defaults work", { + spec <- rand_forest(engine = "grf_quantiles", mode = "regression") + expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) + pars <- parsnip::extract_fit_engine(out) + manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, quantiles = c(0.1, 0.5, 0.9)) + expect_identical(pars$quantiles.orig, manual$quantiles) + expect_identical(pars$`_num_trees`, manual$`_num_trees`) + + fseed <- 12345 + spec_seed <- rand_forest(mode = "regression", mtry = 2L, min_n = 10) %>% + set_engine("grf_quantiles", seed = fseed) + out <- fit(spec_seed, formula = y ~ x + z - 1, data = tib) + manual <- quantile_forest( + as.matrix(tib[, 2:3]), tib$y, + quantiles = c(0.1, 0.5, 0.9), seed = fseed, + mtry = 2L, min.node.size = 10 + ) + p_pars <- predict(out, new_data = tib[1:5, ]) %>% + pivot_quantiles_wider(.pred) + p_manual <- predict(manual, newdata = as.matrix(tib[1:5, 2:3]))$predictions + colnames(p_manual) <- c("0.1", "0.5", "0.9") + p_manual <- tibble::as_tibble(p_manual) + # not equal despite the seed, etc + # expect_equal(p_pars, p_manual) +}) + +test_that("quantile_rand_forest handles alternative quantiles", { + spec <- rand_forest(mode = "regression") %>% + set_engine("grf_quantiles", quantiles = c(.2, .5, .8)) + expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) + pars <- parsnip::extract_fit_engine(out) + manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, quantiles = c(.2, .5, .8)) + expect_identical(pars$quantiles.orig, manual$quantiles.orig) + expect_identical(pars$`_num_trees`, manual$`_num_trees`) +}) + + +test_that("quantile_rand_forest handles allows setting the trees and mtry", { + spec <- rand_forest(mode = "regression", mtry = 2, trees = 100, engine = "grf_quantiles") + expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) + pars <- parsnip::extract_fit_engine(out) + manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, mtry = 2, num.trees = 100) + expect_identical(pars$quantiles.orig, manual$quantiles.orig) + expect_identical(pars$`_num_trees`, manual$`_num_trees`) +}) diff --git a/tests/testthat/test-epi_keys.R b/tests/testthat/test-key_colnames.R similarity index 51% rename from tests/testthat/test-epi_keys.R rename to tests/testthat/test-key_colnames.R index 3e794542e..3b3118740 100644 --- a/tests/testthat/test-epi_keys.R +++ b/tests/testthat/test-key_colnames.R @@ -1,46 +1,25 @@ -library(parsnip) -library(workflows) -library(dplyr) - -test_that("epi_keys returns empty for an object that isn't an epi_df", { - expect_identical(epi_keys(data.frame(x = 1:3, y = 2:4)), character(0L)) -}) - -test_that("epi_keys returns possible keys if they exist", { - expect_identical( - epi_keys(data.frame(time_value = 1:3, geo_value = 2:4)), - c("time_value", "geo_value") - ) -}) - - -test_that("Extracts keys from an epi_df", { - expect_equal(epi_keys(case_death_rate_subset), c("time_value", "geo_value")) -}) - test_that("Extracts keys from a recipe; roles are NA, giving an empty vector", { - expect_equal(epi_keys(recipe(case_death_rate_subset)), character(0L)) + expect_equal(key_colnames(recipe(case_death_rate_subset)), character(0L)) }) -test_that("epi_keys_mold extracts time_value and geo_value, but not raw", { +test_that("key_colnames extracts time_value and geo_value, but not raw", { my_recipe <- epi_recipe(case_death_rate_subset) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% step_epi_naomit() + expect_identical(key_colnames(my_recipe), c("geo_value", "time_value")) + my_workflow <- epi_workflow() %>% add_epi_recipe(my_recipe) %>% add_model(linear_reg()) %>% fit(data = case_death_rate_subset) - expect_setequal( - epi_keys_mold(my_workflow$pre$mold), - c("time_value", "geo_value") - ) + expect_identical(key_colnames(my_workflow), c("geo_value", "time_value")) }) -test_that("epi_keys_mold extracts additional keys when they are present", { +test_that("key_colnames extracts additional keys when they are present", { my_data <- tibble::tibble( geo_value = rep(c("ca", "fl", "pa"), each = 3), time_value = rep(seq(as.Date("2020-06-01"), as.Date("2020-06-03"), @@ -50,18 +29,24 @@ test_that("epi_keys_mold extracts additional keys when they are present", { state = rep(c("ca", "fl", "pa"), each = 3), # extra key value = 1:length(geo_value) + 0.01 * rnorm(length(geo_value)) ) %>% - epiprocess::as_epi_df( - additional_metadata = list(other_keys = c("state", "pol")) + as_epi_df( + other_keys = c("state", "pol") ) + expect_identical( + key_colnames(my_data), + c("geo_value", "state", "pol", "time_value") + ) + my_recipe <- epi_recipe(my_data) %>% step_epi_ahead(value, ahead = 7) %>% step_epi_naomit() + # order of the additional keys may be different + expect_equal(key_colnames(my_recipe), c("geo_value", "state", "pol", "time_value")) + my_workflow <- epi_workflow(my_recipe, linear_reg()) %>% fit(my_data) - expect_setequal( - epi_keys_mold(my_workflow$pre$mold), - c("time_value", "geo_value", "state", "pol") - ) + # order of the additional keys may be different + expect_equal(key_colnames(my_workflow), c("geo_value", "state", "pol", "time_value")) }) diff --git a/tests/testthat/test-layer_add_forecast_date.R b/tests/testthat/test-layer_add_forecast_date.R index 1830118dc..428922f46 100644 --- a/tests/testthat/test-layer_add_forecast_date.R +++ b/tests/testthat/test-layer_add_forecast_date.R @@ -11,8 +11,9 @@ latest <- jhu %>% test_that("layer validation works", { f <- frosting() - expect_error(layer_add_forecast_date(f, "a")) - expect_error(layer_add_forecast_date(f, "2022-05-31", id = c("a", "b"))) + expect_snapshot(error = TRUE, layer_add_forecast_date(f, c("2022-05-31", "2022-05-31"))) # multiple forecast_dates + expect_snapshot(error = TRUE, layer_add_forecast_date(f, "2022-05-31", id = 2)) # id is not a character + expect_snapshot(error = TRUE, layer_add_forecast_date(f, "2022-05-31", id = c("a", "b"))) # multiple ids expect_silent(layer_add_forecast_date(f, "2022-05-31")) expect_silent(layer_add_forecast_date(f)) expect_silent(layer_add_forecast_date(f, as.Date("2022-05-31"))) @@ -41,10 +42,12 @@ test_that("Specify a `forecast_date` that is less than `as_of` date", { layer_naomit(.pred) wf2 <- wf %>% add_frosting(f2) - expect_warning( - p2 <- predict(wf2, latest), - "forecast_date is less than the most recent update date of the data." - ) + # this warning has been removed + # expect_warning( + # p2 <- predict(wf2, latest), + # "forecast_date is less than the most recent update date of the data." + # ) + expect_silent(p2 <- predict(wf2, latest)) expect_equal(ncol(p2), 4L) expect_s3_class(p2, "epi_df") expect_equal(nrow(p2), 3L) @@ -59,13 +62,53 @@ test_that("Do not specify a forecast_date in `layer_add_forecast_date()`", { layer_naomit(.pred) wf3 <- wf %>% add_frosting(f3) - expect_warning( - p3 <- predict(wf3, latest), - "forecast_date is less than the most recent update date of the data." - ) + # this warning has been removed + # expect_warning( + # p3 <- predict(wf3, latest), + # "forecast_date is less than the most recent update date of the data." + # ) + expect_silent(p3 <- predict(wf3, latest)) expect_equal(ncol(p3), 4L) expect_s3_class(p3, "epi_df") expect_equal(nrow(p3), 3L) expect_equal(p3$forecast_date, rep(as.Date("2021-12-31"), times = 3)) expect_named(p3, c("geo_value", "time_value", ".pred", "forecast_date")) }) + + +test_that("forecast date works for daily", { + f <- frosting() %>% + layer_predict() %>% + layer_add_forecast_date() %>% + layer_naomit(.pred) + + wf1 <- add_frosting(wf, f) + p <- predict(wf1, latest) + # both forecast_date and epi_df are dates + expect_identical(p$forecast_date[1], as.Date("2021-12-31")) + + # the error happens at predict time because the + # time_value train/test types don't match + latest_yearly <- latest %>% + unclass() %>% + as.data.frame() %>% + mutate(time_value = as.POSIXlt(time_value)$year + 1900L) %>% + group_by(geo_value, time_value) %>% + summarize(case_rate = mean(case_rate), death_rate = mean(death_rate), .groups = "drop") %>% + as_epi_df() + expect_snapshot(error = TRUE, predict(wf1, latest_yearly)) + + # forecast_date is a string, gets correctly converted to date + wf2 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_forecast_date", forecast_date = "2022-01-01") + ) + expect_silent(predict(wf2, latest)) + + # forecast_date is a year/int while the epi_df is a date + wf3 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_forecast_date", forecast_date = 2022L) + ) + expect_snapshot(error = TRUE, predict(wf3, latest)) +}) diff --git a/tests/testthat/test-layer_add_target_date.R b/tests/testthat/test-layer_add_target_date.R index 287956612..53506ad07 100644 --- a/tests/testthat/test-layer_add_target_date.R +++ b/tests/testthat/test-layer_add_target_date.R @@ -31,7 +31,8 @@ test_that("Use ahead + max time value from pre, fit, post", { layer_naomit(.pred) wf2 <- wf %>% add_frosting(f2) - expect_warning(p2 <- predict(wf2, latest)) + # expect_warning(p2 <- predict(wf2, latest)) # this warning has been removed + expect_silent(p2 <- predict(wf2, latest)) expect_equal(ncol(p2), 5L) expect_s3_class(p2, "epi_df") expect_equal(nrow(p2), 3L) @@ -85,3 +86,40 @@ test_that("Specify own target date", { expect_equal(p2$target_date, rep(as.Date("2022-01-08"), times = 3)) expect_named(p2, c("geo_value", "time_value", ".pred", "target_date")) }) + +test_that("target date works for daily and yearly", { + f <- frosting() %>% + layer_predict() %>% + layer_add_target_date() %>% + layer_naomit(.pred) + + wf1 <- add_frosting(wf, f) + p <- predict(wf1, latest) + # both target_date and epi_df are dates + expect_identical(p$target_date[1], as.Date("2021-12-31") + 7L) + + # the error happens at predict time because the + # time_value train/test types don't match + latest_bad <- latest %>% + unclass() %>% + as.data.frame() %>% + mutate(time_value = as.POSIXlt(time_value)$year + 1900L) %>% + group_by(geo_value, time_value) %>% + summarize(case_rate = mean(case_rate), death_rate = mean(death_rate), .groups = "drop") %>% + as_epi_df() + expect_snapshot(error = TRUE, predict(wf1, latest_bad)) + + # target_date is a string (gets correctly converted to Date) + wf1 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_target_date", target_date = "2022-01-07") + ) + expect_silent(predict(wf1, latest)) + + # target_date is a year/int while the epi_df is a date + wf1 <- add_frosting( + wf, + adjust_frosting(f, "layer_add_target_date", target_date = 2022L) + ) + expect_error(predict(wf1, latest)) # wrong time type of forecast_date +}) diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index bd10de08c..70e76e593 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -31,3 +31,63 @@ test_that("prediction with interval works", { expect_equal(nrow(p), 108L) expect_named(p, c("geo_value", "time_value", ".pred_lower", ".pred_upper")) }) + +test_that("layer_predict dots validation", { + # We balk at unnamed arguments, though perhaps not with the most helpful error messages: + expect_error( + frosting() %>% layer_predict("pred_int", list(), tibble::tibble(x = 5)), + class = "epipredict__layer_predict__unnamed_dot" + ) + expect_error( + frosting() %>% layer_predict("pred_int", list(), "maybe_meant_to_be_id"), + class = "epipredict__layer_predict__unnamed_dot" + ) + # We allow arguments that might actually work at prediction time: + expect_no_error(frosting() %>% layer_predict(type = "quantile", interval = "confidence")) + + # We don't detect completely-bogus arg names until predict time: + expect_no_error(f_bad_arg <- frosting() %>% layer_predict(bogus_argument = "something")) + wf_bad_arg <- wf %>% add_frosting(f_bad_arg) + expect_snapshot(error = TRUE, predict(wf_bad_arg, latest)) + # ^ (currently with a awful error message, due to an extra comma in parsnip::check_pred_type_dots) + + # Some argument names only apply for some prediction `type`s; we don't check + # for invalid pairings, nor does {parsnip}, so we end up producing a forecast + # that silently ignores some arguments some of the time. ({workflows} doesn't + # check for these either.) + expect_no_error(frosting() %>% layer_predict(eval_time = "preferably this would error")) +}) + +test_that("layer_predict dots are forwarded", { + f_lm_int_level_95 <- frosting() %>% + layer_predict(type = "pred_int") + f_lm_int_level_80 <- frosting() %>% + layer_predict(type = "pred_int", level = 0.8) + wf_lm_int_level_95 <- wf %>% add_frosting(f_lm_int_level_95) + wf_lm_int_level_80 <- wf %>% add_frosting(f_lm_int_level_80) + p <- predict(wf, latest) + p_lm_int_level_95 <- predict(wf_lm_int_level_95, latest) + p_lm_int_level_80 <- predict(wf_lm_int_level_80, latest) + expect_contains(names(p_lm_int_level_95), c(".pred_lower", ".pred_upper")) + expect_contains(names(p_lm_int_level_80), c(".pred_lower", ".pred_upper")) + expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_95))) + expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_80))) + expect_true( + cbind( + p, + p_lm_int_level_95 %>% dplyr::select(.pred_lower_95 = .pred_lower, .pred_upper_95 = .pred_upper), + p_lm_int_level_80 %>% dplyr::select(.pred_lower_80 = .pred_lower, .pred_upper_80 = .pred_upper) + ) %>% + na.omit() %>% + mutate( + sandwiched = + .pred_lower_95 <= .pred_lower_80 & + .pred_lower_80 <= .pred & + .pred <= .pred_upper_80 & + .pred_upper_80 <= .pred_upper_95 + ) %>% + `[[`("sandwiched") %>% + all() + ) + # There are many possible other valid configurations that aren't tested here. +}) diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index 73f69b54a..09ef7c9d3 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -7,7 +7,6 @@ r <- epi_recipe(jhu) %>% step_epi_naomit() wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -latest <- get_test_data(recipe = r, x = jhu) test_that("Returns expected number or rows and columns", { @@ -18,7 +17,7 @@ test_that("Returns expected number or rows and columns", { wf1 <- wf %>% add_frosting(f) - expect_silent(p <- predict(wf1, latest)) + expect_silent(p <- forecast(wf1)) expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") expect_equal(nrow(p), 3L) @@ -47,5 +46,61 @@ test_that("Errors when used with a classifier", { layer_predict() %>% layer_residual_quantiles() wf <- wf %>% add_frosting(f) - expect_error(predict(wf, tib)) + expect_snapshot(error = TRUE, forecast(wf)) +}) + + +test_that("Grouping by keys is supported", { + f <- frosting() %>% + layer_predict() %>% + layer_naomit(.pred) %>% + layer_residual_quantiles() + wf1 <- wf %>% add_frosting(f) + expect_silent(p1 <- forecast(wf1)) + f2 <- frosting() %>% + layer_predict() %>% + layer_naomit(.pred) %>% + layer_residual_quantiles(by_key = "geo_value") + wf2 <- wf %>% add_frosting(f2) + expect_warning(p2 <- forecast(wf2)) + + pivot1 <- pivot_quantiles_wider(p1, .pred_distn) %>% + mutate(width = `0.95` - `0.05`) + pivot2 <- pivot_quantiles_wider(p2, .pred_distn) %>% + mutate(width = `0.95` - `0.05`) + expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1))) + expect_false(all(pivot2$width == pivot2$width[1])) +}) + +test_that("Canned forecasters work with / without", { + meta <- attr(jhu, "metadata") + meta$as_of <- max(jhu$time_value) + attr(jhu, "metadata") <- meta + + expect_silent( + flatline_forecaster(jhu, "death_rate") + ) + expect_silent( + flatline_forecaster( + jhu, "death_rate", + args_list = flatline_args_list(quantile_by_key = "geo_value") + ) + ) + + expect_silent( + arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate")) + ) + expect_silent( + flatline_forecaster( + jhu, "death_rate", + args_list = flatline_args_list(quantile_by_key = "geo_value") + ) + ) +}) + +test_that("flatline_forecaster correctly errors when n_training < ahead", { + expect_snapshot( + error = TRUE, + flatline_forecaster(jhu, "death_rate", args_list = flatline_args_list(ahead = 10, n_training = 9)) + ) }) diff --git a/tests/testthat/test-layers.R b/tests/testthat/test-layers.R index 13f859ac3..6e2d80111 100644 --- a/tests/testthat/test-layers.R +++ b/tests/testthat/test-layers.R @@ -11,7 +11,7 @@ test_that("A layer can be updated in frosting", { expect_equal(length(f$layers), 2) expect_equal(f$layers[[1]], fold$layers[[1]]) expect_equal(f$layers[[2]]$lower, 100) - expect_error(update(f$layers[[1]], lower = 100)) - expect_error(update(f$layers[[3]], lower = 100)) - expect_error(update(f$layers[[2]], bad_param = 100)) + expect_snapshot(error = TRUE, update(f$layers[[1]], lower = 100)) + expect_snapshot(error = TRUE, update(f$layers[[3]], lower = 100)) + expect_snapshot(error = TRUE, update(f$layers[[2]], bad_param = 100)) }) diff --git a/tests/testthat/test-pad_to_end.R b/tests/testthat/test-pad_to_end.R index 474b9001b..6949f06ac 100644 --- a/tests/testthat/test-pad_to_end.R +++ b/tests/testthat/test-pad_to_end.R @@ -32,6 +32,6 @@ test_that("test set padding works", { # make sure it maintains the epi_df dat <- dat %>% dplyr::rename(geo_value = gr1) %>% - as_epi_df(dat) + as_epi_df(other_keys = "gr2") expect_s3_class(pad_to_end(dat, "geo_value", 2), "epi_df") }) diff --git a/tests/testthat/test-parse_period.R b/tests/testthat/test-parse_period.R index 0adbcec3d..10dd5692d 100644 --- a/tests/testthat/test-parse_period.R +++ b/tests/testthat/test-parse_period.R @@ -1,8 +1,8 @@ test_that("parse_period works", { - expect_error(parse_period(c(1, 2))) - expect_error(parse_period(c(1.3))) - expect_error(parse_period("1 year")) - expect_error(parse_period("2 weeks later")) + expect_snapshot(error = TRUE, parse_period(c(1, 2))) + expect_snapshot(error = TRUE, parse_period(c(1.3))) + expect_snapshot(error = TRUE, parse_period("1 year")) + expect_snapshot(error = TRUE, parse_period("2 weeks later")) expect_identical(parse_period(1), 1L) expect_identical(parse_period("1 day"), 1L) expect_identical(parse_period("1 days"), 1L) diff --git a/tests/testthat/test-parsnip_model_validation.R b/tests/testthat/test-parsnip_model_validation.R index 02ed94fe0..605fad817 100644 --- a/tests/testthat/test-parsnip_model_validation.R +++ b/tests/testthat/test-parsnip_model_validation.R @@ -4,12 +4,12 @@ test_that("forecaster can validate parsnip model", { trainer2 <- parsnip::logistic_reg() trainer3 <- parsnip::rand_forest() - expect_error(get_parsnip_mode(l)) + expect_snapshot(error = TRUE, get_parsnip_mode(l)) expect_equal(get_parsnip_mode(trainer1), "regression") expect_equal(get_parsnip_mode(trainer2), "classification") expect_equal(get_parsnip_mode(trainer3), "unknown") - expect_error(is_classification(l)) + expect_snapshot(error = TRUE, is_classification(l)) expect_true(is_regression(trainer1)) expect_false(is_classification(trainer1)) expect_true(is_classification(trainer2)) diff --git a/tests/testthat/test-pivot_quantiles.R b/tests/testthat/test-pivot_quantiles.R index 908a75795..1639058e2 100644 --- a/tests/testthat/test-pivot_quantiles.R +++ b/tests/testthat/test-pivot_quantiles.R @@ -1,14 +1,14 @@ test_that("quantile pivotting wider behaves", { tib <- tibble::tibble(a = 1:5, b = 6:10) - expect_error(pivot_quantiles_wider(tib, a)) + expect_snapshot(error = TRUE, pivot_quantiles_wider(tib, a)) tib$c <- rep(dist_normal(), 5) - expect_error(pivot_quantiles_wider(tib, c)) + expect_snapshot(error = TRUE, pivot_quantiles_wider(tib, c)) d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:5, 1:4 / 5)) # different quantiles tib <- tib[1:2, ] tib$d1 <- d1 - expect_error(pivot_quantiles_wider(tib, d1)) + expect_snapshot(error = TRUE, pivot_quantiles_wider(tib, d1)) d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:4, 2:4 / 4)) tib$d1 <- d1 @@ -25,12 +25,20 @@ test_that("quantile pivotting wider behaves", { expect_length(pivot_quantiles_wider(tib, d2), 5L) }) +test_that("pivotting wider still works if there are duplicates", { + # previously this would produce a warning if pivotted because the + # two rows of the result are identical + tb <- tibble(.pred = dist_quantiles(list(1:3, 1:3), list(c(.1, .5, .9)))) + res <- tibble(`0.1` = c(1, 1), `0.5` = c(2, 2), `0.9` = c(3, 3)) + expect_identical(tb %>% pivot_quantiles_wider(.pred), res) +}) + test_that("quantile pivotting longer behaves", { tib <- tibble::tibble(a = 1:5, b = 6:10) - expect_error(pivot_quantiles_longer(tib, a)) + expect_snapshot(error = TRUE, pivot_quantiles_longer(tib, a)) tib$c <- rep(dist_normal(), 5) - expect_error(pivot_quantiles_longer(tib, c)) + expect_snapshot(error = TRUE, pivot_quantiles_longer(tib, c)) d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:5, 1:4 / 5)) # different quantiles @@ -56,7 +64,7 @@ test_that("quantile pivotting longer behaves", { tib$d3 <- c(dist_quantiles(2:5, 2:5 / 6), dist_quantiles(3:6, 2:5 / 6)) # now the cols have different numbers of quantiles - expect_error(pivot_quantiles_longer(tib, d1, d3)) + expect_snapshot(error = TRUE, pivot_quantiles_longer(tib, d1, d3)) expect_length( pivot_quantiles_longer(tib, d1, d3, .ignore_length_check = TRUE), 6L diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index 6951d56e6..6337a2ea8 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -5,18 +5,25 @@ test_that("Column names can be passed with and without the tidy way", { value = c(1000, 2000, 3000, 4000, 5000, 6000) ) - newdata <- case_death_rate_subset %>% filter(geo_value %in% c("ak", "al", "ar", "as", "az", "ca")) + pop_data2 <- pop_data %>% dplyr::rename(geo_value = states) + + newdata <- case_death_rate_subset %>% + filter(geo_value %in% c("ak", "al", "ar", "as", "az", "ca")) r1 <- epi_recipe(newdata) %>% - step_population_scaling(c("case_rate", "death_rate"), + step_population_scaling( + case_rate, death_rate, df = pop_data, - df_pop_col = "value", by = c("geo_value" = "states") + df_pop_col = "value", + by = c("geo_value" = "states") ) r2 <- epi_recipe(newdata) %>% - step_population_scaling(case_rate, death_rate, - df = pop_data, - df_pop_col = "value", by = c("geo_value" = "states") + step_population_scaling( + case_rate, death_rate, + df = pop_data2, + df_pop_col = "value", + by = "geo_value" ) prep1 <- prep(r1, newdata) @@ -56,9 +63,9 @@ test_that("Number of columns and column names returned correctly, Upper and lowe suffix = "_rate" ) - prep <- prep(r, newdata) + p <- prep(r, newdata) - expect_silent(b <- bake(prep, newdata)) + b <- bake(p, newdata) expect_equal(ncol(b), 7L) expect_true("case_rate" %in% colnames(b)) expect_true("death_rate" %in% colnames(b)) @@ -75,15 +82,15 @@ test_that("Number of columns and column names returned correctly, Upper and lowe create_new = FALSE ) - expect_warning(prep <- prep(r, newdata)) + expect_warning(p <- prep(r, newdata)) - expect_warning(b <- bake(prep, newdata)) + expect_warning(b <- bake(p, newdata)) expect_equal(ncol(b), 5L) }) ## Postprocessing test_that("Postprocessing workflow works and values correct", { - jhu <- epiprocess::jhu_csse_daily_subset %>% + jhu <- jhu_csse_daily_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, cases) @@ -119,17 +126,7 @@ test_that("Postprocessing workflow works and values correct", { fit(jhu) %>% add_frosting(f) - latest <- get_test_data( - recipe = r, - x = epiprocess::jhu_csse_daily_subset %>% - dplyr::filter( - time_value > "2021-11-01", - geo_value %in% c("ca", "ny") - ) %>% - dplyr::select(geo_value, time_value, cases) - ) - - suppressWarnings(p <- predict(wf, latest)) + p <- forecast(wf) expect_equal(nrow(p), 2L) expect_equal(ncol(p), 4L) expect_equal(p$.pred_scaled, p$.pred * c(20000, 30000)) @@ -146,7 +143,7 @@ test_that("Postprocessing workflow works and values correct", { wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) - suppressWarnings(p <- predict(wf, latest)) + p <- forecast(wf) expect_equal(nrow(p), 2L) expect_equal(ncol(p), 4L) expect_equal(p$.pred_scaled, p$.pred * c(2, 3)) @@ -188,18 +185,7 @@ test_that("Postprocessing to get cases from case rate", { fit(jhu) %>% add_frosting(f) - latest <- get_test_data( - recipe = r, - x = case_death_rate_subset %>% - dplyr::filter( - time_value > "2021-11-01", - geo_value %in% c("ca", "ny") - ) %>% - dplyr::select(geo_value, time_value, case_rate) - ) - - - suppressWarnings(p <- predict(wf, latest)) + p <- forecast(wf) expect_equal(nrow(p), 2L) expect_equal(ncol(p), 4L) expect_equal(p$.pred_scaled, p$.pred * c(1 / 20000, 1 / 30000)) @@ -207,7 +193,6 @@ test_that("Postprocessing to get cases from case rate", { test_that("test joining by default columns", { - skip() jhu <- case_death_rate_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, case_rate) @@ -229,9 +214,18 @@ test_that("test joining by default columns", { recipes::step_naomit(recipes::all_predictors()) %>% recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) - suppressMessages(prep <- prep(r, jhu)) + p <- prep(r, jhu) + b <- bake(p, new_data = NULL) + expect_named( + b, + c( + "geo_value", "time_value", "case_rate", "case_rate_scaled", + paste0("lag_", c(0, 7, 14), "_case_rate_scaled"), + "ahead_7_case_rate_scaled" + ) + ) + - suppressMessages(b <- bake(prep, jhu)) f <- frosting() %>% layer_predict() %>% @@ -243,27 +237,16 @@ test_that("test joining by default columns", { df_pop_col = "values" ) - suppressMessages( - wf <- epi_workflow(r, parsnip::linear_reg()) %>% - fit(jhu) %>% - add_frosting(f) - ) - - latest <- get_test_data( - recipe = r, - x = case_death_rate_subset %>% - dplyr::filter( - time_value > "2021-11-01", - geo_value %in% c("ca", "ny") - ) %>% - dplyr::select(geo_value, time_value, case_rate) - ) + wf <- epi_workflow(r, parsnip::linear_reg()) %>% + fit(jhu) %>% + add_frosting(f) - suppressMessages(p <- predict(wf, latest)) + fc <- forecast(wf) + expect_named(fc, c("geo_value", "time_value", ".pred", ".pred_scaled")) + expect_equal(fc$.pred_scaled, fc$.pred * c(1 / 20000, 1 / 30000)) }) - test_that("expect error if `by` selector does not match", { jhu <- case_death_rate_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% @@ -296,7 +279,8 @@ test_that("expect error if `by` selector does not match", { df_pop_col = "values" ) - expect_error( + expect_snapshot( + error = TRUE, wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) @@ -324,22 +308,11 @@ test_that("expect error if `by` selector does not match", { df_pop_col = "values" ) - - latest <- get_test_data( - recipe = r, - x = case_death_rate_subset %>% - dplyr::filter( - time_value > "2021-11-01", - geo_value %in% c("ca", "ny") - ) %>% - dplyr::select(geo_value, time_value, case_rate) - ) - wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) - expect_error(predict(wf, latest)) + expect_snapshot(error = TRUE, forecast(wf)) }) @@ -407,12 +380,10 @@ test_that("Rate rescaling behaves as expected", { fit(x) %>% add_frosting(f) - latest <- get_test_data(recipe = r, x = x) - # suppress warning: prediction from a rank-deficient fit may be misleading suppressWarnings(expect_equal( - unique(predict(wf, latest)$.pred) * (1 / 1000) / 100, - unique(predict(wf, latest)$.pred_scaled) + unique(forecast(wf)$.pred) * (1 / 1000) / 100, + unique(forecast(wf)$.pred_scaled) )) }) @@ -459,7 +430,6 @@ test_that("Extra Columns are ignored", { wf <- epi_workflow(recip, parsnip::linear_reg()) %>% fit(x) %>% add_frosting(frost) - latest <- get_test_data(recipe = recip, x = x) # suppress warning: prediction from a rank-deficient fit may be misleading - suppressWarnings(expect_equal(ncol(predict(wf, latest)), 4)) + suppressWarnings(expect_equal(ncol(forecast(wf)), 4)) }) diff --git a/tests/testthat/test-propagate_samples.R b/tests/testthat/test-propagate_samples.R deleted file mode 100644 index 5278ab385..000000000 --- a/tests/testthat/test-propagate_samples.R +++ /dev/null @@ -1,7 +0,0 @@ -test_that("propagate_samples", { - r <- -30:50 - p <- 40 - quantiles <- 1:9 / 10 - aheads <- c(2, 4, 7) - nsim <- 100 -}) diff --git a/tests/testthat/test-shuffle.R b/tests/testthat/test-shuffle.R index 94bc1aa3b..f05e8be3d 100644 --- a/tests/testthat/test-shuffle.R +++ b/tests/testthat/test-shuffle.R @@ -1,5 +1,5 @@ test_that("shuffle works", { - expect_error(shuffle(matrix(NA, 2, 2))) + expect_snapshot(error = TRUE, shuffle(matrix(NA, 2, 2))) expect_length(shuffle(1:10), 10L) expect_identical(sort(shuffle(1:10)), 1:10) }) diff --git a/tests/testthat/test-snapshots.R b/tests/testthat/test-snapshots.R new file mode 100644 index 000000000..d624a4c21 --- /dev/null +++ b/tests/testthat/test-snapshots.R @@ -0,0 +1,81 @@ +train_data <- jhu_csse_daily_subset +expect_snapshot_tibble <- function(x) { + expect_snapshot_value(x, style = "deparse", cran = FALSE) +} + +test_that("flatline_forecaster snapshots", { + # Let's make a few forecasts using different settings and snapshot them + flat1 <- flatline_forecaster(train_data, "death_rate_7d_av") + expect_snapshot_tibble(flat1$predictions) + + flat2 <- flatline_forecaster( + train_data, "death_rate_7d_av", + args_list = flatline_args_list(ahead = 1L) + ) + expect_snapshot_tibble(flat2$predictions) + + flat3 <- flatline_forecaster( + train_data, "death_rate_7d_av", + args_list = flatline_args_list( + forecast_date = as.Date("2021-12-31") + ) + ) + expect_snapshot_tibble(flat3$predictions) + + flat4 <- flatline_forecaster( + train_data, "death_rate_7d_av", + args_list = flatline_args_list( + target_date = as.Date("2022-01-01"), + ) + ) + expect_snapshot_tibble(flat4$predictions) +}) + +test_that("cdc_baseline_forecaster snapshots", { + set.seed(1234) + cdc1 <- cdc_baseline_forecaster(train_data, "death_rate_7d_av") + expect_snapshot_tibble(cdc1$predictions) + + cdc2 <- cdc_baseline_forecaster( + train_data, "death_rate_7d_av", + args_list = cdc_baseline_args_list(aheads = 2:6) + ) + expect_snapshot_tibble(cdc2$predictions) + + cdc3 <- cdc_baseline_forecaster( + train_data, "death_rate_7d_av", + args_list = cdc_baseline_args_list( + data_frequency = "5 days", + ) + ) + expect_snapshot_tibble(cdc3$predictions) +}) + +test_that("arx_forecaster snapshots", { + arx1 <- arx_forecaster( + train_data, + "death_rate_7d_av", + c("death_rate_7d_av", "case_rate_7d_av") + ) + expect_snapshot_tibble(arx1$predictions) + + arx2 <- arx_forecaster( + train_data, + "death_rate_7d_av", + c("death_rate_7d_av", "case_rate_7d_av"), + args_list = arx_args_list( + ahead = 1L + ) + ) + expect_snapshot_tibble(arx2$predictions) +}) + +test_that("arx_classifier snapshots", { + arc1 <- arx_classifier( + case_death_rate_subset %>% + dplyr::filter(time_value >= as.Date("2021-11-01")), + "death_rate", + c("case_rate", "death_rate") + ) + expect_snapshot_tibble(arc1$predictions) +}) diff --git a/tests/testthat/test-step_epi_naomit.R b/tests/testthat/test-step_epi_naomit.R index 2fb173f01..0e5e1750f 100644 --- a/tests/testthat/test-step_epi_naomit.R +++ b/tests/testthat/test-step_epi_naomit.R @@ -17,7 +17,7 @@ r <- epi_recipe(x) %>% step_epi_lag(death_rate, lag = c(0, 7, 14)) test_that("Argument must be a recipe", { - expect_error(step_epi_naomit(x)) + expect_snapshot(error = TRUE, step_epi_naomit(x)) }) z1 <- step_epi_naomit(r) diff --git a/tests/testthat/test-step_epi_shift.R b/tests/testthat/test-step_epi_shift.R index da04fd0f2..1f83120b3 100644 --- a/tests/testthat/test-step_epi_shift.R +++ b/tests/testthat/test-step_epi_shift.R @@ -20,7 +20,8 @@ slm_fit <- function(recipe, data = x) { } test_that("Values for ahead and lag must be integer values", { - expect_error( + expect_snapshot( + error = TRUE, r1 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 3.6) %>% step_epi_lag(death_rate, lag = 1.9) @@ -28,7 +29,8 @@ test_that("Values for ahead and lag must be integer values", { }) test_that("A negative lag value should should throw an error", { - expect_error( + expect_snapshot( + error = TRUE, r2 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(death_rate, lag = -7) @@ -36,7 +38,8 @@ test_that("A negative lag value should should throw an error", { }) test_that("A nonpositive ahead value should throw an error", { - expect_error( + expect_snapshot( + error = TRUE, r3 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = -7) %>% step_epi_lag(death_rate, lag = 7) @@ -48,9 +51,7 @@ test_that("Values for ahead and lag cannot be duplicates", { step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(death_rate, lag = 7) %>% step_epi_lag(death_rate, lag = 7) - expect_error( - slm_fit(r4) - ) + expect_snapshot(error = TRUE, slm_fit(r4)) }) test_that("Check that epi_lag shifts applies the shift", { diff --git a/tests/testthat/test-step_epi_slide.R b/tests/testthat/test-step_epi_slide.R new file mode 100644 index 000000000..27f362ad6 --- /dev/null +++ b/tests/testthat/test-step_epi_slide.R @@ -0,0 +1,77 @@ +library(dplyr) + +tt <- seq(as.Date("2022-01-01"), by = "1 day", length.out = 20) +edf <- data.frame( + time_value = c(tt, tt), + geo_value = rep(c("ca", "ny"), each = 20L), + value = c(2:21, 3:22) +) %>% + as_epi_df() +r <- epi_recipe(edf) + + +test_that("epi_slide errors when needed", { + # not an epi_recipe + expect_snapshot(error = TRUE, recipe(edf) %>% step_epi_slide(value, .f = mean, .window_size = 7L)) + + # non-scalar args + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = c(3L, 6L))) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .align = c("right", "left"))) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, skip = c(TRUE, FALSE))) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, role = letters[1:2])) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, prefix = letters[1:2])) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, id = letters[1:2])) + # wrong types + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1.5)) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, .align = 1.5)) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, skip = "a")) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, role = 1)) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, prefix = 1)) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = mean, .window_size = 1L, id = 1)) + # function problems + expect_snapshot(error = TRUE, r %>% step_epi_slide(value)) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = 1)) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value)) + expect_snapshot(error = TRUE, r %>% step_epi_slide(value, .f = 1)) +}) + + +test_that("epi_slide handles different function specs", { + cfun <- r %>% + step_epi_slide(value, .f = "mean", .window_size = 4L) %>% + prep(edf) %>% + bake(new_data = NULL) + expected_out <- edf %>% + group_by(geo_value) %>% + epi_slide(~ mean(.x$value), .window_size = 4L) %>% + ungroup() %>% + rename(epi_slide__.f_value = slide_value) + expect_equal(cfun, expected_out) + ffun <- r %>% + step_epi_slide(value, .f = mean, .window_size = 4L) %>% + prep(edf) %>% + bake(new_data = NULL) + expect_equal(ffun, expected_out) + # formula NOT currently supported + expect_snapshot( + error = TRUE, + lfun <- r %>% + step_epi_slide(value, .f = ~ mean(.x, na.rm = TRUE), .window_size = 4L) + ) + # expect_equal(lfun, rolled_before) + blfun <- r %>% + step_epi_slide(value, .f = function(x) mean(x, na.rm = TRUE), .window_size = 4L) %>% + prep(edf) %>% + bake(new_data = NULL) + expected_out <- edf %>% + group_by(geo_value) %>% + epi_slide(~ mean(.x$value, na.rm = TRUE), .window_size = 4L) %>% + ungroup() %>% + rename(epi_slide__.f_value = slide_value) + expect_equal(blfun, expected_out) + nblfun <- r %>% + step_epi_slide(value, .f = \(x) mean(x, na.rm = TRUE), .window_size = 4L) %>% + prep(edf) %>% + bake(new_data = NULL) + expect_equal(nblfun, expected_out) +}) diff --git a/tests/testthat/test-step_growth_rate.R b/tests/testthat/test-step_growth_rate.R index d0dec170e..f2845d812 100644 --- a/tests/testthat/test-step_growth_rate.R +++ b/tests/testthat/test-step_growth_rate.R @@ -1,27 +1,25 @@ test_that("step_growth_rate validates arguments", { df <- data.frame(time_value = 1:5, geo_value = rep("a", 5), value = 6:10) r <- recipes::recipe(df) - expect_error(step_growth_rate(r)) + expect_snapshot(error = TRUE, step_growth_rate(r)) edf <- as_epi_df(df) r <- epi_recipe(edf) - expect_error(step_growth_rate(r, value, role = 1)) - expect_error(step_growth_rate(r, value, method = "abc")) - expect_error(step_growth_rate(r, value, horizon = 0)) - expect_error(step_growth_rate(r, value, horizon = c(1, 2))) - expect_error(step_growth_rate(r, value, prefix = letters[1:2])) - expect_error(step_growth_rate(r, value, id = letters[1:2])) - expect_error(step_growth_rate(r, value, prefix = letters[1:2])) - expect_error(step_growth_rate(r, value, prefix = 1)) - expect_error(step_growth_rate(r, value, id = 1)) - expect_error(step_growth_rate(r, value, trained = 1)) - expect_error(step_growth_rate(r, value, log_scale = 1)) - expect_error(step_growth_rate(r, value, skip = 1)) - expect_error(step_growth_rate(r, value, additional_gr_args_list = 1:5)) - expect_error(step_growth_rate(r, value, columns = letters[1:5])) - expect_error(step_growth_rate(r, value, replace_Inf = "c")) - expect_error(step_growth_rate(r, value, replace_Inf = c(1, 2))) + expect_snapshot(error = TRUE, step_growth_rate(r, value, role = 1)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, method = "abc")) + expect_snapshot(error = TRUE, step_growth_rate(r, value, horizon = 0)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, horizon = c(1, 2))) + expect_snapshot(error = TRUE, step_growth_rate(r, value, prefix = letters[1:2])) + expect_snapshot(error = TRUE, step_growth_rate(r, value, id = letters[1:2])) + expect_snapshot(error = TRUE, step_growth_rate(r, value, prefix = letters[1:2])) + expect_snapshot(error = TRUE, step_growth_rate(r, value, prefix = 1)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, id = 1)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, log_scale = 1)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, skip = 1)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, additional_gr_args_list = 1:5)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, replace_Inf = "c")) + expect_snapshot(error = TRUE, step_growth_rate(r, value, replace_Inf = c(1, 2))) expect_silent(step_growth_rate(r, value, replace_Inf = NULL)) expect_silent(step_growth_rate(r, value, replace_Inf = NA)) }) @@ -34,7 +32,7 @@ test_that("step_growth_rate works for a single signal", { res <- r %>% step_growth_rate(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_value, c(NA, 1 / 6:9)) @@ -46,7 +44,7 @@ test_that("step_growth_rate works for a single signal", { r <- epi_recipe(edf) res <- r %>% step_growth_rate(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_value, rep(c(NA, 1 / 6:9), each = 2)) }) @@ -63,7 +61,7 @@ test_that("step_growth_rate works for a two signals", { res <- r %>% step_growth_rate(v1, v2, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_v1, c(NA, 1 / 6:9)) expect_equal(res$gr_1_rel_change_v2, c(NA, 1 / 1:4)) @@ -76,7 +74,7 @@ test_that("step_growth_rate works for a two signals", { r <- epi_recipe(edf) res <- r %>% step_growth_rate(v1, v2, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$gr_1_rel_change_v1, rep(c(NA, 1 / 6:9), each = 2)) expect_equal(res$gr_1_rel_change_v2, rep(c(NA, 1 / 1:4), each = 2)) diff --git a/tests/testthat/test-step_lag_difference.R b/tests/testthat/test-step_lag_difference.R index dc61d12d4..6ff9884a7 100644 --- a/tests/testthat/test-step_lag_difference.R +++ b/tests/testthat/test-step_lag_difference.R @@ -1,22 +1,20 @@ test_that("step_lag_difference validates arguments", { df <- data.frame(time_value = 1:5, geo_value = rep("a", 5), value = 6:10) r <- recipes::recipe(df) - expect_error(step_lag_difference(r)) + expect_snapshot(error = TRUE, step_lag_difference(r)) edf <- as_epi_df(df) r <- epi_recipe(edf) - expect_error(step_lag_difference(r, value, role = 1)) - expect_error(step_lag_difference(r, value, horizon = 0)) + expect_snapshot(error = TRUE, step_lag_difference(r, value, role = 1)) + expect_snapshot(error = TRUE, step_lag_difference(r, value, horizon = 0)) expect_silent(step_lag_difference(r, value, horizon = c(1, 2))) - expect_error(step_lag_difference(r, value, prefix = letters[1:2])) - expect_error(step_lag_difference(r, value, id = letters[1:2])) - expect_error(step_lag_difference(r, value, prefix = letters[1:2])) - expect_error(step_lag_difference(r, value, prefix = 1)) - expect_error(step_lag_difference(r, value, id = 1)) - expect_error(step_lag_difference(r, value, trained = 1)) - expect_error(step_lag_difference(r, value, skip = 1)) - expect_error(step_lag_difference(r, value, columns = letters[1:5])) + expect_snapshot(error = TRUE, step_lag_difference(r, value, prefix = letters[1:2])) + expect_snapshot(error = TRUE, step_lag_difference(r, value, id = letters[1:2])) + expect_snapshot(error = TRUE, step_lag_difference(r, value, prefix = letters[1:2])) + expect_snapshot(error = TRUE, step_lag_difference(r, value, prefix = 1)) + expect_snapshot(error = TRUE, step_lag_difference(r, value, id = 1)) + expect_snapshot(error = TRUE, step_lag_difference(r, value, skip = 1)) }) @@ -27,13 +25,13 @@ test_that("step_lag_difference works for a single signal", { res <- r %>% step_lag_difference(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_value, c(NA, rep(1, 4))) res <- r %>% step_lag_difference(value, horizon = 1:2) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_value, c(NA, rep(1, 4))) expect_equal(res$lag_diff_2_value, c(NA, NA, rep(2, 3))) @@ -48,7 +46,7 @@ test_that("step_lag_difference works for a single signal", { r <- epi_recipe(edf) res <- r %>% step_lag_difference(value, horizon = 1) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_value, c(NA, NA, rep(1, 8))) }) @@ -65,7 +63,7 @@ test_that("step_lag_difference works for a two signals", { res <- r %>% step_lag_difference(v1, v2, horizon = 1:2) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_v1, c(NA, rep(1, 4))) expect_equal(res$lag_diff_2_v1, c(NA, NA, rep(2, 3))) @@ -80,7 +78,7 @@ test_that("step_lag_difference works for a two signals", { r <- epi_recipe(edf) res <- r %>% step_lag_difference(v1, v2, horizon = 1:2) %>% - prep() %>% + prep(edf) %>% bake(edf) expect_equal(res$lag_diff_1_v1, rep(c(NA, rep(1, 4)), each = 2)) expect_equal(res$lag_diff_2_v1, rep(c(NA, NA, rep(2, 3)), each = 2)) diff --git a/tests/testthat/test-step_training_window.R b/tests/testthat/test-step_training_window.R index c8a17f43f..a9f2170d3 100644 --- a/tests/testthat/test-step_training_window.R +++ b/tests/testthat/test-step_training_window.R @@ -11,13 +11,13 @@ toy_epi_df <- tibble::tibble( test_that("step_training_window works with default n_recent", { p <- epi_recipe(y ~ x, data = toy_epi_df) %>% step_training_window() %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) expect_equal(nrow(p), 100L) expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") - expect_named(p, c("time_value", "geo_value", "x", "y")) + expect_named(p, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p$time_value, rep(seq(as.Date("2020-02-20"), as.Date("2020-04-09"), by = 1), times = 2) @@ -28,13 +28,13 @@ test_that("step_training_window works with default n_recent", { test_that("step_training_window works with specified n_recent", { p2 <- epi_recipe(y ~ x, data = toy_epi_df) %>% step_training_window(n_recent = 5) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df) %>% + bake(new_data = NULL) expect_equal(nrow(p2), 10L) expect_equal(ncol(p2), 4L) expect_s3_class(p2, "epi_df") - expect_named(p2, c("time_value", "geo_value", "x", "y")) + expect_named(p2, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p2$time_value, rep(seq(as.Date("2020-04-05"), as.Date("2020-04-09"), by = 1), times = 2) @@ -48,14 +48,14 @@ test_that("step_training_window does not proceed with specified new_data", { # testing data. p3 <- epi_recipe(y ~ x, data = toy_epi_df) %>% step_training_window(n_recent = 3) %>% - recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = toy_epi_df[1:10, ]) + prep(toy_epi_df) %>% + bake(new_data = toy_epi_df[1:10, ]) expect_equal(nrow(p3), 10L) expect_equal(ncol(p3), 4L) expect_s3_class(p3, "epi_df") # cols will be predictors, outcomes, time_value, geo_value - expect_named(p3, c("x", "y", "time_value", "geo_value")) + expect_named(p3, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p3$time_value, rep(seq(as.Date("2020-01-01"), as.Date("2020-01-10"), by = 1), times = 1) @@ -78,13 +78,13 @@ test_that("step_training_window works with multiple keys", { p4 <- epi_recipe(y ~ x, data = toy_epi_df2) %>% step_training_window(n_recent = 3) %>% - recipes::prep(toy_epi_df2) %>% - recipes::bake(new_data = NULL) + prep(toy_epi_df2) %>% + bake(new_data = NULL) expect_equal(nrow(p4), 12L) expect_equal(ncol(p4), 5L) expect_s3_class(p4, "epi_df") - expect_named(p4, c("time_value", "geo_value", "additional_key", "x", "y")) + expect_named(p4, c("geo_value", "additional_key", "time_value", "x", "y")) expect_equal( p4$time_value, rep(c( @@ -112,20 +112,20 @@ test_that("step_training_window and step_naomit interact", { e1 <- epi_recipe(y ~ x, data = tib) %>% step_training_window(n_recent = 3) %>% - recipes::prep(tib) %>% - recipes::bake(new_data = NULL) + prep(tib) %>% + bake(new_data = NULL) e2 <- epi_recipe(y ~ x, data = tib) %>% - recipes::step_naomit() %>% + step_naomit() %>% step_training_window(n_recent = 3) %>% - recipes::prep(tib) %>% - recipes::bake(new_data = NULL) + prep(tib) %>% + bake(new_data = NULL) e3 <- epi_recipe(y ~ x, data = tib) %>% step_training_window(n_recent = 3) %>% - recipes::step_naomit() %>% - recipes::prep(tib) %>% - recipes::bake(new_data = NULL) + step_naomit() %>% + prep(tib) %>% + bake(new_data = NULL) expect_identical(e1, e2) expect_identical(e2, e3) diff --git a/tests/testthat/test-target_date_bug.R b/tests/testthat/test-target_date_bug.R new file mode 100644 index 000000000..4a7e7d2e8 --- /dev/null +++ b/tests/testthat/test-target_date_bug.R @@ -0,0 +1,77 @@ +# These tests address #290: +# https://github.com/cmu-delphi/epipredict/issues/290 + +library(dplyr) +train <- jhu_csse_daily_subset |> + filter(time_value >= as.Date("2021-10-01")) |> + select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av) +ngeos <- n_distinct(train$geo_value) + +test_that("flatline determines target_date where forecast_date exists", { + flat <- flatline_forecaster( + train, "dr", + args_list = flatline_args_list( + forecast_date = as.Date("2021-12-31"), + target_date = as.Date("2022-01-01"), + ahead = 1L + ) + ) + # previously, if target_date existed, it could be + # erroneously incremented by the ahead + expect_identical( + flat$predictions$target_date, + rep(as.Date("2022-01-01"), ngeos) + ) + expect_identical( + flat$predictions$forecast_date, + rep(as.Date("2021-12-31"), ngeos) + ) + expect_true(all(!is.na(flat$predictions$.pred_distn))) + expect_true(all(!is.na(flat$predictions$.pred))) +}) + +test_that("arx_forecaster determines target_date where forecast_date exists", { + arx <- arx_forecaster( + train, "dr", c("dr", "cr"), + args_list = arx_args_list( + forecast_date = as.Date("2021-12-31"), + target_date = as.Date("2022-01-01"), + ahead = 1L + ) + ) + # previously, if target_date existed, it could be + # erroneously incremented by the ahead + expect_identical( + arx$predictions$target_date, + rep(as.Date("2022-01-01"), ngeos) + ) + expect_identical( + arx$predictions$forecast_date, + rep(as.Date("2021-12-31"), ngeos) + ) + expect_true(all(!is.na(arx$predictions$.pred_distn))) + expect_true(all(!is.na(arx$predictions$.pred))) +}) + +test_that("arx_classifier determines target_date where forecast_date exists", { + arx <- arx_classifier( + train, "dr", c("dr"), + trainer = parsnip::boost_tree(mode = "classification", trees = 5), + args_list = arx_class_args_list( + forecast_date = as.Date("2021-12-31"), + target_date = as.Date("2022-01-01"), + ahead = 1L + ) + ) + # previously, if target_date existed, it could be + # erroneously incremented by the ahead + expect_identical( + arx$predictions$target_date, + rep(as.Date("2022-01-01"), ngeos) + ) + expect_identical( + arx$predictions$forecast_date, + rep(as.Date("2021-12-31"), ngeos) + ) + expect_true(all(!is.na(arx$predictions$.pred_class))) +}) diff --git a/tests/testthat/test-wis-dist-quantiles.R b/tests/testthat/test-wis-dist-quantiles.R new file mode 100644 index 000000000..937793189 --- /dev/null +++ b/tests/testthat/test-wis-dist-quantiles.R @@ -0,0 +1,60 @@ +test_that("wis dispatches and produces the correct values", { + tau <- c(.2, .4, .6, .8) + q1 <- 1:4 + q2 <- 8:11 + wis_one_pred <- function(q, tau, actual) { + 2 * mean(pmax(tau * (actual - q), (1 - tau) * (q - actual)), na.rm = TRUE) + } + actual <- 5 + expected <- c(wis_one_pred(q1, tau, actual), wis_one_pred(q2, tau, actual)) + + dstn <- dist_quantiles(list(q1, q2), tau) + expect_equal(weighted_interval_score(dstn, actual), expected) + + # works with a single dstn + q <- sort(10 * rexp(23)) + tau0 <- c(.01, .025, 1:19 / 20, .975, .99) + dst <- dist_quantiles(q, tau0) + expect_equal(weighted_interval_score(dst, 10), wis_one_pred(q, tau0, 10)) + + # returns NA when expected + dst <- dist_quantiles(rep(NA, 3), c(.2, .5, .95)) + expect_true(is.na(weighted_interval_score(dst, 10))) + expect_equal( + weighted_interval_score(dstn, c(NA, actual)), + c(NA, wis_one_pred(q2, tau, actual)) + ) + + # errors for non distributions + expect_snapshot(error = TRUE, weighted_interval_score(1:10, 10)) + expect_warning(w <- weighted_interval_score(dist_normal(1), 10)) + expect_true(all(is.na(w))) + expect_warning(w <- weighted_interval_score( + c(dist_normal(), dist_quantiles(1:5, 1:5 / 6)), + 10 + )) + expect_equal(w, c(NA, wis_one_pred(1:5, 1:5 / 6, 10))) + + # errors if sizes don't match + expect_snapshot(error = TRUE, weighted_interval_score( + dist_quantiles(list(1:4, 8:11), 1:4 / 5), # length 2 + 1:3 + )) + + #' # Missing value behaviours + dstn <- dist_quantiles(c(1, 2, NA, 4), 1:4 / 5) + expect_equal(weighted_interval_score(dstn, 2.5), 0.5) + expect_equal(weighted_interval_score(dstn, 2.5, c(2, 4, 5, 6, 8) / 10), 0.4) + expect_equal( + weighted_interval_score(dist_quantiles(c(1, 2, NA, 4), 1:4 / 5), 3, na_handling = "drop"), + 2 / 3 + ) + expect_equal( + weighted_interval_score(dstn, 2.5, c(2, 4, 5, 6, 8) / 10, na_handling = "drop"), + 0.4 + ) + expect_true(is.na( + weighted_interval_score(dstn, 2.5, na_handling = "propagate") + )) + weighted_interval_score(dist_quantiles(1:4, 1:4 / 5), 2.5, 1:9 / 10, na_handling = "fail") +}) diff --git a/vignettes/.gitignore b/vignettes/.gitignore index 097b24163..324ceaf7e 100644 --- a/vignettes/.gitignore +++ b/vignettes/.gitignore @@ -1,2 +1,3 @@ *.html +*_cache/ *.R diff --git a/inst/extdata/all_states_covidcast_signals.rds b/vignettes/articles/all_states_covidcast_signals.rds similarity index 100% rename from inst/extdata/all_states_covidcast_signals.rds rename to vignettes/articles/all_states_covidcast_signals.rds diff --git a/vignettes/articles/case_death_rate_archive.rds b/vignettes/articles/case_death_rate_archive.rds new file mode 100644 index 000000000..b5209fb1d Binary files /dev/null and b/vignettes/articles/case_death_rate_archive.rds differ diff --git a/vignettes/articles/sliding.Rmd b/vignettes/articles/sliding.Rmd index a0b3312bc..1556c4a72 100644 --- a/vignettes/articles/sliding.Rmd +++ b/vignettes/articles/sliding.Rmd @@ -14,6 +14,7 @@ knitr::opts_chunk$set( ```{r pkgs} library(epipredict) +library(epidatr) library(data.table) library(dplyr) library(tidyr) @@ -24,49 +25,45 @@ library(purrr) # Demonstrations of sliding AR and ARX forecasters -A key function from the epiprocess package is `epi_slide()`, which allows the -user to apply a function or formula-based computation over variables in an -`epi_df` over a running window of `n` time steps (see the following `epiprocess` -vignette to go over the basics of the function: ["Slide a computation over -signal values"](https://cmu-delphi.github.io/epiprocess/articles/slide.html)). -The equivalent sliding method for an `epi_archive` object can be called by using -the wrapper function `epix_slide()` (refer to the following vignette for the -basics of the function: ["Work with archive objects and data -revisions"](https://cmu-delphi.github.io/epiprocess/articles/archive.html)). The -key difference from `epi_slide()` is that it performs version-aware -computations. That is, the function only uses data that would have been -available as of time t for that reference time. - -In this vignette, we use `epi_slide()` and `epix_slide()` for backtesting our -`arx_forecaster` on historical COVID-19 case data from the US and from Canada. -More precisely, we first demonstrate using `epi_slide()` to slide ARX -forecasters over an `epi_df` object and compare the results obtained from using -different forecasting engines. We then compare the results from version-aware -and unaware forecasting, where the former is obtained from applying -`epix_slide()` to the `epi_archive` object, while the latter is obtained from -applying `epi_slide()` to the latest snapshot of the data. +A key function from the epiprocess package is `epix_slide()` (refer to the +following vignette for the basics of the function: ["Work with archive objects +and data +revisions"](https://cmu-delphi.github.io/epiprocess/articles/archive.html)) +which allows performing version-aware computations. That is, the function only +uses data that would have been available as of time t for that reference time. + +In this vignette, we use `epix_slide()` for backtesting our `arx_forecaster` on +historical COVID-19 case data from the US and from Canada. We first examine the +results from a version-unaware forecaster, comparing two different fitting +engines and then we contrast this with version-aware forecasting. The former +will proceed by constructing an `epi_archive` that erases its version +information and then use `epix_slide()` to forecast the future. The latter will +keep the versioned data and proceed similarly by using `epix_slide()` to +forecast the future. ## Comparing different forecasting engines -### Example using CLI and case data from US states +### Example using CLI and case data from US states First, we download the version history (ie. archive) of the percentage of -doctor’s visits with CLI (COVID-like illness) computed from medical insurance +doctor's visits with CLI (COVID-like illness) computed from medical insurance claims and the number of new confirmed COVID-19 cases per 100,000 population -(daily) for all 50 states from the COVIDcast API. We process as before, with the -modification that we use `sync = locf` in `epix_merge()` so that the last -version of each observation can be carried forward to extrapolate unavailable -versions for the less up-to-date input archive. +(daily) for all 50 states from the COVIDcast API. + +
+ +Load a data archive + +We process as before, with the modification that we use `sync = locf` in +`epix_merge()` so that the last version of each observation can be carried +forward to extrapolate unavailable versions for the less up-to-date input +archive. ```{r grab-epi-data} theme_set(theme_bw()) -y <- readRDS(system.file( - "extdata", "all_states_covidcast_signals.rds", - package = "epipredict", mustWork = TRUE -)) - -y <- purrr::map(y, ~ select(.x, geo_value, time_value, version = issue, value)) +y <- readRDS("all_states_covidcast_signals.rds") %>% + purrr::map(~ select(.x, geo_value, time_value, version = issue, value)) x <- epix_merge( y[[1]] %>% rename(percent_cli = value) %>% as_epi_archive(compactify = FALSE), @@ -77,17 +74,17 @@ x <- epix_merge( rm(y) ``` -After obtaining the latest snapshot of the data, we produce forecasts on that -data using the default engine of simple linear regression and compare to a -random forest. +
-Note that all of the warnings about the forecast date being less than the most -recent update date of the data have been suppressed to avoid cluttering the -output. +We then obtaining the latest snapshot of the data and proceed to fake the +version information by setting `version = time_value`. This has the effect of +obtaining data that arrives exactly at the day of the time_value. -```{r make-arx-kweek, warning = FALSE} +```{r arx-kweek-preliminaries, warning = FALSE} # Latest snapshot of data, and forecast dates -x_latest <- epix_as_of(x, max_version = max(x$versions_end)) +x_latest <- epix_as_of(x, version = max(x$versions_end)) %>% + mutate(version = time_value) %>% + as_epi_archive() fc_time_values <- seq( from = as.Date("2020-08-01"), to = as.Date("2021-11-01"), @@ -95,89 +92,106 @@ fc_time_values <- seq( ) aheads <- c(7, 14, 21, 28) -k_week_ahead <- function(epi_df, outcome, predictors, ahead = 7, engine) { - epi_slide( - epi_df, - ~ arx_forecaster( - .x, outcome, predictors, engine, - args_list = arx_args_list(ahead = ahead) - ) %>% - extract2("predictions") %>% - select(-geo_value), - before = 120 - 1, - ref_time_values = fc_time_values, - new_col_name = "fc" - ) %>% - select(geo_value, time_value, starts_with("fc")) %>% - mutate(engine_type = engine$engine) +forecast_k_week_ahead <- function(epi_archive, outcome, predictors, ahead = 7, engine) { + epi_archive %>% + epix_slide( + .f = function(x, gk, rtv) { + arx_forecaster( + x, outcome, predictors, engine, + args_list = arx_args_list(ahead = ahead) + )$predictions %>% + mutate(engine_type = engine$engine) %>% + pivot_quantiles_wider(.pred_distn) + }, + .before = 120, + .versions = fc_time_values + ) } +``` +```{r make-arx-kweek} # Generate the forecasts and bind them together fc <- bind_rows( - map( - aheads, - ~ k_week_ahead( - x_latest, "case_rate", c("case_rate", "percent_cli"), .x, - engine = linear_reg() - ) - ) %>% list_rbind(), - map( - aheads, - ~ k_week_ahead( - x_latest, "case_rate", c("case_rate", "percent_cli"), .x, - engine = rand_forest(mode = "regression") - ) - ) %>% list_rbind() -) %>% - pivot_quantiles_wider(fc_.pred_distn) + map(aheads, ~ forecast_k_week_ahead( + x_latest, + outcome = "case_rate", + predictors = c("case_rate", "percent_cli"), + ahead = .x, + engine = linear_reg() + )), + map(aheads, ~ forecast_k_week_ahead( + x_latest, + outcome = "case_rate", + predictors = c("case_rate", "percent_cli"), + ahead = .x, + engine = rand_forest(mode = "regression") + )) +) ``` Here, `arx_forecaster()` does all the heavy lifting. It creates leads of the target (respecting time stamps and locations) along with lags of the features (here, the response and doctors visits), estimates a forecasting model using the -specified engine, creates predictions, and non-parametric confidence bands. +specified engine, creates predictions, and non-parametric confidence bands. To see how the predictions compare, we plot them on top of the latest case -rates. Note that even though we've fitted the model on all states, -we'll just display the -results for two states, California (CA) and Florida (FL), to get a sense of the -model performance while keeping the graphic simple. - -```{r plot-arx, message = FALSE, warning = FALSE, fig.width = 9, fig.height = 6} -fc_cafl <- fc %>% filter(geo_value %in% c("ca", "fl")) -x_latest_cafl <- x_latest %>% filter(geo_value %in% c("ca", "fl")) - -ggplot(fc_cafl, aes(fc_target_date, group = time_value, fill = engine_type)) + +rates. Note that even though we've fitted the model on all states, we'll just +display the results for two states, California (CA) and Florida (FL), to get a +sense of the model performance while keeping the graphic simple. + +
+ +Code for plotting +```{r plot-arx, message = FALSE, warning = FALSE} +fc_cafl <- fc %>% + tibble() %>% + filter(geo_value %in% c("ca", "fl")) +x_latest_cafl <- x_latest$DT %>% + tibble() %>% + filter(geo_value %in% c("ca", "fl")) + +p1 <- ggplot(fc_cafl, aes(target_date, group = forecast_date, fill = engine_type)) + geom_line( data = x_latest_cafl, aes(x = time_value, y = case_rate), inherit.aes = FALSE, color = "gray50" ) + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`), alpha = 0.4) + - geom_line(aes(y = fc_.pred)) + - geom_point(aes(y = fc_.pred), size = 0.5) + - geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + + geom_line(aes(y = .pred)) + + geom_point(aes(y = .pred), size = 0.5) + + geom_vline(aes(xintercept = forecast_date), linetype = 2, alpha = 0.5) + facet_grid(vars(geo_value), vars(engine_type), scales = "free") + scale_x_date(minor_breaks = "month", date_labels = "%b %y") + scale_fill_brewer(palette = "Set1") + labs(x = "Date", y = "Reported COVID-19 case rates") + theme(legend.position = "none") ``` +
+ +```{r show-plot1, fig.width = 9, fig.height = 6, echo=FALSE} +p1 +``` For the two states of interest, simple linear regression clearly performs better -than random forest in terms of accuracy of the predictions and does not -result in such in overconfident predictions (overly narrow confidence bands). -Though, in general, neither approach produces amazingly accurate forecasts. -This could be because -the behaviour is rather different across states and the effects of other notable -factors such as age and public health measures may be important to account for -in such forecasting. Including such factors as well as making enhancements such -as correcting for outliers are some improvements one could make to this simple -model.[^1] - -[^1]: Note that, despite the above caveats, simple models like this tend to out-perform many far more complicated models in the online Covid forecasting due to those models high variance predictions. +than random forest in terms of accuracy of the predictions and does not result +in such in overconfident predictions (overly narrow confidence bands). Though, +in general, neither approach produces amazingly accurate forecasts. This could +be because the behaviour is rather different across states and the effects of +other notable factors such as age and public health measures may be important to +account for in such forecasting. Including such factors as well as making +enhancements such as correcting for outliers are some improvements one could +make to this simple model.[^1] + +[^1]: Note that, despite the above caveats, simple models like this tend to +out-perform many far more complicated models in the online Covid forecasting due +to those models high variance predictions. + ### Example using case data from Canada +
+ +Data and forecasts. Similar to the above. + By leveraging the flexibility of `epiprocess`, we can apply the same techniques to data from other sources. Since some collaborators are in British Columbia, Canada, we'll do essentially the same thing for Canada as we did above. @@ -201,51 +215,50 @@ linear regression with those from using boosted regression trees. can <- readRDS(system.file( "extdata", "can_prov_cases.rds", package = "epipredict", mustWork = TRUE -)) - -can <- can %>% +)) %>% group_by(version, geo_value) %>% arrange(time_value) %>% mutate(cr_7dav = RcppRoll::roll_meanr(case_rate, n = 7L)) %>% as_epi_archive(compactify = TRUE) -can_latest <- epix_as_of(can, max_version = max(can$DT$version)) +can_latest <- epix_as_of(can, version = max(can$DT$version)) %>% + mutate(version = time_value) %>% + as_epi_archive() # Generate the forecasts, and bind them together can_fc <- bind_rows( map( aheads, - ~ k_week_ahead(can_latest, "cr_7dav", "cr_7dav", .x, linear_reg()) - ) %>% list_rbind(), + ~ forecast_k_week_ahead(can_latest, "cr_7dav", "cr_7dav", .x, linear_reg()) + ), map( aheads, - ~ k_week_ahead( + ~ forecast_k_week_ahead( can_latest, "cr_7dav", "cr_7dav", .x, boost_tree(mode = "regression", trees = 20) ) - ) %>% list_rbind() -) %>% - pivot_quantiles_wider(fc_.pred_distn) + ) +) ``` -The figures below shows the results for all of the provinces. +The figures below shows the results for all of the provinces. ```{r plot-can-fc-lr, message = FALSE, warning = FALSE, fig.width = 9, fig.height = 12} ggplot( can_fc %>% filter(engine_type == "lm"), - aes(x = fc_target_date, group = time_value) + aes(x = target_date, group = forecast_date) ) + coord_cartesian(xlim = lubridate::ymd(c("2020-12-01", NA))) + geom_line( - data = can_latest, aes(x = time_value, y = cr_7dav), + data = can_latest$DT %>% tibble(), aes(x = time_value, y = cr_7dav), inherit.aes = FALSE, color = "gray50" ) + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`, fill = geo_value), alpha = 0.4 ) + - geom_line(aes(y = fc_.pred)) + - geom_point(aes(y = fc_.pred), size = 0.5) + - geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + + geom_line(aes(y = .pred)) + + geom_point(aes(y = .pred), size = 0.5) + + geom_vline(aes(xintercept = forecast_date), linetype = 2, alpha = 0.5) + facet_wrap(~geo_value, scales = "free_y", ncol = 3) + scale_x_date(minor_breaks = "month", date_labels = "%b %y") + labs( @@ -258,19 +271,19 @@ ggplot( ```{r plot-can-fc-boost, message = FALSE, warning = FALSE, fig.width = 9, fig.height = 12} ggplot( can_fc %>% filter(engine_type == "xgboost"), - aes(x = fc_target_date, group = time_value) + aes(x = target_date, group = forecast_date) ) + coord_cartesian(xlim = lubridate::ymd(c("2020-12-01", NA))) + geom_line( - data = can_latest, aes(x = time_value, y = cr_7dav), + data = can_latest$DT %>% tibble(), aes(x = time_value, y = cr_7dav), inherit.aes = FALSE, color = "gray50" ) + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`, fill = geo_value), alpha = 0.4 ) + - geom_line(aes(y = fc_.pred)) + - geom_point(aes(y = fc_.pred), size = 0.5) + - geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + + geom_line(aes(y = .pred)) + + geom_point(aes(y = .pred), size = 0.5) + + geom_vline(aes(xintercept = forecast_date), linetype = 2, alpha = 0.5) + facet_wrap(~geo_value, scales = "free_y", ncol = 3) + scale_x_date(minor_breaks = "month", date_labels = "%b %y") + labs( @@ -285,81 +298,163 @@ and/or are overly confident (very narrow bands), particularly when boosted regression trees are used. But as this is meant to be a simple demonstration of sliding with different engines in `arx_forecaster`, we may devote another vignette to work on improving the predictive modelling using the suite of tools -available in epipredict. +available in `{epipredict}`. + +
-## Version-aware and unaware forecasting +## Version-aware forecasting -### Example using case data from US states +### Example using case data from US states We will now employ a forecaster that uses properly-versioned data (that would -have been available in real-time) to forecast future COVID-19 case rates from -current and past COVID-19 case rates for all states. That is, we can make -forecasts on the archive, `x`, and compare those to forecasts on the latest -data, `x_latest` using the same general set-up as above. For version-aware -forecasting, note that `x` is fed into `epix_slide()`, while for version-unaware -forecasting, `x_latest` is fed into `epi_slide()`. - -```{r make-ar-kweek-asof} -k_week_version_aware <- function(ahead = 7, version_aware = TRUE) { - if (version_aware) { - epix_slide( - x, - ~ arx_forecaster( - .x, "case_rate", c("case_rate", "percent_cli"), - args_list = arx_args_list(ahead = ahead) - ) %>% - extract2("predictions"), - before = 120 - 1, - ref_time_values = fc_time_values, - new_col_name = "fc" - ) %>% - mutate(engine_type = "lm", version_aware = version_aware) %>% - rename(geo_value = fc_geo_value) - } else { - k_week_ahead( - x_latest, "case_rate", c("case_rate", "percent_cli"), - ahead, linear_reg() - ) %>% mutate(version_aware = version_aware) - } +have been available in real-time) to forecast the 7 day average of future +COVID-19 case rates from current and past COVID-19 case rates and death rates +for all states. That is, we can make forecasts on the archive, `x`, and compare +those to forecasts on the latest data, `x_latest` using the same general set-up +as above. Note that in this example, we use a geo-pooled approach (using +combined data from all US states and territories) to train our model. + +
+ +Download data using `{epidatr}` +```{r load-data, eval=FALSE} +# loading in the data +states <- "*" + +confirmed_incidence_prop <- pub_covidcast( + source = "jhu-csse", + signals = "confirmed_incidence_prop", + time_type = "day", + geo_type = "state", + time_values = epirange(20200301, 20211231), + geo_values = states, + issues = epirange(20000101, 20211231) +) %>% + select(geo_value, time_value, version = issue, case_rate = value) %>% + arrange(geo_value, time_value) %>% + as_epi_archive(compactify = FALSE) + +deaths_incidence_prop <- pub_covidcast( + source = "jhu-csse", + signals = "deaths_incidence_prop", + time_type = "day", + geo_type = "state", + time_values = epirange(20200301, 20211231), + geo_values = states, + issues = epirange(20000101, 20211231) +) %>% + select(geo_value, time_value, version = issue, death_rate = value) %>% + arrange(geo_value, time_value) %>% + as_epi_archive(compactify = FALSE) + + +x <- epix_merge(confirmed_incidence_prop, deaths_incidence_prop, sync = "locf") + +x <- x %>% + epix_slide( + .versions = fc_time_values, + function(x, gk, rtv) { + x %>% + group_by(geo_value) %>% + epi_slide_mean(case_rate, .window_size = 7L) %>% + rename(case_rate_7d_av = slide_value_case_rate) %>% + epi_slide_mean(death_rate, ..window_size = 7L) %>% + rename(death_rate_7d_av = slide_value_death_rate) %>% + ungroup() + } + ) %>% + rename(version = time_value) %>% + rename( + time_value = slide_value_time_value, + geo_value = slide_value_geo_value, + case_rate = slide_value_case_rate, + death_rate = slide_value_death_rate, + case_rate_7d_av = slide_value_case_rate_7d_av, + death_rate_7d_av = slide_value_death_rate_7d_av + ) %>% + as_epi_archive(compactify = TRUE) + +saveRDS(x$DT, file = "case_death_rate_archive.rds") +``` + +```{r load-stored-data} +x <- readRDS("case_death_rate_archive.rds") +x <- as_epi_archive(x) +``` +
+ +Here we specify the ARX model. + +```{r make-arx-model} +aheads <- c(7, 14, 21) +fc_time_values <- seq( + from = as.Date("2020-09-01"), + to = as.Date("2021-12-31"), + by = "1 month" +) +forecaster <- function(x) { + map(aheads, function(ahead) { + arx_forecaster( + epi_data = x, + outcome = "death_rate_7d_av", + predictors = c("death_rate_7d_av", "case_rate_7d_av"), + trainer = quantile_reg(), + args_list = arx_args_list(lags = c(0, 7, 14, 21), ahead = ahead) + )$predictions + }) %>% + bind_rows() } +``` -# Generate the forecasts, and bind them together -fc <- bind_rows( - map(aheads, ~ k_week_version_aware(.x, TRUE)) %>% list_rbind(), - map(aheads, ~ k_week_version_aware(.x, FALSE)) %>% list_rbind() -) %>% pivot_quantiles_wider(fc_.pred_distn) +We can now use our forecaster function that we've created and use it in the +pipeline for forecasting the predictions. We store the predictions into the +`arx_preds` variable and calculate the most up to date version of the data in the +epi archive and store it as `x_latest`. + +```{r running-arx-forecaster} +arx_preds <- x %>% + epix_slide( + ~ forecaster(.x), + .before = 120, .versions = fc_time_values + ) %>% + mutate(engine_type = quantile_reg()$engine) %>% + mutate(ahead_val = target_date - forecast_date) + +x_latest <- epix_as_of(x, version = max(x$versions_end)) ``` -Now we can plot the results on top of the latest case rates. As before, we will only display and focus on the results for FL and CA for simplicity. +Now we plot both the actual and predicted 7 day average of the death rate for +the chosen states + +
-```{r plot-ar-asof, message = FALSE, warning = FALSE, fig.width = 9, fig.height = 6} -fc_cafl <- fc %>% filter(geo_value %in% c("ca", "fl")) -x_latest_cafl <- x_latest %>% filter(geo_value %in% c("ca", "fl")) +Code for the plot +```{r plot-arx-asof, message = FALSE, warning = FALSE} +states_to_show <- c("ca", "ny", "mi", "az") +fc_states <- arx_preds %>% + filter(geo_value %in% states_to_show) %>% + pivot_quantiles_wider(.pred_distn) -ggplot(fc_cafl, aes(x = fc_target_date, group = time_value, fill = version_aware)) + +x_latest_states <- x_latest %>% filter(geo_value %in% states_to_show) + +p2 <- ggplot(fc_states, aes(target_date, group = forecast_date)) + + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`, fill = geo_value), alpha = 0.4) + geom_line( - data = x_latest_cafl, aes(x = time_value, y = case_rate), + data = x_latest_states, aes(x = time_value, y = death_rate_7d_av), inherit.aes = FALSE, color = "gray50" ) + - geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`), alpha = 0.4) + - geom_line(aes(y = fc_.pred)) + - geom_point(aes(y = fc_.pred), size = 0.5) + - geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + - facet_grid(geo_value ~ version_aware, - scales = "free", - labeller = labeller(version_aware = label_both) - ) + + geom_line(aes(y = .pred, color = geo_value)) + + geom_point(aes(y = .pred, color = geo_value), size = 0.5) + + geom_vline(aes(xintercept = forecast_date), linetype = 2, alpha = 0.5) + + facet_wrap(~geo_value, scales = "free_y", ncol = 1L) + scale_x_date(minor_breaks = "month", date_labels = "%b %y") + - labs(x = "Date", y = "Reported COVID-19 case rates") + scale_fill_brewer(palette = "Set1") + + scale_color_brewer(palette = "Set1") + + labs(x = "Date", y = "7 day average COVID-19 death rates") + theme(legend.position = "none") ``` +
-Again, we observe that the results are not great for these two states, but -that's likely due to the simplicity of the model (ex. the omission of key -factors such as age and public health measures) and the quality of the data (ex. -we have not personally corrected for anomalies in the data). - -We shall leave it to the reader to try the above version aware and unaware -forecasting exercise on the Canadian case rate data. The above code for the -American state data should be readily adaptable for this purpose. +```{r show-plot2, fig.width = 9, fig.height = 6, echo = FALSE} +p2 +``` diff --git a/vignettes/articles/smooth-qr.Rmd b/vignettes/articles/smooth-qr.Rmd new file mode 100644 index 000000000..3d626b2e1 --- /dev/null +++ b/vignettes/articles/smooth-qr.Rmd @@ -0,0 +1,520 @@ +--- +title: "Smooth quantile regression" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Smooth quantile regression} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r setup, include = FALSE} +knitr::opts_chunk$set( + collapse = FALSE, + comment = "#>", + warning = FALSE, + message = FALSE, + out.width = "100%" +) +``` + +# Introducing smooth quantile regression + +Whereas other time-series forecasting examples in this package have used +(direct) models for single horizons, in multi-period forecasting, the goal is to +(directly) forecast several horizons simultaneously. This is useful in +epidemiological applications where decisions are based on the trend of a signal. + +The idea underlying smooth quantile regression is that set forecast targets can +be approximated by a smooth curve. This novel approach from +[Tuzhilina et al., 2022](https://arxiv.org/abs/2202.09723) +enforces smoothness across the +horizons and can be applied to point estimation by regression or interval +prediction by quantile regression. Our focus in this vignette is the latter. + +# Built-in function for smooth quantile regression and its parameters + +The built-in smooth quantile regression function, `smooth_quantile_reg()` +provides a model specification for smooth quantile regression that works under +the tidymodels framework. It has the following parameters and default values: + +```{r, eval = FALSE} +smooth_quantile_reg( + mode = "regression", + engine = "smoothqr", + outcome_locations = NULL, + quantile_levels = 0.5, + degree = 3L +) +``` + +For smooth quantile regression, the type of model or `mode` is regression. + +The only `engine` that is currently supported is `smooth_qr()` from the +[`smoothqr` package](https://dajmcdon.github.io/smoothqr/). + +The `outcome_locations` indicate the multiple horizon (ie. ahead) values. These +should be specified by the user. + +The `quantile_levels` parameter is a vector of values that indicates the +quantiles to be estimated. The default is the median (0.5 quantile). + +The `degree` parameter indicates the degree of the polynomials used for +smoothing of the response. It should be no more than the number of aheads. If +the degree is precisely equal to the number of aheads, then there is no +smoothing. To better understand this parameter and how it works, we should look +to its origins and how it is used in the model. + +# Model form + +Smooth quantile regression is linear auto-regressive, with the key feature being +a transformation that forces the coefficients to satisfy a smoothing constraint. +The purpose of this is for each model coefficient to be a smooth function of +ahead values, and so each such coefficient is set to be a linear combination of +smooth basis functions (such as a spline or a polynomial). + +The `degree` parameter controls the number of these polynomials used. It should +be no greater than the number of responses. This is a tuning parameter, and so +it can be chosen by performing a grid search with cross-validation. Intuitively, +$d = 1$ corresponds to the constant model, $d = 2$ gives straight line +forecasts, while $d = 3$ gives quadratic forecasts. Since a degree of 3 was +found to work well in the tested applications (see Section 9 of +[Tuzhilina et al., 2022](https://arxiv.org/abs/2202.09723)), +it is the default value. + +# Demonstration of smooth quantile regression + +```{r, message = FALSE} +library(epipredict) +library(dplyr) +library(purrr) +library(ggplot2) +theme_set(theme_bw()) +``` + +We will now apply smooth quantile regression on the real data used for COVID-19 +forecasting. The built-in dataset we will use is a subset of JHU daily data on +state cases and deaths. This sample data ranges from Dec. 31, 2020 to +Dec. 31, 2021. + +```{r} +edf <- case_death_rate_subset +``` + +We will set the forecast date to be November 30, 2021 so that we can produce +forecasts for target dates of 1 to 28 days ahead. We construct our test data, +`tedf` from the days beyond this. + +```{r} +fd <- as.Date("2021-11-30") + +tedf <- edf %>% filter(time_value >= fd) +``` + +We will use the most recent 3 months worth of data up to the forecast date for +training. + +```{r} +edf <- edf %>% filter(time_value < fd, time_value >= fd - 90L) +``` + +And for plotting our focus will be on a subset of two states - California and +Utah. + +```{r} +geos <- c("ut", "ca") +``` + +Suppose that our goal with this data is to predict COVID-19 death rates at +several horizons for each state. On day $t$, we want to predict new deaths $y$ +that are $a = 1,\dots, 28$ days ahead at locations $j$ using the death rates +from today, 1 week ago, and 2 weeks ago. So for each location, we'll predict the +median (0.5 quantile) for each of the target dates by using +$$ +\hat{y}_{j}(t+a) = \alpha(a) + \sum_{l = 0}^2 \beta_{l}(a) y_{j}(t - 7l) +$$ +where $\beta_{l}(a) = \sum_{i=1}^d \theta_{il} h_i(a)$ is the smoothing +constraint where ${h_1(a), \dots, h_d(a)}$ are the set of smooth basis functions +and $d$ is a hyperparameter that manages the flexibility of $\beta_{l}(a)$. +Remember that the goal is to have each $\beta_{l}(a)$ to be a smooth function of +the aheads and that is achieved through imposing the smoothing constraint. + +Note that this model is intended to be simple and straightforward. Our only +modification to this model is to add case rates as another predictive feature +(we will leave it to the reader to incorporate additional features beyond this +and the historical response values). We can update the basic model incorporate +the $k = 2$ predictive features of case and death rates for each location j, +$x_j(t) = (x_{j1}(t), x_{j2}(t))$ as follows: + +$$ +\hat{y}_{j}(t+a) = \alpha(a) + \sum_{k = 1}^2 \sum_{l = 0}^2 \beta_{kl}(a) x_{jk}(t - 7l) +$$ +where $\beta_{kl}(a) = \sum_{i=1}^d \theta_{ikl} h_i(a)$. + +Now, we will create our own forecaster from scratch by building up an +`epi_workflow` (there is no canned forecaster that is currently available). +Building our own forecaster allows for customization and control over the +pre-processing and post-processing actions we wish to take. + +The pre-processing steps we take in our `epi_recipe` are simply to lag the +predictor (by 0, 7, and 14 days) and lead the response by the multiple aheads +specified by the function user. + +The post-processing layers we add to our `frosting` are nearly as simple. We +first predict, unnest the prediction list-cols, omit NAs from them, and enforce +that they are greater than 0. + +The third component of an to an `epi_workflow`, the model, is smooth quantile +regression, which has three main arguments - the quantiles, aheads, and degree. + +After creating our `epi_workflow` with these components, we get our test data +based on longest lag period and make the predictions. + +We input our forecaster into a function for ease of use. + +```{r} +smooth_fc <- function(x, aheads = 1:28, degree = 3L, quantiles = 0.5, fd) { + rec <- epi_recipe(x) %>% + step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = aheads) + + f <- frosting() %>% + layer_predict() %>% + layer_unnest(.pred) %>% + layer_naomit(distn) %>% + layer_add_forecast_date() %>% + layer_threshold(distn) + + ee <- smooth_quantile_reg( + quantile_levels = quantiles, + outcome_locations = aheads, + degree = degree + ) + + ewf <- epi_workflow(rec, ee, f) + + the_fit <- ewf %>% fit(x) + + latest <- get_test_data(rec, x, fill_locf = TRUE) + + preds <- predict(the_fit, new_data = latest) %>% + mutate(forecast_date = fd, target_date = fd + ahead) %>% + select(geo_value, target_date, distn, ahead) %>% + pivot_quantiles_wider(distn) + + preds +} +``` + +Notice that we allow the function user to specify the aheads, degree, and +quantile as they may want to change these parameter values. We also allow for +input of the forecast date as we fixed that at the onset of this demonstration. + +We now can produce smooth quantile regression predictions for our problem: + +```{r, warning = FALSE} +smooth_preds <- smooth_fc(edf, fd = fd) + +smooth_preds +``` +Most often, we're not going to want to limit ourselves to just predicting the +median value as there is uncertainty about the predictions, so let's try to +predict several different quantiles in addition to the median: + +```{r, warning = FALSE} +several_quantiles <- c(.1, .25, .5, .75, .9) +smooth_preds <- smooth_fc(edf, quantiles = several_quantiles, fd = fd) + +smooth_preds +``` + +We can see that we have different columns for the different quantile +predictions. + +Let's visualize these results for the sample of two states. We will create a +simple plotting function, under which the median predictions are an orange line +and the surrounding quantiles are blue bands around this. For comparison, we +will include the actual values over time as a black line. + +```{r} +plot_preds <- function(preds, geos_to_plot = NULL, train_test_dat, fd) { + if (!is.null(geos_to_plot)) { + preds <- preds %>% filter(geo_value %in% geos_to_plot) + train_test_dat <- train_test_dat %>% filter(geo_value %in% geos_to_plot) + } + + ggplot(preds) + + geom_ribbon(aes(target_date, ymin = `0.1`, ymax = `0.9`), + fill = "cornflowerblue", alpha = .8 + ) + + geom_ribbon(aes(target_date, ymin = `0.25`, ymax = `0.75`), + fill = "#00488E", alpha = .8 + ) + + geom_line(data = train_test_dat, aes(time_value, death_rate)) + + geom_line(aes(target_date, `0.5`), color = "orange") + + geom_vline(xintercept = fd) + + facet_wrap(~geo_value) + + scale_x_date(name = "", date_labels = "%b %Y", date_breaks = "2 months") + + ylab("Deaths per 100K inhabitants") +} +``` + +Since we would like to plot the actual death rates for these states over time, +we bind the training and testing data together and input this into our plotting +function as follows: + +```{r, warning = FALSE} +plot_preds(smooth_preds, geos, bind_rows(tedf, edf), fd) +``` + +We can see that the predictions are smooth curves for each state, as expected +when using smooth quantile regression. In addition while the curvature of the +forecasts matches that of the truth, the forecasts do not look remarkably +accurate. + +## Varying the degrees parameter + +We can test the impact of different degrees by using the `map()` function. +Noting that this may take some time to run, let's try out all degrees from 1 +to 7: + +```{r, warning = FALSE} +smooth_preds_list <- map(1:7, ~ smooth_fc(edf, + degree = .x, + quantiles = c(.1, .25, .5, .75, .9), + fd = fd +) %>% + mutate(degree = .x)) %>% list_rbind() +``` + +One way to quantify the impact of these on the forecasting is to look at the +mean absolute error (MAE) or mean squared error (MSE) over the degrees. We can +select the degree that results in the lowest MAE. + +Since the MAE compares the predicted values to the actual values, we will first +join the test data to the predicted data for our comparisons: +```{r, message = FALSE} +tedf_sub <- tedf %>% + rename(target_date = time_value, actual = death_rate) %>% + select(geo_value, target_date, actual) +``` + +And then compute the MAE for each of the degrees: +```{r, message = FALSE} +smooth_preds_df_deg <- smooth_preds_list %>% + left_join(tedf_sub, by = c("geo_value", "target_date")) %>% + group_by(degree) %>% + mutate(error = abs(`0.5` - actual)) %>% + summarise(mean = mean(error)) + +# Arrange the MAE from smallest to largest +smooth_preds_df_deg %>% arrange(mean) +``` + +Instead of just looking at the raw numbers, let's create a simple line plot to +visualize how the MAE changes over degrees for this data: + +```{r} +ggplot(smooth_preds_df_deg, aes(degree, mean)) + + geom_line() + + xlab("Degrees of freedom") + + ylab("Mean MAE") +``` + +We can see that the degree that results in the lowest MAE is 3. Hence, we could +pick this degree for future forecasting work on this data. + +## A brief comparison between smoothing and no smoothing + +Now, we will briefly compare the results from using smooth quantile regression +to those obtained without smoothing. The latter approach amounts to ordinary +quantile regression to get predictions for the intended target date. The main +drawback is that it ignores the fact that the responses all represent the same +signal, just for different ahead values. In contrast, the smooth quantile +regression approach utilizes this information about the data structure - the +fact that the aheads in are not be independent of each other, but that they are +naturally related over time by a smooth curve. + +To get the basic quantile regression results we can utilize the forecaster that +we've already built. We can simply set the degree to be the number of ahead +values to re-run the code without smoothing. + +```{r, warning = FALSE} +baseline_preds <- smooth_fc( + edf, + degree = 28L, quantiles = several_quantiles, fd = fd +) +``` + +And we can produce the corresponding plot to inspect the predictions obtained +under the baseline model: + +```{r, warning = FALSE} +plot_preds(baseline_preds, geos, bind_rows(tedf, edf), fd) +``` + +Unlike for smooth quantile regression, the resulting forecasts are not smooth +curves, but rather jagged and irregular in shape. + +For a more formal comparison between the two approaches, we could compare the +test performance in terms of accuracy through calculating either the, MAE or +MSE, where the performance measure of choice can be calculated over over all +times and locations for each ahead value + +```{r, message = FALSE} +baseline_preds_mae_df <- baseline_preds %>% + left_join(tedf_sub, by = c("geo_value", "target_date")) %>% + group_by(ahead) %>% + mutate(error = abs(`0.5` - actual)) %>% + summarise(mean = mean(error)) %>% + mutate(type = "baseline") + +smooth_preds_mae_df <- smooth_preds %>% + left_join(tedf_sub, by = c("geo_value", "target_date")) %>% + group_by(ahead) %>% + mutate(error = abs(`0.5` - actual)) %>% + summarise(mean = mean(error)) %>% + mutate(type = "smooth") + +preds_mae_df <- bind_rows(baseline_preds_mae_df, smooth_preds_mae_df) + +ggplot(preds_mae_df, aes(ahead, mean, color = type)) + + geom_line() + + xlab("Ahead") + + ylab("Mean MAE") + + scale_color_manual(values = c("#A69943", "#063970")) +``` + +or over all aheads, times, and locations for a single numerical summary. + +```{r} +mean(baseline_preds_mae_df$mean) +mean(smooth_preds_mae_df$mean) +``` + +The former shows that forecasts for the immediate future and for the distant +future are more inaccurate for both models under consideration. The latter shows +that the smooth quantile regression model and baseline models perform very +similarly overall, with the smooth quantile regression model only slightly +beating the baseline model in terms of overall average MAE. + +One other commonly used metric is the Weighted Interval Score +(WIS, [Bracher et al., 2021](https://arxiv.org/pdf/2005.12881.pdf)), +which a scoring rule that is based on the population quantiles. The point is to +score the interval, whereas MAE only evaluates the accuracy of the point +forecast. + +Let $F$ be a forecast composed of predicted quantiles $q_{\tau}$ for the set of +quantile levels $\tau$. Then, in terms of the predicted quantiles, the WIS for +target variable $Y$ is represented as follows +([McDonald etal., 2021](https://www.pnas.org/doi/full/10.1073/pnas.2111453118)): + +$$ +WIS(F, Y) = 2 \sum_{\tau} \phi_{\tau} (Y - q_{\tau}) +$$ +where $\phi_{\tau}(x) = \tau |x|$ for $x \geq 0$ +and$\phi_{\tau}(x) = (1 - \tau) |x|$ for $x < 0$. + +This form is general as it can accommodate both symmetric and asymmetric +quantile levels. If the quantile levels are symmetric, then we can alternatively +express the WIS as a collection of central prediction intervals +($\ell_{\alpha}, u_{\alpha}$) parametrized by the exclusion probability +$\alpha$: + +$$ +WIS(F, Y) = \sum_{\alpha} \{ (u_{\alpha} - \ell_{\alpha}) + 2 \cdot \text{dist}(Y, [\ell_{\alpha}, u_{\alpha}]) \} +$$ +where $\text{dist}(a,S)$ is the smallest distance between point $a$ and an +element of set $S$. + +While we implement the former representation, we mention this form because it +shows the that the score can be decomposed into the addition of a sharpness +component (first term in the summand) and an under/overprediction component +(second term in the summand). This alternative representation is useful because +from it, we more easily see the major limitation to the WIS, which is that the +score tends to prioritize sharpness (how wide the interval is) relative to +coverage (if the interval contains the truth). + +Now, we write a simple function for the first representation of the score that +is compatible with the latest version of `epipredict` (adapted from the +corresponding function in +[smoothmpf-epipredict](https://github.com/dajmcdon/smoothmpf-epipredict)). The +inputs for it are the actual and predicted values and the quantile levels. + +```{r} +wis_dist_quantile <- function(actual, values, quantile_levels) { + 2 * mean(pmax( + quantile_levels * (actual - values), + (1 - quantile_levels) * (values - actual), + na.rm = TRUE + )) +} +``` + +Next, we apply the `wis_dist_quantile` function to get a WIS score for each +state on each target date. We then compute the mean WIS for each ahead value +over all of the states. The results for each of the smooth and baseline +forecasters are shown in a similar style line plot as we chose for MAE: + +```{r} +smooth_preds_wis_df <- smooth_preds %>% + left_join(tedf_sub, by = c("geo_value", "target_date")) %>% + rowwise() %>% + mutate(wis = wis_dist_quantile( + actual, c(`0.1`, `0.25`, `0.5`, `0.75`, `0.9`), + several_quantiles + )) %>% + group_by(ahead) %>% + summarise(mean = mean(wis)) %>% + mutate(type = "smooth") + +baseline_preds_wis_df <- baseline_preds %>% + left_join(tedf_sub, by = c("geo_value", "target_date")) %>% + rowwise() %>% + mutate(wis = wis_dist_quantile( + actual, c(`0.1`, `0.25`, `0.5`, `0.75`, `0.9`), + several_quantiles + )) %>% + group_by(ahead) %>% + summarise(mean = mean(wis)) %>% + mutate(type = "baseline") + +preds_wis_df <- bind_rows(smooth_preds_wis_df, baseline_preds_wis_df) + +ggplot(preds_wis_df, aes(ahead, mean, color = type)) + + geom_line() + + xlab("Ahead") + + ylab("Mean WIS") + + scale_color_manual(values = c("#A69943", "#063970")) +``` + +The results are consistent with what we saw for MAE: The forecasts for the near +and distant future tend to be inaccurate for both models. The smooth quantile +regression model only slightly outperforms the baseline model. + +Though averaging the WIS score over location and time tends to be the primary +aggregation scheme used in evaluation and model comparisons (see, for example, +[McDonald et al., 2021](https://www.pnas.org/doi/full/10.1073/pnas.2111453118)), +we can also obtain a single numerical summary by averaging over the aheads, +times, and locations: + +```{r} +mean(baseline_preds_wis_df$mean) +mean(smooth_preds_wis_df$mean) +``` + +Overall, both perspectives agree that the smooth quantile regression model tends +to perform only slightly better than the baseline model in terms of average WIS, +illustrating the difficulty of this forecasting problem. + +# What we've learned in a nutshell + +Smooth quantile regression is used in multi-period forecasting for predicting +several horizons simultaneously with a single smooth curve. It operates under +the key assumption that the future of the response can be approximated well by a +smooth curve. + +# Attribution + +The information presented on smooth quantile regression is from +[Tuzhilina et al., 2022](https://arxiv.org/abs/2202.09723). diff --git a/vignettes/articles/symptom-surveys.Rmd b/vignettes/articles/symptom-surveys.Rmd new file mode 100644 index 000000000..f480db575 --- /dev/null +++ b/vignettes/articles/symptom-surveys.Rmd @@ -0,0 +1,887 @@ +--- +title: "Can symptoms surveys improve COVID-19 forecasts?" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Using the update and adjust functions} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r setup, include = FALSE} +knitr::opts_chunk$set( + echo = TRUE, + collapse = FALSE, + comment = "#>", + warning = FALSE, + message = FALSE, + out.width = "100%" +) +``` + +# Introduction + +During the COVID-19 pandemic, Delphi ran COVID-19 symptom surveys through +Facebook and Google. In these surveys, millions of people in the US were asked +whether they or the people that they know are experiencing COVID-like symptoms. +This enabled the calculation of a "% CLI-in-community" signal for counties +across the US. This is simply an estimate of the percentage of people who know +someone who is presently sick with a COVID-like illness. + +These surveys were valuable tools for monitoring the pandemic because they +reported daily and not subject to reporting delays that plague other sources of +data. + +In this vignette, we will look at whether the % CLI-in-community indicators from +the Facebook and Google surveys improve the accuracy of short-term forecasts of +county-level COVID-19 case rates. The purpose here is to study and demonstrate +the value of the Facebook and Google % CLI-in-community signals to add +predictive power beyond what we can achieve with simple time series models +trained on case rates alone. + +Note that this vignette was adapted from the following [Delphi blog +post](https://delphi.cmu.edu/blog/2020/09/21/can-symptoms-surveys-improve-covid-19-forecasts/), +with the necessary modifications to enable the use of `epipredict`. The results +may be different from those on the blog post (one reason is that we are +exploring the use of a different forecaster and another is that we're using the +most recent versions of the datasets). + +Now, we will delve into the forecasting problem set-up and code followed by a +discussion of the results. + +## Problem Setup + +Our goal is to predict county-level COVID-19 case incidence rates for 1 and 2 +weeks ahead. For this, we restrict our attention to the 442 counties that had at +least 200 confirmed cases by May 14, 2020 (the end of the Google survey data) +and in which both the Facebook and Google % CLI-in-community signals are +available. + +To set the notation, let $Y_{l,t}$ denote the smoothed COVID-19 case incidence +rate for location (county) $l$ and day $t$. Let $F_{l,t}$ and $G_{l,t}$ denote +the Facebook and Google % CLI-in-community signals, respectively, for location +$l$ and time $t$. Note that we rescale all these signals from their given values +in our API so that they are true proportions. We then evaluate the following +four models: + +$$ +\begin{align} +h(Y_{l,t+d}) &\approx \alpha + \sum_{j = 0}^2 \beta_j h(Y_{l,t-7j}) \\ +h(Y_{l,t+d}) &\approx \alpha + \sum_{j = 0}^2 \beta_j h(Y_{l,t-7j}) + \sum_{j = 0}^2 \gamma_j h(F_{l, t-7j}) \\ +h(Y_{l,t+d}) &\approx \alpha + \sum_{j = 0}^2 \beta_j h(Y_{l,t-7j}) + \sum_{j = 0}^2 \tau_j h(G_{l, t-7j}) \\ +h(Y_{l,t+d}) &\approx \alpha + \sum_{j = 0}^2 \beta_j h(Y_{l,t-7j}) + \sum_{j = 0}^2 \gamma_j h(F_{l, t-7j}) + \sum_{j = 0}^2 \tau_j h(G_{l, t-7j}) +\end{align} +$$ +Here $d = 7$ or $d = 14$ depending on the target value, and $h$ is a +transformation to be specified later. + +We'll call the first model the "Cases" model because it bases its predictions of +future case rates on COVID-19 case rates (from 0, 1 and 2 weeks back). The +second model is called "Cases + Facebook" because it additionally incorporates +the current Facebook signal, and the Facebook signal from 1 and 2 weeks back. +The third model, "Cases + Google", is exactly the same as the second but +substitutes the Google signal instead of the Facebook one. The fourth and final +model model, "Cases + Facebook + Google", uses both Facebook and Google signals. +For each model, we use our canned autoregressive forecaster with quantile +regression to forecast at time $t_0$ (and predict case rates at $t_0 + d$). We +train over training over all locations, $l$ (all 442 counties), and all time $t$ +that are within the most recent 14 days of data available up to and including +time $t_0$. In other words, we use 14 trailing days for the training set. + +The forecasts are denoted by $\hat{Y}_{l, t_0 + d}$. To see how accurate these +forecasts are, we use the scaled absolute error: + +$$ +\frac{| \hat{Y}_{l, t_0 + d} - Y_{l, t_0 + d} |} {| Y_{l, t_0} - Y_{l, t_0 + d} |} +$$ + +where the error in the denominator is the strawman model error. This model +simply uses the most recent case rate for future predictions. You may recognize +this as an application of the flatline forecaster from `epipredict`. + +We normalize in this manner for two reasons. First, since the scaled error is +the fraction improvement over the strawman’s error, we get an interpretable +scale, where numebrs like 0.8 or 0.9 are favorable, and numbers like 2 or 5 are +increasingly disastrous. Second, in such problems we should expect considerable +county-to-county variability in the forecasting difficulty. Normalizing by the +strawman's error helps to adjust for this so that the results on aggregate are +not dominated by the county-to-county differences. + +## Transformations + +To help stabilize the variance of the case, Facebook and Google data, we chose +to use the logit transformation on their proportions. In actuality, we use a +"padded" version $h(x) = \log (\frac{x+a}{1-x+a})$ such that the numerator and +denominator are pushed away from zero by a small constant, $a = 0.01$. An +alternative to the logit transform is using a log transform (as in +$h(x) = \log (x+a)$ where $a$ is for padding). Note that such +variance-stabilizing transformations are used in the model fitting. When we +calculate the errors, we back-transform the values for comparison using the +inverse transform $h^{-1}$ so that we may calculate them on the original scale. + +## Forecasting Code + +The code below marches the forecast date $t_0$ forward, one day at a time for +the nine forecasting dates for which the four models can all be fit (May 6, 2020 +to May 14, 2020). Then, it fits the models, makes predictions 7 and 14 days +ahead (as permissible by the data), and records the errors. + +There are a number of benefits to using `epipredict` over writing the code from +scratch to fit and predict under each of the models. First, we do not have to +reformat the data for input into a model or concern ourselves with its unique +interface. We instead work under unifying interface to streamline the modelling +process. Second, we avoid having to write our own function to append shift +values (leads or lags). This is done for us under-the-hood in the +`arx_forecaster()` function. You can see this in the forecaster output by +inspecting the `step_epi_lag()` and `step_epi_ahead()` pre-processing steps in +the `epi_workflow`. Third, we only need one for loop for the forecast dates (and +not a second loop for the different aheads) as we can easily use `map()` with +the `arx_forecaster()` over the different ahead values, as we’ve done before. + +However, there are some trade-offs to bear in mind. For instance, since we are +using a canned arx forecaster, we are not able to easily modify and add steps +such as that for signal transformations to the pre-processing (this is +pre-specified as part of using a canned forecaster). If we were to code-up our +own forecaster under the `epipredict` framework, we could easily add steps to +re-scale and transform the signals to our `epi_recipe`. This would make the code +more succinct and self-contained. + +```{r, message = FALSE, warning = FALSE} +library(epidatr) +library(dplyr) +library(purrr) +library(epipredict) +library(recipes) + +case_num <- 200 +as_of_date <- "2020-05-14" +geo_values <- pub_covidcast( + source = "jhu-csse", + signals = "confirmed_cumulative_num", + geo_type = "county", + time_type = "day", + geo_values = "*", + time_values = epirange(20200514, 20200514) +) %>% + filter(value >= case_num) %>% + pull(geo_value) %>% + unique() + +# Fetch county-level Google and Facebook % CLI-in-community signals, and JHU +# confirmed case incidence proportion +start_day <- "2020-04-11" +end_day <- "2020-09-01" + +goog_sm_cli <- pub_covidcast( + source = "google-survey", + signals = "smoothed_cli", + geo_type = "county", + time_type = "day", + geo_values = "*", + time_values = epirange(start_day, end_day) +) %>% + filter(geo_value %in% geo_values) %>% + select(geo_value, time_value, value) %>% + rename(goog = value) + +fb_survey <- pub_covidcast( + source = "fb-survey", + signals = "smoothed_hh_cmnty_cli", + geo_type = "county", + time_type = "day", + geo_values = "*", + time_values = epirange(start_day, end_day) +) %>% + filter(geo_value %in% geo_values) %>% + select(geo_value, time_value, value) %>% + rename(fb = value) + +jhu_7dav_incid <- pub_covidcast( + source = "jhu-csse", + signals = "confirmed_7dav_incidence_prop", + geo_type = "county", + time_type = "day", + geo_values = "*", + time_values = epirange(start_day, end_day) +) %>% + filter(geo_value %in% geo_values) %>% + select(geo_value, time_value, value) %>% + rename(case = value) + +# Find "complete" counties, present in all three data signals at all times +geo_values_complete <- intersect( + intersect(goog_sm_cli$geo_value, fb_survey$geo_value), + jhu_7dav_incid$geo_value +) + +# Make one big matrix by joining these three data frames +z <- full_join(full_join(goog_sm_cli, fb_survey, by = c("geo_value", "time_value")), + jhu_7dav_incid, + by = c("geo_value", "time_value") +) %>% + filter(geo_value %in% geo_values_complete) %>% + as_epi_df() + +Logit <- function(x, a = 0.01) log((x + a) / (1 - x + a)) +Sigmd <- function(y, a = 0.01) (exp(y) * (1 + a) - a) / (1 + exp(y)) + +#### Parameters ##### + +# Transforms to consider, in what follows +trans <- Logit +inv_trans <- Sigmd + +# Rescale factors for our signals: bring them all down to proportions (between +# 0 and 1) +rescale_f <- 1e-2 # Originally a percentage +rescale_g <- 1e-2 # Originally a percentage +rescale_c <- 1e-5 # Originally a count per 100,000 people + +z <- z %>% mutate( + case = trans(case * rescale_c), + fb = trans(fb * rescale_f), + goog = trans(goog * rescale_g) +) + +# lead = 7 +leads <- c(7, 14) +lags <- c(0, 7, 14) +n <- 14 # Number of trailing days to use for the training set + +# Nine forecast dates +dates <- seq(as.Date("2020-05-06"), as.Date("2020-05-14"), by = "day") + +# List for storage of results +out_list <- vector(mode = "list", length = length(dates)) +for (k in 1:length(dates)) { + date <- dates[k] + + if (date %in% c("2020-05-13", "2020-05-14")) leads <- c(7, 14) else leads <- 7 + + # Pre-structuring test data + z_te <- z %>% + rename( + target_date = time_value, + target_case = case + ) %>% + select(geo_value, target_date, target_case) + + # Strawman model + out_df0 <- map(leads, ~ flatline_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case), + outcome = "case", + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err0 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err0, lead) + + + # Cases model + out_df1 <- map(leads, ~ arx_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case) %>% + filter(complete.cases(.)), + outcome = "case", + predictors = "case", + trainer = quantile_reg(), + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err1 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err1, lead) + + # Cases and Facebook model + out_df2 <- map(leads, ~ arx_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case, fb) %>% + filter(complete.cases(.)), + outcome = "case", + predictors = c("case", "fb"), + trainer = quantile_reg(), + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err2 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err2, lead) + + + # Cases and Google model + out_df3 <- map(leads, ~ arx_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case, goog) %>% + filter(complete.cases(.)), + outcome = "case", + predictors = c("case", "goog"), + trainer = quantile_reg(), + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err3 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err3, lead) + + # Cases, Facebook and Google model + out_df4 <- map(leads, ~ arx_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case, fb, goog) %>% + filter(complete.cases(.)), + outcome = "case", + predictors = c("case", "goog"), + trainer = quantile_reg(), + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err4 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err4, lead) + + # Left join of the results for all models + out_list[[k]] <- left_join(left_join(left_join(left_join(out_df0, out_df1), out_df2), out_df3), out_df4) +} +# Outside of loop bind rows of list +out_df <- do.call(rbind, out_list) +``` + +## Results: All Four Models + +Since there are only two common forecast dates available for the four models for +the 14-day-ahead forecasts (May 13 and May 14, 2020), we skip studying +the 14-day-ahead forecast results in this four-way model discussion. + +Below we compute the median scaled errors for each of the four models over +the 9-day test period. We can see that adding either or both of the survey +signals improves on the median scaled error of the model that uses cases only, +with the biggest gain achieved by the "Cases + Google" model. We can also see +that the median scaled errors are all close to 1 (with all but that from the +"Cases + Google" and "Cases + Facebook + Google" models exceeding 1), which +speaks to the difficulty of the forecasting problem. + +```{r} +library(dplyr) +library(tidyr) +library(ggplot2) + +model_names <- c( + "Cases", "Cases + Facebook", "Cases + Google", + "Cases + Facebook + Google" +) + +# Calculate the scaled errors for each model, that is, the error relative to the strawman's error +res_all4 <- out_df %>% + drop_na() %>% # Restrict to common time + mutate(across(err1:err4, ~ .x / err0)) %>% # compute relative error to strawman + mutate(across(err2:err4, list(diff = ~ err1 - .x))) %>% # relative to cases model + ungroup() %>% + select(-err0) + +# Calculate and print median errors, for all 4 models, and just 7 days ahead +res_err4 <- res_all4 %>% + select(-ends_with("diff")) %>% + pivot_longer( + names_to = "model", values_to = "err", + cols = -c(geo_value, forecast_date, lead) + ) %>% + mutate( + lead = factor(lead, labels = paste(leads, "days ahead")), + model = factor(model, labels = model_names) + ) + +knitr::kable( + res_err4 %>% + group_by(model, lead) %>% + summarize(err = median(err), n = length(unique(forecast_date))) %>% + arrange(lead) %>% ungroup() %>% + rename( + "Model" = model, "Median scaled error" = err, + "Target" = lead, "Test days" = n + ) %>% + filter(Target == "7 days ahead"), + caption = paste( + "Test period:", min(res_err4$forecast_date), "to", + max(res_err4$forecast_date) + ), + format = "html", table.attr = "style='width:70%;'" +) +``` +$$\\[0.01in]$$ +Are these differences in median scaled errors significant? Some basic hypothesis +testing suggests that some probably are: Below we conduct a sign test for +whether the difference in the "Cases" model’s scaled error and each other +model’s scaled error is centered at zero. The sign test is run on the 9 test +days x 442 counties = 3978 pairs of scaled errors. The p-value from the "Cases" +versus "Cases + Google" test is tiny and well below a cutoff of 0.01. In +contrast, the p-values from the "Cases" versus "Cases + Facebook" and the +"Cases" versus "Cases + Facebook + Google" tests are much bigger and exceed this +cutoff, suggesting that the Facebook survey is not adding as much for this +situation (meaning the time and ahead considered, etc.) + +```{r} +# Compute p-values using the sign test against a one-sided alternative, for +# all models, and just 7 days ahead +res_dif4 <- res_all4 %>% + select(-ends_with(as.character(1:4))) %>% + pivot_longer( + names_to = "model", values_to = "diff", + cols = -c(geo_value, forecast_date, lead) + ) %>% + mutate( + lead = factor(lead, labels = paste(leads, "days ahead")), + model = factor(model, + labels = c( + "Cases vs Cases + Facebook", + "Cases vs Cases + Google", + "Cases vs Cases + Facebook + Google" + ) + ) + ) + +knitr::kable( + res_dif4 %>% + group_by(model, lead) %>% + summarize(p = binom.test( + x = sum(diff > 0, na.rm = TRUE), + n = n(), alt = "greater" + )$p.val) %>% + ungroup() %>% filter(lead == "7 days ahead") %>% + rename("Comparison" = model, "Target" = lead, "P-value" = p), + format = "html", table.attr = "style='width:50%;'", + digiits = 3 +) +``` +$$\\[0.01in]$$ +We should take these test results with a grain of salt because the sign test +assumes independence of observations, which clearly cannot be true given the +spatiotemporal structure of our forecasting problem. To mitigate the dependence +across time (which intuitively seems to matter more than that across space), we +recomputed these tests in a stratified way, where for each day we run a sign +test on the scaled errors between two models over all 442 counties. The results +are plotted as histograms below; the "Cases + Google" and (to a lesser extent) +the "Cases + Facebook + Google" models appear to deliver some decently small +p-values, but this is not very evident with the "Cases + Facebook" model. Taking +a larger sample size (of more than nine test days) would be a natural next step +to take to see if these results persist. + +```{r} +# Red, blue (similar to ggplot defaults), then yellow +ggplot_colors <- c("#FC4E07", "#00AFBB", "#E7B800") + +ggplot(res_dif4 %>% + group_by(model, lead, forecast_date) %>% + summarize(p = binom.test( + x = sum(diff > 0, na.rm = TRUE), + n = n(), alt = "greater" + )$p.val) %>% + ungroup() %>% filter(lead == "7 days ahead"), aes(p)) + + geom_histogram(aes(color = model, fill = model), alpha = 0.4) + + scale_color_manual(values = ggplot_colors) + + scale_fill_manual(values = ggplot_colors) + + facet_wrap(vars(lead, model)) + + labs(x = "P-value", y = "Count") + + theme_bw() + + theme(legend.position = "none") +``` + +## Results: First Two Models + +One way to get a larger sample size with the current data is to compare a subset +of the models. Therefore, next we focus on comparing results between the "Cases" +and "Cases + Facebook" models only. Restricting to common forecast dates for +these two models yields a much longer test period for the 7 and 14-day-ahead +forecasts: May 20 through August 27, 2020. We make the code to compare these two +models a simple function so that we have the option to use it over different +dates or aheads (in particular, this function will be useful for the next +section where we explore several ahead values): + +```{r} +case_fb_mods <- function(forecast_dates, leads) { + # List for storage of results + out_list <- vector(mode = "list", length = length(forecast_dates)) + for (k in 1:length(forecast_dates)) { + date <- forecast_dates[k] + + # Pre-structuring test data + z_te <- z %>% + rename( + target_date = time_value, + target_case = case + ) %>% + select(geo_value, target_date, target_case) + + # Strawman model + out_df0 <- map(leads, ~ flatline_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case), + outcome = "case", + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err0 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err0, lead) + + # Cases model + out_df1 <- map(leads, ~ arx_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case) %>% + filter(complete.cases(.)), + outcome = "case", + predictors = "case", + trainer = quantile_reg(), + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err1 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err1, lead) + + # Cases and Facebook model + out_df2 <- map(leads, ~ arx_forecaster( + z %>% + filter(between(time_value, date - .x - max(lags) - n, date)) %>% + select(time_value, geo_value, case, fb) %>% + filter(complete.cases(.)), + outcome = "case", + predictors = c("case", "fb"), + trainer = quantile_reg(), + args_list = arx_args_list( + lags = lags, + ahead = .x, + nonneg = FALSE + ) + )$predictions %>% + mutate(lead = .x) %>% + left_join(z_te %>% filter(target_date == (date + .x)), by = c("geo_value", "target_date"))) %>% + list_rbind() %>% + mutate(err2 = abs(inv_trans(.pred) - inv_trans(target_case))) %>% + select(geo_value, forecast_date, err2, lead) + + # Left join of the results for all models + out_list[[k]] <- left_join(left_join(out_df0, out_df1), out_df2) + } + # Outside of loop bind rows and split into two lists by lead + out_df <- do.call(rbind, out_list) +} + +# Choose forecast dates common to the Cases and Cases + Facebook models +dates <- seq(as.Date("2020-05-20"), as.Date("2020-08-27"), by = "day") + +# Two leads to consider +leads <- c(7, 14) + +res <- case_fb_mods(dates, leads) +``` + +The median scaled errors over the test period are computed and reported below. +Now we see a decent improvement in median scaled error for the "Cases + +Facebook" model, which is true for both 7-day-ahead and 14-day-ahead forecasts. + +```{r} +# For just models 1 and 2, then calculate the scaled +# errors, that is, the error relative to the strawman's error +res_all2 <- res %>% + drop_na() %>% # Restrict to common time + mutate(across(err1:err2, ~ .x / err0)) %>% # compute relative error to strawman + mutate(err12_diff = err1 - err2) %>% # Compute differences + # relative to cases model + ungroup() %>% + select(-err0) + +# Calculate and print median errors, for just models 1 and 2, and both 7 and 14 +# days ahead +res_err2 <- res_all2 %>% + select(-ends_with("diff")) %>% + pivot_longer( + names_to = "model", values_to = "err", + cols = -c(geo_value, forecast_date, lead) + ) %>% + mutate( + lead = factor(lead, labels = paste(leads, "days ahead")), + model = factor(model, labels = model_names[1:2]) + ) + +knitr::kable( + res_err2 %>% + select(-ends_with("diff")) %>% + group_by(model, lead) %>% + summarize(err = median(err), n = length(unique(forecast_date))) %>% + arrange(lead) %>% ungroup() %>% + rename( + "Model" = model, "Median scaled error" = err, + "Target" = lead, "Test days" = n + ), + caption = paste( + "Test period:", min(res_err2$forecast_date), "to", + max(res_err2$forecast_date) + ), + format = "html", table.attr = "style='width:70%;'", digits = 3 +) +``` +$$\\[0.01in]$$ + +Thanks to the extended length of the test period, we can also plot the +trajectories of the median scaled errors over time, as we do below, with the +left plot concerning 7-day-ahead forecasts, and the right 14-day-ahead +forecasts. These plots reveal something at once interesting and bothersome: the +median scaled errors are quite volatile over time, and for some periods in July, +forecasting became much harder, with the scaled errors reaching above 1.5 +for 7-day-ahead forecasts, and above 1.8 for 14-day-ahead forecasts. +Furthermore, we can see a clear visual difference between the median scaled +errors from the "Cases + Facebook" model in red and the "Cases" model in black. +The former appears to be below the latter during periods with low median scaled +errors and above during periods where forecasting becomes hard and the scaled +errors shoot above 1. This suggests that the Facebook signal may be more useful +to incorporate during periods of time where forecasting is easier. + +```{r} +# Plot median errors as a function of time, for models 1 and 2, and both 7 and +# 14 days ahead +ggplot( + res_err2 %>% + group_by(model, lead, forecast_date) %>% + summarize(err = median(err)) %>% ungroup(), + aes(x = forecast_date, y = err) +) + + geom_line(aes(color = model)) + + scale_color_manual(values = c("black", ggplot_colors)) + + geom_hline(yintercept = 1, linetype = 2, color = "gray") + + facet_wrap(vars(lead)) + + labs(x = "Date", y = "Median scaled error") + + theme_bw() + + theme(legend.position = "bottom", legend.title = element_blank()) +``` + +The fact that the lines are non-coincident suggests that the results we’re +seeing here are likely to be significantly different, though it’s hard to say +definitively given the complicated dependence structure present in the data. +Below we perform a sign test for whether the difference in scaled errors from +the "Cases" and "Cases + Facebook" models is centered at zero. The p-values are +essentially zero, given the large sample sizes: 98 test days in total for +the 7-day-ahead forecasts and 91 days for the 14-day-ahead forecasts (times 442 +counties for each day). + +```{r} +# Compute p-values using the sign test against a one-sided alternative, just +# for models 1 and 2, and both 7 and 14 days ahead +res_dif2 <- res_all2 %>% + select(-ends_with(as.character(1:4))) %>% + pivot_longer( + names_to = "model", values_to = "diff", + cols = -c(geo_value, forecast_date, lead) + ) %>% + mutate( + lead = factor(lead, labels = paste(leads, "days ahead")), + model = factor(model, labels = "Cases > Cases + Facebook") + ) + +knitr::kable( + res_dif2 %>% + group_by(model, lead) %>% + summarize(p = binom.test( + x = sum(diff > 0, na.rm = TRUE), + n = n(), alt = "greater" + )$p.val) %>% + ungroup() %>% + rename("Comparison" = model, "Target" = lead, "P-value" = p), + format = "html", table.attr = "style='width:50%;'" +) +``` +$$\\[0.01in]$$ + +If we stratify and recompute p-values by forecast date, the bulk of p-values are +quite small. + +```{r} +ggplot(res_dif2 %>% + group_by(model, lead, forecast_date) %>% + summarize(p = binom.test( + x = sum(diff > 0, na.rm = TRUE), + n = n(), alt = "greater" + )$p.val) %>% + ungroup(), aes(p)) + + geom_histogram(aes(color = model, fill = model), alpha = 0.4) + + scale_color_manual(values = ggplot_colors) + + scale_fill_manual(values = ggplot_colors) + + facet_wrap(vars(lead, model)) + + labs(x = "P-value", y = "Count") + + theme_bw() + + theme(legend.position = "none") +``` + +This exploration illustrates an important point: The test period should be +chosen so that it is large enough in size to see differences (if there are any) +between the models under comparison. While we did not observe significant +differences between the "Cases" and "Cases + Facebook" models when the test +period was small at 9 days, we did observe a significant difference over this +extended test period of nearly 100 days. + +## Varying the Number of Days Ahead + +Statistical significance refers to whether an effect exists (as opposed to +occurring by chance), while practical significance refers to the magnitude of +the effect and whether it is meaningful in the real world. Hypothesis tests, +such as the sign tests we conducted above, tell us whether the differences in +errors are statistically significant, but not about their practical +significance. For example, for 7-day-ahead forecasts, what does an improvement +of 0.019 units on the scaled error scale really mean, when comparing the "Cases ++ Facebook" model to the "Cases" model? Is this a meaningful gain in practice? + +To answer questions such as these, we can look at the way that the median scaled +errors behave as a function of the number of days ahead. Previously, we +considered forecasting case rates just 7 and 14 days ahead. Now we will +systematically examine 5 through 20 days ahead (the key difference in the code +being that we use `leads = 5:20`). Note that running the code for this many +leads may take a while. + +```{r} +# Consider a number of leads +leads <- 5:20 + +res <- case_fb_mods(dates, leads) +``` + +We obtain and plot the median scaled errors for the "Cases" and "Cases + +Facebook" models for different number of days ahead for the forecast target. +This is done over May 20 through August 27 for the forecast dates that are +common to the two models. + +```{r} +err_by_lead <- res %>% + drop_na() %>% # Restrict to common time + mutate(across(err1:err2, ~ .x / err0)) %>% + ungroup() %>% + select(-err0) %>% + pivot_longer( + names_to = "model", values_to = "err", + cols = -c(geo_value, forecast_date, lead) + ) %>% + mutate(model = factor(model, labels = model_names[1:2])) %>% + group_by(model, lead) %>% + summarize(err = median(err)) %>% + ungroup() + +ggplot(err_by_lead, aes(x = lead, y = err)) + + geom_line(aes(color = model)) + + geom_point(aes(color = model)) + + scale_color_manual(values = c("black", ggplot_colors)) + + geom_hline( + yintercept = err_by_lead %>% + filter(lead %in% 7, model == "Cases") %>% pull(err), + linetype = 2, color = "gray" + ) + + labs( + title = "Forecasting errors by number of days ahead", + subtitle = sprintf( + "Over all counties with at least %i cumulative cases", + case_num + ), + x = "Number of days ahead", y = "Median scaled error" + ) + + theme_bw() # + theme(legend.position = "bottom", legend.title = element_blank()) +``` + +A first glance shows that the "Cases + Facebook" model, in red, gives better +median scaled errors at all ahead values. Furthermore, the vertical gap between +the two curves is consistently in the range of what we were seeing before (for 7 +and 14 days ahead), around 0.02 units on the scaled error scale. + +But if we look at this from a different angle, by considering the horizontal gap +between the curves, then we can infer something quite a bit more interesting: +For 7-day-ahead forecasts, the median scaled error of the "Cases" model +(indicated by the horizontal gray line) is comparable to that of 12-day-ahead +forecasts from the "Cases + Facebook" model. So using the % CLI-in-community +signal from our Facebook survey buys us around 4 extra days of lead time for +this forecasting problem, which is striking. As you might imagine, different +forecast targets yield different lead times (for 14-day-ahead forecasts, it +appears to be around 2 to 3 days of lead time), but the added value of the +survey signal is clear throughout. + +## Wrap-Up + +In this vignette, we've shown that either of the Facebook or Google % +CLI-in-community signals can improve the accuracy of short-term forecasts of +county-level COVID-19 case rates. The significance of these improvements is more +apparent with the Facebook signal, thanks to the much larger test period. With +either signal, the magnitude of the improvement offered seems modest but +nontrivial, especially because the forecasting problem is so difficult in the +first place. + +We reiterate that this was just a demo. Our analysis was fairly simple and lacks +a few qualities that we’d expect in a truly comprehensive, realistic forecasting +analysis. For reflection, let's discuss three possible areas to improve: + +1. The models we considered are simple autoregressive structures from standard + time series and could be improved in various ways (including, considering + other relevant dimensions like mobility measures, county health metrics, + etc.). + +2. The forecasts we produced are point rather than distributional forecasts. + That is, we predict a single number, rather than an entire distribution for + what happens 7 and 14 days ahead. Distributional forecasts portray + uncertainty in a transparent way, which is important in practice. + +3. The way we trained our forecast models does not account for data latency and + revisions, which are critical issues. For each (retrospective) forecast + date, $t_0$, we constructed forecasts by training on data that we fetched + from the API today, "as of" the day of writing this, and not "as of" the + forecast date. This matters because nearly all signals are subject to + latency and go through multiple revisions. + +On the flip side, our example here was not that far away from being realistic. +The models we examined are actually not too different from Delphi’s forecasters +in production. Also, the way we fit the quantile regression models in the code +extends immediately to multiple quantile regression (this just requires changing +the parameter `quantile_levels` in the call to `quantile_reg()`). And lastly, +it’s fairly easy to change the data acquisition step in the code so that data +gets pulled "as of" the forecast date (this requires specifying the parameter +`as_of` in the call to `pub_covidcast()` and should change per forecast date). + +Hopefully these preliminary findings have gotten you excited about the possible +uses of this symptom survey data. For further practice, try your hand at +implementing the suggested improvements or develop your own novel analytic +approach to extract insights from this data. diff --git a/vignettes/arx-classifier.Rmd b/vignettes/arx-classifier.Rmd new file mode 100644 index 000000000..b2a2bbf8e --- /dev/null +++ b/vignettes/arx-classifier.Rmd @@ -0,0 +1,281 @@ +--- +title: "Auto-regressive classifier" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Auto-regressive classifier} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r setup, include = FALSE} +knitr::opts_chunk$set( + echo = TRUE, + collapse = FALSE, + comment = "#>", + warning = FALSE, + message = FALSE, + out.width = "100%" +) +``` + +## Load required packages + +```{r, message = FALSE, warning = FALSE} +library(dplyr) +library(purrr) +library(ggplot2) +library(epipredict) +``` + +## Introducing the ARX classifier + +The `arx_classifier()` is an autoregressive classification model for `epi_df` +data that is used to predict a discrete class for each case under consideration. +It is a direct forecaster in that it estimates the classes at a specific horizon +or ahead value. + +To get a sense of how the `arx_classifier()` works, let's consider a simple +example with minimal inputs. For this, we will use the built-in +`case_death_rate_subset` that contains confirmed COVID-19 cases and deaths from +JHU CSSE for all states over Dec 31, 2020 to Dec 31, 2021. From this, we'll take +a subset of data for five states over June 4, 2021 to December 31, 2021. Our +objective is to predict whether the case rates are increasing when considering +the 0, 7 and 14 day case rates: + +```{r} +jhu <- case_death_rate_subset %>% + filter( + time_value >= "2021-06-04", + time_value <= "2021-12-31", + geo_value %in% c("ca", "fl", "tx", "ny", "nj") + ) + +out <- arx_classifier(jhu, outcome = "case_rate", predictors = "case_rate") + +out$predictions +``` + +The key takeaway from the predictions is that there are two prediction classes: +(-Inf, 0.25] and (0.25, Inf). This is because for our goal of classification +the classes must be discrete. The discretization of the real-valued outcome is +controlled by the `breaks` argument, which defaults to 0.25. Such breaks will be +automatically extended to cover the entire real line. For example, the default +break of 0.25 is silently extended to breaks = c(-Inf, .25, Inf) and, therefore, +results in two classes: [-Inf, 0.25] and (0.25, Inf). These two classes are +used to discretize the outcome. The conversion of the outcome to such classes is +handled internally. So if discrete classes already exist for the outcome in the +`epi_df`, then we recommend to code a classifier from scratch using the +`epi_workflow` framework for more control. + +The `trainer` is a `parsnip` model describing the type of estimation such that +`mode = "classification"` is enforced. The two typical trainers that are used +are `parsnip::logistic_reg()` for two classes or `parsnip::multinom_reg()` for +more than two classes. + +```{r} +workflows::extract_spec_parsnip(out$epi_workflow) +``` + +From the parsnip model specification, we can see that the trainer used is +logistic regression, which is expected for our binary outcome. More complicated +trainers like `parsnip::naive_Bayes()` or `parsnip::rand_forest()` may also be +used (however, we will stick to the basics in this gentle introduction to the +classifier). + +If you use the default trainer of logistic regression for binary classification +and you decide against using the default break of 0.25, then you should only +input one break so that there are two classification bins to properly +dichotomize the outcome. For example, let's set a break of 0.5 instead of +relying on the default of 0.25. We can do this by passing 0.5 to the `breaks` +argument in `arx_class_args_list()` as follows: + +```{r} +out_break_0.5 <- arx_classifier( + jhu, + outcome = "case_rate", + predictors = "case_rate", + args_list = arx_class_args_list( + breaks = 0.5 + ) +) + +out_break_0.5$predictions +``` +Indeed, we can observe that the two `.pred_class` are now (-Inf, 0.5] and (0.5, +Inf). See `help(arx_class_args_list)` for other available modifications. + +Additional arguments that may be supplied to `arx_class_args_list()` include the +expected `lags` and `ahead` arguments for an autoregressive-type model. These +have default values of 0, 7, and 14 days for the lags of the predictors and 7 +days ahead of the forecast date for predicting the outcome. There is also +`n_training` to indicate the upper bound for the number of training rows per +key. If you would like some practice with using this, then remove the filtering +command to obtain data within "2021-06-04" and "2021-12-31" and instead set +`n_training` to be the number of days between these two dates, inclusive of the +end points. The end results should be the same. In addition to `n_training`, +there are `forecast_date` and `target_date` to specify the date that the +forecast is created and intended, respectively. We will not dwell on such +arguments here as they are not unique to this classifier or absolutely essential +to understanding how it operates. The remaining arguments will be discussed +organically, as they are needed to serve our purposes. For information on any +remaining arguments that are not discussed here, please see the function +documentation for a complete list and their definitions. + +## Example of using the ARX classifier + +Now, to demonstrate the power and utility of this built-in arx classifier, we +will loosely adapt the classification example that was written from scratch in +`vignette("preprocessing-and-models")`. However, to keep things simple and not +merely a direct translation, we will only consider two prediction categories and +leave the extension to three as an exercise for the reader. + +To motivate this example, a major use of autoregressive classification models is +to predict upswings or downswings like in hotspot prediction models to +anticipate the direction of the outcome (see [McDonald, Bien, Green, Hu, et al. +(2021)](https://www.pnas.org/doi/full/10.1073/pnas.2111453118) for more on +these). In our case, one simple question that such models can help answer is... +Do we expect that the future will have increased case rates or not relative to +the present? + +To answer this question, we can create a predictive model for upswings and +downswings of case rates rather than the raw case rates themselves. For this +situation, our target is to predict whether there is an increase in case rates +or not. Following +[McDonald, Bien, Green, Hu, et al.(2021)](https://www.pnas.org/doi/full/10.1073/pnas.2111453118), +we look at the +relative change between $Y_{l,t}$ and $Y_{l, t+a}$, where the former is the case +rate at location $l$ at time $t$ and the latter is the rate for that location at +time $t+a$. Using these variables, we define a categorical response variable +with two classes + +$$\begin{align} +Z_{l,t} = \left\{\begin{matrix} +\text{up,} & \text{if } Y_{l,t}^\Delta > 0.25\\ +\text{not up,} & \text{otherwise} +\end{matrix}\right. +\end{align}$$ +where $Y_{l,t}^\Delta = (Y_{l, t} - Y_{l, t-7} / Y_{l, t-7}$. If $Y_{l,t}^\Delta$ > 0.25, meaning that the number of new cases over the week has increased by over 25\%, then $Z_{l,t}$ is up. This is the criteria for location $l$ to be a hotspot at time $t$. On the other hand, if $Y_{l,t}^\Delta$ \leq 0.25$, then then $Z_{l,t}$ is categorized as not up, meaning that there has not been a >25\% increase in the new cases over the past week. + +The logistic regression model we use to predict this binary response can be +considered to be a simplification of the multinomial regression model presented +in `vignette("preprocessing-and-models")`: + +$$\begin{align} +\pi_{\text{up}}(x) &= Pr(Z_{l, t} = \text{up}|x) = \frac{e^{g_{\text{up}}(x)}}{1 + e^{g_{\text{up}}(x)}}, \\ +\pi_{\text{not up}}(x)&= Pr(Z_{l, t} = \text{not up}|x) = 1 - Pr(Z_{l, t} = \text{up}|x) = \frac{1}{1 + e^{g_{\text{up}}(x)}} +\end{align}$$ +where + +$$ +g_{\text{up}}(x) = \log\left ( \frac{\Pr(Z_{l, t} = \text{up} \vert x)}{\Pr(Z_{l, t} = \text{not up} \vert x)} \right ) = \beta_{10} + \beta_{11}Y_{l,t}^\Delta + \beta_{12}Y_{l,t-7}^\Delta + \beta_{13}Y_{l,t-14}^\Delta. +$$ + +Now then, we will operate on the same subset of the `case_death_rate_subset` +that we used in our above example. This time, we will use it to investigate +whether the number of newly reported cases over the past 7 days has increased by +at least 25% compared to the preceding week for our sample of states. + +Notice that by using the `arx_classifier()` function we've completely eliminated +the need to manually categorize the response variable and implement +pre-processing steps, which was necessary in +`vignette("preprocessing-and-models")`. + +```{r} +log_res <- arx_classifier( + jhu, + outcome = "case_rate", + predictors = "case_rate", + args_list = arx_class_args_list( + breaks = 0.25 / 7 # division by 7 gives weekly not daily + ) +) + +log_res$epi_workflow +``` + +Comparing the pre-processing steps for this to those in the other vignette, we +can see that they are not precisely the same, but they cover the same essentials +of transforming `case_rate` to the growth rate scale (`step_growth_rate()`), +lagging the predictors (`step_epi_lag()`), leading the response +(`step_epi_ahead()`), which are both constructed from the growth rates, and +constructing the binary classification response variable (`step_mutate()`). + +On this topic, it is important to understand that we are not actually concerned +about the case values themselves. Rather we are concerned whether the quantity +of cases in the future is a lot larger than that in the present. For this +reason, the outcome does not remain as cases, but rather it is transformed by +using either growth rates (as the predictors and outcome in our example are) or +lagged differences. While the latter is closer to the requirements for the +[2022-23 CDC Flusight Hospitalization Experimental Target](https://github.com/cdcepi/Flusight-forecast-data/blob/745511c436923e1dc201dea0f4181f21a8217b52/data-experimental/README.md), +and it is conceptually easy to understand because it is simply the change of the +value for the horizon, it is not the default. The default is `growth_rate`. One +reason for this choice is because the growth rate is on a rate scale, not on the +absolute scale, so it fosters comparability across locations without any +conscious effort (on the other hand, when using the `lag_difference` one would +need to take care to operate on rates per 100k and not raw counts). We utilize +`epiprocess::growth_rate()` to create the outcome using some of the additional +arguments. One important argument for the growth rate calculation is the +`method`. Only `rel_change` for relative change should be used as the method +because the test data is the only data that is accessible and the other methods +require access to the training data. + +The other optional arguments for controlling the growth rate calculation (that +can be inputted as `additional_gr_args`) can be found in the documentation for +`epiprocess::growth_rate()` and the related +`vignette("growth_rate", package = "epiprocess")`. + +### Visualizing the results + +To visualize the prediction classes across the states for the target date, we +can plot our results as a heatmap. However, if we were to plot the results for +only one target date, like our 7-day ahead predictions, then that would be a +pretty sad heatmap (which would look more like a bar chart than a heatmap)... So +instead of doing that, let's get predictions for several aheads and plot a +heatmap across the target dates. To get the predictions across several ahead +values, we will use the map function in the same way that we did in other +vignettes: + +```{r} +multi_log_res <- map(1:40, ~ arx_classifier( + jhu, + outcome = "case_rate", + predictors = "case_rate", + args_list = arx_class_args_list( + breaks = 0.25 / 7, # division by 7 gives weekly not daily + ahead = .x + ) +)$predictions) %>% list_rbind() +``` + +We can plot a the heatmap of the results over the aheads to see if there's +anything novel or interesting to take away: + +```{r} +ggplot(multi_log_res, aes(target_date, geo_value, fill = .pred_class)) + + geom_tile() + + ylab("State") + + xlab("Target date") + + scale_fill_brewer(palette = "Set1") +``` + +While there is a bit of variability near to the end, we can clearly see that +there are upswings for all states starting from the beginning of January 2022, +which we can recall was when there was a massive spike in cases for many states. +So our results seem to align well with what actually happened at the beginning +of January 2022. + +## A brief reflection + +The most noticeable benefit of using the `arx_classifier()` function is the +simplification and reduction of the manual implementation of the classifier from +about 30 down to 3 lines. However, as we noted before, the trade-off for +simplicity is control over the precise pre-processing, post-processing, and +additional features embedded in the coding of a classifier. So the good thing is +that `epipredict` provides both - a built-in `arx_classifer()` or the means to +implement your own classifier from scratch by using the `epi_workflow` +framework. And which you choose will depend on the circumstances. Our advice is +to start with using the built-in classifier for ostensibly simple projects and +begin to implement your own when the modelling project takes a complicated turn. +To get some practice on coding up a classifier by hand, consider translating +this binary classification model example to an `epi_workflow`, akin to that in +`vignette("preprocessing-and-models")`. diff --git a/vignettes/epipredict.Rmd b/vignettes/epipredict.Rmd index fe911ede0..1925de2fb 100644 --- a/vignettes/epipredict.Rmd +++ b/vignettes/epipredict.Rmd @@ -28,78 +28,123 @@ library(epipredict) # Goals for the package -At a high level, our goal with `{epipredict}` is to make running simple Machine Learning / Statistical forecasters for epidemiology easy. However, this package is extremely extensible, and that is part of its utility. Our hope is that it is easy for users with epi training and some statistics to fit baseline models while still allowing those with more nuanced statistical understanding to create complicated specializations using the same framework. +At a high level, our goal with `{epipredict}` is to make running simple Machine +Learning / Statistical forecasters for epidemiology easy. However, this package +is extremely extensible, and that is part of its utility. Our hope is that it is +easy for users with epi training and some statistics to fit baseline models +while still allowing those with more nuanced statistical understanding to create +complicated specializations using the same framework. -Serving both populations is the main motivation for our efforts, but at the same time, we have tried hard to make it useful. +Serving both populations is the main motivation for our efforts, but at the same +time, we have tried hard to make it useful. ## Baseline models -We provide a set of basic, easy-to-use forecasters that work out of the box. -You should be able to do a reasonably limited amount of customization on them. Any serious customization happens with the framework discussed below). +We provide a set of basic, easy-to-use forecasters that work out of the box. You +should be able to do a reasonably limited amount of customization on them. Any +serious customization happens with the framework discussed below). -For the basic forecasters, we provide: - -* Baseline flat-line forecaster +For the basic forecasters, we provide: + +* Baseline flat-line forecaster * Autoregressive forecaster * Autoregressive classifier -All the forcasters we provide are built on our framework. So we will use these basic models to illustrate its flexibility. +All the forcasters we provide are built on our framework. So we will use these +basic models to illustrate its flexibility. ## Forecasting framework -Our framework for creating custom forecasters views the prediction task as a set of modular components. There are four types of components: - +Our framework for creating custom forecasters views the prediction task as a set +of modular components. There are four types of components: + 1. Preprocessor: make transformations to the data before model training 2. Trainer: train a model on data, resulting in a fitted model object 3. Predictor: make predictions, using a fitted model object and processed test data 4. Postprocessor: manipulate or transform the predictions before returning - -Users familiar with [`{tidymodels}`](https://www.tidymodels.org) and especially the [`{workflows}`](https://workflows.tidymodels.org) package will notice a lot of overlap. This is by design, and is in fact a feature. The truth is that `{epipredict}` is a wrapper around much that is contained in these packages. Therefore, if you want something from this -verse, it should "just work" (we hope). - -The reason for the overlap is that `{workflows}` _already implements_ the first three steps. And it does this very well. However, it is missing the postprocessing stage and currently has no plans for such an implementation. And this feature is important. The baseline forecaster we provide _requires_ postprocessing. Anything more complicated needs this as well. -The second omission from `{tidymodels}` is support for panel data. Besides epidemiological data, economics, psychology, sociology, and many other areas frequently deal with data of this type. So the framework of behind `{epipredict}` implements this. In principle, this has nothing to do with epidemiology, and one could simply use this package as a solution for the missing functionality in `{tidymodels}`. Again, this should "just work". - -All of the _panel data_ functionality is implemented through the `epi_df` data type in the companion [`{epiprocess}`](https://cmu-delphi.github.io/epiprocess/) package. There is much more to see there, but for the moment, it's enough to look at a simple one: +Users familiar with [`{tidymodels}`](https://www.tidymodels.org) and especially +the [`{workflows}`](https://workflows.tidymodels.org) package will notice a lot +of overlap. This is by design, and is in fact a feature. The truth is that +`{epipredict}` is a wrapper around much that is contained in these packages. +Therefore, if you want something from this -verse, it should "just work" (we +hope). + +The reason for the overlap is that `{workflows}` *already implements* the first +three steps. And it does this very well. However, it is missing the +postprocessing stage and currently has no plans for such an implementation. And +this feature is important. The baseline forecaster we provide *requires* +postprocessing. Anything more complicated needs this as well. + +The second omission from `{tidymodels}` is support for panel data. Besides +epidemiological data, economics, psychology, sociology, and many other areas +frequently deal with data of this type. So the framework of behind +`{epipredict}` implements this. In principle, this has nothing to do with +epidemiology, and one could simply use this package as a solution for the +missing functionality in `{tidymodels}`. Again, this should "just work". + +All of the *panel data* functionality is implemented through the `epi_df` data +type in the companion [`{epiprocess}`](https://cmu-delphi.github.io/epiprocess/) +package. There is much more to see there, but for the moment, it's enough to +look at a simple one: ```{r epidf} jhu <- case_death_rate_subset jhu ``` -This data is built into the package and contains the measured variables `case_rate` and `death_rate` for COVID-19 at the daily level for each US state for the year 2021. The "panel" part is because we have repeated measurements across a number of locations. - -The `epi_df` encodes the time stamp as `time_value` and the `key` as `geo_value`. While these 2 names are required, the values don't need to actually represent such objects. Additional `key`'s are also supported (like age group, ethnicity, taxonomy, etc.). - -The `epi_df` also contains some metadata that describes the keys as well as the vintage of the data. It's possible that data collected at different times for the _same set_ of `geo_value`'s and `time_value`'s could actually be different. For more details, see [`{epiprocess}`](https://cmu-delphi.github.io/epiprocess/articles/epiprocess.html). +This data is built into the package and contains the measured variables +`case_rate` and `death_rate` for COVID-19 at the daily level for each US state +for the year 2021. The "panel" part is because we have repeated measurements +across a number of locations. +The `epi_df` encodes the time stamp as `time_value` and the `key` as +`geo_value`. While these 2 names are required, the values don't need to actually +represent such objects. Additional `key`'s are also supported (like age group, +ethnicity, taxonomy, etc.). +The `epi_df` also contains some metadata that describes the keys as well as the +vintage of the data. It's possible that data collected at different times for +the *same set* of `geo_value`'s and `time_value`'s could actually be different. +For more details, see +[`{epiprocess}`](https://cmu-delphi.github.io/epiprocess/articles/epiprocess.html). ## Why doesn't this package already exist? As described above: -* Parts actually DO exist. There's a universe called `{tidymodels}`. It handles +* Parts actually DO exist. There's a universe called `{tidymodels}`. It handles preprocessing, training, and prediction, bound together, through a package called `{workflows}`. We built `{epipredict}` on top of that setup. In this way, you CAN use almost everything they provide. -* However, `{workflows}` doesn't do postprocessing. And nothing in the -verse handles _panel data_. +* However, `{workflows}` doesn't do postprocessing. And nothing in the -verse +handles _panel data_. * The tidy-team doesn't have plans to do either of these things. (We checked). * There are two packages that do _time series_ built on `{tidymodels}`, but it's -"basic" time series: 1-step AR models, exponential smoothing, STL decomposition, etc.[^2] Our group has not prioritized these sorts of models for epidemic forecasting, but one could also integrate these methods into our framework. +"basic" time series: 1-step AR models, exponential smoothing, STL decomposition, +etc.[^2] Our group has not prioritized these sorts of models for epidemic +forecasting, but one could also integrate these methods into our framework. -[^2]: These are [`{timetk}`](https://business-science.github.io/timetk/index.html) and [`{modeltime}`](https://business-science.github.io/timetk/index.html). There are _lots_ of useful methods there than can be used to do fairly complex machine learning methodology, though not directly for panel data and not for direct prediction of future targets. +[^2]: These are [`{timetk}`](https://business-science.github.io/timetk/index.html) +and [`{modeltime}`](https://business-science.github.io/timetk/index.html). There +are *lots* of useful methods there than can be used to do fairly complex machine +learning methodology, though not directly for panel data and not for direct +prediction of future targets. # Show me the basics -We start with the `jhu` data displayed above. -One of the "canned" forecasters we provide is an autoregressive forecaster with (or without) covariates that _directly_ trains on the response. This is in contrast to a typical "iterative" AR model that trains to predict one-step-ahead, and then plugs in the predictions to "leverage up" to longer horizons. +We start with the `jhu` data displayed above. One of the "canned" forecasters we +provide is an autoregressive forecaster with (or without) covariates that +*directly* trains on the response. This is in contrast to a typical "iterative" +AR model that trains to predict one-step-ahead, and then plugs in the +predictions to "leverage up" to longer horizons. -We'll estimate the model jointly across all locations using only the most recent 30 days. +We'll estimate the model jointly across all locations using only the most +recent 30 days. ```{r demo-workflow} jhu <- jhu %>% filter(time_value >= max(time_value) - 30) @@ -110,23 +155,27 @@ out <- arx_forecaster( ) ``` -This call produces a warning, which we'll ignore for now. But essentially, it's telling us that our data comes from May 2022 but we're trying to do a forecast for January 2022. The result is likely not an accurate measure of real-time forecast performance, because the data have been revised over time. +The `out` object has two components: -The `out` object has two components: - - 1. The predictions which is just another `epi_df`. It contains the predictions for each location along with additional columns. By default, these are a 90% predictive interval, the `forecast_date` (the date on which the forecast was putatively made) and the `target_date` (the date for which the forecast is being made). + 1. The predictions which is just another `epi_df`. It contains the predictions for +each location along with additional columns. By default, these are a 90% +predictive interval, the `forecast_date` (the date on which the forecast was +putatively made) and the `target_date` (the date for which the forecast is being +made). ```{r} out$predictions ``` - 2. A list object of class `epi_workflow`. This object encapsulates all the instructions necessary to create the prediction. More details on this below. + 2. A list object of class `epi_workflow`. This object encapsulates all the +instructions necessary to create the prediction. More details on this below. ```{r} out$epi_workflow ``` -Note that the `time_value` in the predictions is not necessarily meaningful, -but it is a required column in an `epi_df`, so it remains here. - -By default, the forecaster predicts the outcome (`death_rate`) 1-week ahead, using 3 lags of each predictor (`case_rate` and `death_rate`) at 0 (today), 1 week back and 2 weeks back. The predictors and outcome can be changed directly. The rest of the defaults are encapsulated into a list of arguments. This list is produced by `arx_args_list()`. +By default, the forecaster predicts the outcome (`death_rate`) 1-week ahead, +using 3 lags of each predictor (`case_rate` and `death_rate`) at 0 (today), 1 +week back and 2 weeks back. The predictors and outcome can be changed directly. +The rest of the defaults are encapsulated into a list of arguments. This list is +produced by `arx_args_list()`. ## Simple adjustments @@ -148,11 +197,19 @@ out2week <- arx_forecaster( ) ``` -Here, we've used different lags on the `case_rate` and are now predicting 2 weeks ahead. This example also illustrates a major difficulty with the "iterative" versions of AR models. This model doesn't produce forecasts for `case_rate`, and so, would not have data to "plug in" for the necessary lags.[^1] +Here, we've used different lags on the `case_rate` and are now predicting 2 +weeks ahead. This example also illustrates a major difficulty with the +"iterative" versions of AR models. This model doesn't produce forecasts for +`case_rate`, and so, would not have data to "plug in" for the necessary +lags.[^1] -[^1]: An obvious fix is to instead use a VAR and predict both, but this would likely increase the variance of the model, and therefore, may lead to less accurate forecasts for the variable of interest. +[^1]: An obvious fix is to instead use a VAR and predict both, but this would +likely increase the variance of the model, and therefore, may lead to less +accurate forecasts for the variable of interest. -Another property of the basic model is the predictive interval. We describe this in more detail in a different vignette, but it is easy to request multiple quantiles. +Another property of the basic model is the predictive interval. We describe this +in more detail in a different vignette, but it is easy to request multiple +quantiles. ```{r differential-levels} out_q <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), @@ -162,7 +219,11 @@ out_q <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), ) ``` -The column `.pred_dstn` in the `predictions` object is actually a "distribution" here parameterized by its quantiles. For this default forecaster, these are created using the quantiles of the residuals of the predictive model (possibly symmetrized). Here, we used 23 quantiles, but one can grab a particular quantile, +The column `.pred_dstn` in the `predictions` object is actually a "distribution" +here parameterized by its quantiles. For this default forecaster, these are +created using the quantiles of the residuals of the predictive model (possibly +symmetrized). Here, we used 23 quantiles, but one can grab a particular +quantile, ```{r q1} head(quantile(out_q$predictions$.pred_distn, p = .4)) @@ -178,7 +239,8 @@ out_q$predictions %>% unnest(.pred_distn) # then unnest it ``` -Additional simple adjustments to the basic forecaster can be made using the function: +Additional simple adjustments to the basic forecaster can be made using the +function: ```{r, eval = FALSE} arx_args_list( @@ -191,9 +253,12 @@ arx_args_list( ## Changing the engine -So far, our forecasts have been produced using simple linear regression. But this is not the only way to estimate such a model. -The `trainer` argument determines the type of model we want. -This takes a [`{parsnip}`](https://parsnip.tidymodels.org) model. The default is linear regression, but we could instead use a random forest with the `{ranger}` package: +So far, our forecasts have been produced using simple linear regression. But +this is not the only way to estimate such a model. The `trainer` argument +determines the type of model we want. This takes a +[`{parsnip}`](https://parsnip.tidymodels.org) model. The default is linear +regression, but we could instead use a random forest with the `{ranger}` +package: ```{r ranger, warning = FALSE} out_rf <- arx_forecaster( @@ -220,36 +285,38 @@ out_qr <- arx_forecaster( ) ``` -FWIW, this last case (using quantile regression), is not far from what the Delphi production forecast team used for its Covid forecasts over the past few years. +FWIW, this last case (using quantile regression), is not far from what the +Delphi production forecast team used for its Covid forecasts over the past few +years. ## Inner workings -Underneath the hood, this forecaster creates (and returns) an `epi_workflow`. -Essentially, this is a big S3 object that wraps up the 4 modular steps +Underneath the hood, this forecaster creates (and returns) an `epi_workflow`. +Essentially, this is a big S3 object that wraps up the 4 modular steps (preprocessing - postprocessing) described above. ### Preprocessing -Preprocessing is accomplished through a `recipe` (imagine baking a cake) as -provided in the [`{recipes}`](https://recipes.tidymodels.org) package. +Preprocessing is accomplished through a `recipe` (imagine baking a cake) as +provided in the [`{recipes}`](https://recipes.tidymodels.org) package. We've made a few modifications (to handle panel data) as well as added some additional options. The recipe gives a specification of how to handle training data. Think of it like a fancified -`formula` that you would pass to `lm()`: `y ~ x1 + log(x2)`. In general, -there are 2 extensions to the `formula` that `{recipes}` handles: +`formula` that you would pass to `lm()`: `y ~ x1 + log(x2)`. In general, +there are 2 extensions to the `formula` that `{recipes}` handles: - 1. Doing transformations of both training and test data that can always be - applied. These are things like taking the log of a variable, leading or + 1. Doing transformations of both training and test data that can always be + applied. These are things like taking the log of a variable, leading or lagging, filtering out rows, handling dummy variables, etc. - 2. Using statistics from the training data to eventually process test data. + 2. Using statistics from the training data to eventually process test data. This is a major benefit of `{recipes}`. It prevents what the tidy team calls "data leakage". A simple example is centering a predictor by its mean. We need to store the mean of the predictor from the training data and use that value on the test data rather than accidentally calculating the mean of the test predictor for centering. - + A recipe is processed in 2 steps, first it is "prepped". This calculates and -stores any intermediate statistics necessary for use on the test data. +stores any intermediate statistics necessary for use on the test data. Then it is "baked" resulting in training data ready for passing into a statistical model (like `lm`). @@ -258,13 +325,14 @@ the `time_value`, `geo_value`, and any additional keys so that these are availab when necessary. The `epi_recipe` from `out_gb` can be extracted from the result: + ```{r} extract_recipe(out_gb$epi_workflow) ``` The "Inputs" are the original `epi_df` and the "roles" that these are assigned. -None of these are predictors or outcomes. Those will be created -by the recipe when it is prepped. The "Operations" are the sequence of +None of these are predictors or outcomes. Those will be created +by the recipe when it is prepped. The "Operations" are the sequence of instructions to create the cake (baked training data). Here we create lagged predictors, lead the outcome, and then remove `NA`s. Some models like `lm` internally handle `NA`s, but not everything does, so we @@ -289,7 +357,7 @@ Users with familiarity with the `{parsnip}` package will have no trouble here. Basically, `{parsnip}` unifies the function signature across statistical models. For example, `lm()` "likes" to work with formulas, but `glmnet::glmnet()` uses `x` and `y` for predictors and response. `{parsnip}` is agnostic. Both of these -do "linear regression". Above we switched from `lm()` to `xgboost()` without +do "linear regression". Above we switched from `lm()` to `xgboost()` without any issue despite the fact that these functions couldn't be more different. ```{r, eval = FALSE} @@ -308,7 +376,7 @@ xgboost( ) ``` -`{epipredict}` provides a few engines/modules (the flatline forecaster and +`{epipredict}` provides a few engines/modules (the flatline forecaster and quantile regression), but you should be able to use any available models listed [here](https://www.tidymodels.org/find/parsnip/). @@ -322,8 +390,8 @@ ewf <- epi_workflow(er, linear_reg()) %>% fit(jhu) To stretch the metaphor of preparing a cake to its natural limits, we have created postprocessing functionality called "frosting". Much like the recipe, -each postprocessing operation is a "layer" and we "slather" these onto our -baked cake. To fix ideas, below is the postprocessing `frosting` for +each postprocessing operation is a "layer" and we "slather" these onto our +baked cake. To fix ideas, below is the postprocessing `frosting` for `arx_forecaster()` ```{r} @@ -333,7 +401,7 @@ extract_frosting(out_q$epi_workflow) Here we have 5 layers of frosting. The first generates the forecasts from the test data. The second uses quantiles of the residuals to create distributional forecasts. The next two add columns for the date the forecast was made and the -date for which it is intended to occur. Because we are predicting rates, they +date for which it is intended to occur. Because we are predicting rates, they should be non-negative, so the last layer thresholds both predicted values and intervals at 0. The code to do this (inside the forecaster) is @@ -349,13 +417,12 @@ f <- frosting() %>% layer_threshold(starts_with(".pred")) ``` -At predict time, we add this object onto the `epi_workflow` and call `predict()` +At predict time, we add this object onto the `epi_workflow` and call `forecast()` ```{r, warning=FALSE} -test_data <- get_test_data(er, jhu) ewf %>% add_frosting(f) %>% - predict(test_data) + forecast() ``` The above `get_test_data()` function examines the recipe and ensures that enough @@ -369,20 +436,18 @@ that contained the necessary predictors. ## Conclusion Internally, we provide some simple functions to create reasonable forecasts. -But ideally, a user could create their own forecasters by building up the +But ideally, a user could create their own forecasters by building up the components we provide. In other vignettes, we try to walk through some of these -customizations. +customizations. -To illustrate everything above, here is (roughly) the code for the `flatline_forecaster()` applied to the `case_rate`. +To illustrate everything above, here is (roughly) the code for the +`flatline_forecaster()` applied to the `case_rate`. ```{r} r <- epi_recipe(jhu) %>% step_epi_ahead(case_rate, ahead = 7, skip = TRUE) %>% update_role(case_rate, new_role = "predictor") %>% - add_role(all_of(epi_keys(jhu)), new_role = "predictor") - -# bit of a weird hack to get the latest values per key -latest <- get_test_data(epi_recipe(jhu), jhu) + add_role(all_of(key_colnames(jhu)), new_role = "predictor") f <- frosting() %>% layer_predict() %>% @@ -393,11 +458,11 @@ f <- frosting() %>% eng <- linear_reg() %>% set_engine("flatline") wf <- epi_workflow(r, eng, f) %>% fit(jhu) -preds <- predict(wf, latest) +preds <- forecast(wf) ``` -All that really differs from the `arx_forecaster()` is the `recipe`, the -test data, and the engine. The `frosting` is identical, as is the fitting +All that really differs from the `arx_forecaster()` is the `recipe`, the +test data, and the engine. The `frosting` is identical, as is the fitting and predicting procedure. ```{r} diff --git a/vignettes/panel-data.Rmd b/vignettes/panel-data.Rmd new file mode 100644 index 000000000..0dea322f2 --- /dev/null +++ b/vignettes/panel-data.Rmd @@ -0,0 +1,485 @@ +--- +title: "Using epipredict on non-epidemic panel data" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Using epipredict on non-epidemic panel data} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r setup, include=F} +knitr::opts_chunk$set( + echo = TRUE, + collapse = TRUE, + comment = "#>", + out.width = "90%", + fig.align = "center" +) +``` + +```{r libraries, warning=FALSE, message=FALSE} +library(dplyr) +library(tidyr) +library(parsnip) +library(recipes) +library(epiprocess) +library(epipredict) +library(ggplot2) +theme_set(theme_bw()) +``` + +[Panel data](https://en.wikipedia.org/wiki/Panel_data), or longitudinal data, +contain cross-sectional measurements of subjects over time. The `epipredict` +package is most suitable for running forecasters on epidemiological panel data. +A built-in example of this is the [`case_death_rate_subset`]( + https://cmu-delphi.github.io/epipredict/reference/case_death_rate_subset.html) +dataset, which contains daily state-wise measures of `case_rate` and +`death_rate` for COVID-19 in 2021: + +```{r epi-panel-ex, include=T} +head(case_death_rate_subset, 3) +``` + +`epipredict` functions work with data in +[`epi_df`](https://cmu-delphi.github.io/epiprocess/reference/epi_df.html) +format. Despite the stated goal and name of the package, other panel datasets +are also valid candidates for `epipredict` functionality, as long as they are +in `epi_df` format. + +```{r employ-stats, include=F} +data("grad_employ_subset") +year_start <- min(grad_employ_subset$time_value) +year_end <- max(grad_employ_subset$time_value) +``` + +# Example panel data overview + +In this vignette, we will demonstrate using `epipredict` with employment panel +data from Statistics Canada. We will be using +[ + Table 37-10-0115-01: Characteristics and median employment income of + longitudinal cohorts of postsecondary graduates two and five years after + graduation, by educational qualification and field of study (primary + groupings) +](https://www150.statcan.gc.ca/t1/tbl1/en/tv.action?pid=3710011501). + +The full dataset contains yearly median employment income two and five years +after graduation, and number of graduates. The data is stratified by +variables such as geographic region (Canadian province), education, and +age group. The year range of the dataset is `r year_start` to `r year_end`, +inclusive. The full dataset also contains metadata that describes the +quality of data collected. For demonstration purposes, we make the following +modifications to get a subset of the full dataset: + +* Only keep provincial-level geographic region (the full data also has +"Canada" as a region) +* Only keep "good" or better quality data rows, as indicated by the [`STATUS`]( + https://www.statcan.gc.ca/en/concepts/definitions/guide-symbol) column +* Choose a subset of covariates and aggregate across the remaining ones. The +chosen covariates are age group, and educational qualification. + +To use this data with `epipredict`, we need to convert it into `epi_df` format +using `epiprocess::as_epi_df()` +with additional keys. In our case, the additional keys are `age_group`, +and `edu_qual`. Note that in the above modifications, we encoded `time_value` +as type `integer`. This lets us set `time_type = "year"`, and ensures that +lag and ahead modifications later on are using the correct time units. See the +`epiprocess::epi_df` for +a list of all the `time_type`s available. + +```{r data-dim, include=F} +employ_rowcount <- format(nrow(grad_employ_subset), big.mark = ",") +employ_colcount <- length(names(grad_employ_subset)) +``` + +Now, we are ready to use `grad_employ_subset` with `epipredict`. +Our `epi_df` contains `r employ_rowcount` rows and `r employ_colcount` columns. +Here is a quick summary of the columns in our `epi_df`: + +* `time_value` (time value): year in `date` format +* `geo_value` (geo value): province in Canada +* `num_graduates` (raw, time series value): number of graduates +* `med_income_2y` (raw, time series value): median employment income 2 years +after graduation +* `med_income_5y` (raw, time series value): median employment income 5 years +after graduation +* `age_group` (key): one of two age groups, either 15 to 34 years, or 35 to 64 +years +* `edu_qual` (key): one of 32 unique educational qualifications, e.g., +"Master's diploma" + +```{r preview-data, include=T} +# Rename for simplicity +employ <- grad_employ_subset +sample_n(employ, 6) +``` + +In the following sections, we will go over pre-processing the data in the +`epi_recipe` framework, and fitting a model and making predictions within the +`epipredict` framework and using the package's canned forecasters. + +# Autoregressive (AR) model to predict number of graduates in a year + +## Pre-processing + +As a simple example, let's work with the `num_graduates` column for now. We will +first pre-process by standardizing each numeric column by the total within +each group of keys. We do this since those raw numeric values will vary greatly +from province to province since there are large differences in population. + +```{r employ-small, include=T} +employ_small <- employ %>% + group_by(geo_value, age_group, edu_qual) %>% + # Select groups where there are complete time series values + filter(n() >= 6) %>% + mutate( + num_graduates_prop = num_graduates / sum(num_graduates), + med_income_2y_prop = med_income_2y / sum(med_income_2y), + med_income_5y_prop = med_income_5y / sum(med_income_5y) + ) %>% + ungroup() +head(employ_small) +``` + +Below is a visualization for a sample of the small data for British Columbia and Ontario. +Note that some groups +do not have any time series information since we filtered out all time series +with incomplete dates. + +```{r employ-small-graph, include=T, eval=T, fig.width=9, fig.height=6} +employ_small %>% + filter(geo_value %in% c("British Columbia", "Ontario")) %>% + filter(grepl("degree", edu_qual, fixed = T)) %>% + group_by(geo_value, time_value, edu_qual, age_group) %>% + summarise(num_graduates_prop = sum(num_graduates_prop), .groups = "drop") %>% + ggplot(aes(x = time_value, y = num_graduates_prop, color = geo_value)) + + geom_line() + + scale_colour_manual(values = c("Cornflowerblue", "Orange"), name = "") + + facet_grid(rows = vars(edu_qual), cols = vars(age_group)) + + xlab("Year") + + ylab("Percentage of gratuates") + + theme(legend.position = "bottom") +``` + +We will predict the standardized number of graduates (a proportion) in the +next year (time $t+1$) using an autoregressive model with three lags (i.e., an +AR(3) model). Such a model is represented algebraically like this: + +\[ + y_{t+1,ijk} = + \alpha_0 + \alpha_1 y_{tijk} + \alpha_2 y_{t-1,ijk} + \alpha_3 y_{t-2,ijk} + \epsilon_{tijk} +\] + +where $y_{tij}$ is the proportion of graduates at time $t$ in location $i$ and +age group $j$ with education quality $k$. + +In the pre-processing step, we need to create additional columns in `employ` for +each of $y_{t+1,ijk}$, $y_{tijk}$, $y_{t-1,ijk}$, and $y_{t-2,ijk}$. +We do this via an +`epi_recipe`. Note that creating an `epi_recipe` alone doesn't add these +outcome and predictor columns; the recipe just stores the instructions for +adding them. + +Our `epi_recipe` should add one `ahead` column representing $y_{t+1,ijk}$ and +3 `lag` columns representing $y_{tijk}$, $y_{t-1,ijk}$, and $y_{t-2,ijk}$ +(it's more accurate to think of the 0th "lag" as the "current" value with 2 lags, +but that's not quite how the processing works). +Also note that +since we specified our `time_type` to be `year`, our `lag` and `lead` +values are both in years. + +```{r make-recipe, include=T, eval=T} +r <- epi_recipe(employ_small) %>% + step_epi_ahead(num_graduates_prop, ahead = 1) %>% + step_epi_lag(num_graduates_prop, lag = 0:2) %>% + step_epi_naomit() +r +``` + +Let's apply this recipe using `prep` and `bake` to generate and view the `lag` +and `ahead` columns. + +```{r view-preprocessed, include=T} +# Display a sample of the pre-processed data +bake_and_show_sample <- function(recipe, data, n = 5) { + recipe %>% + prep(data) %>% + bake(new_data = data) %>% + sample_n(n) +} + +r %>% bake_and_show_sample(employ_small) +``` + +We can see that the `prep` and `bake` steps created new columns according to +our `epi_recipe`: + +- `ahead_1_num_graduates_prop` corresponds to $y_{t+1,ijk}$ +- `lag_0_num_graduates_prop`, `lag_1_num_graduates_prop`, and +`lag_2_num_graduates_prop` correspond to $y_{tijk}$, $y_{t-1,ijk}$, and $y_{t-2,ijk}$ +respectively. + +## Model fitting and prediction + +Since our goal for now is to fit a simple autoregressive model, we can use +[`parsnip::linear_reg()`]( + https://parsnip.tidymodels.org/reference/linear_reg.html) with the default +engine `lm`, which fits a linear regression using ordinary least squares. + +We will use `epi_workflow` with the `epi_recipe` we defined in the +pre-processing section along with the `parsnip::linear_reg()` model. Note +that `epi_workflow` is a container and doesn't actually do the fitting. We have +to pass the workflow into `fit()` to get our estimated model coefficients +$\widehat{\alpha}_i,\ i=0,...,3$. + +```{r linearreg-wf, include=T} +wf_linreg <- epi_workflow(r, linear_reg()) %>% + fit(employ_small) +summary(extract_fit_engine(wf_linreg)) +``` + +This output tells us the coefficients of the fitted model; for instance, +the estimated intercept is $\widehat{\alpha}_0 =$ +`r round(coef(extract_fit_engine(wf_linreg))[1], 3)` and the coefficient for +$y_{tijk}$ is +$\widehat\alpha_1 =$ `r round(coef(extract_fit_engine(wf_linreg))[2], 3)`. +The summary also tells us that all estimated coefficients are significantly +different from zero. Extracting the 95% confidence intervals for the +coefficients also leads us to +the same conclusion: all the coefficient estimates are significantly different +from 0. + +```{r} +confint(extract_fit_engine(wf_linreg)) +``` + +Now that we have our workflow, we can generate predictions from a subset of our +data. For this demo, we will predict the number of graduates using the last 2 +years of our dataset. + +```{r linearreg-predict, include=T} +latest <- get_test_data(recipe = r, x = employ_small) +preds <- stats::predict(wf_linreg, latest) %>% filter(!is.na(.pred)) +# Display a sample of the prediction values, excluding NAs +preds %>% sample_n(5) +``` + +We can do this using the `augment` function too. Note that `predict` and +`augment` both still return an `epiprocess::epi_df` with all of the keys that +were present in the original dataset. + +```{r linearreg-augment} +augment(wf_linreg, latest) %>% sample_n(5) +``` + +## Model diagnostics + +First, we'll plot the residuals (that is, $y_{tijk} - \widehat{y}_{tijk}$) +against the fitted values ($\widehat{y}_{tijk}$). + +```{r lienarreg-resid-plot, include=T, fig.height = 5, fig.width = 5} +par(mfrow = c(2, 2), mar = c(5, 3, 1.2, 0)) +plot(extract_fit_engine(wf_linreg)) +``` + +The fitted values vs. residuals plot shows us that the residuals are mostly +clustered around zero, but do not form an even band around the zero line, +indicating that the variance of the residuals is not constant. Additionally, +the fitted values vs. square root of standardized residuals makes this more +obvious - the spread of the square root of standardized residuals varies with +the fitted values. + +The Q-Q plot shows us that the residuals have heavier tails than a Normal +distribution. So the normality of residuals assumption doesn't hold either. + +Finally, the residuals vs. leverage plot shows us that we have a few influential +points based on the Cook's distance (those outside the red dotted line). + +Since we appear to be violating the linear model assumptions, we might consider +transforming our data differently, or considering a non-linear model, or +something else. + +# AR model with exogenous inputs + +Now suppose we want to model the 1-step-ahead 5-year employment income using +current and two previous values, while +also incorporating information from the other two time-series in our dataset: +the 2-year employment income and the number of graduates in the previous 2 +years. We would do this using an autoregressive model with exogenous inputs, +defined as follows: + +\[ +\begin{aligned} + y_{t+1,ijk} &= + \alpha_0 + \alpha_1 y_{tijk} + \alpha_2 y_{t-1,ijk} + \alpha_3 y_{t-2,ijk}\\ + &\quad + \beta_1 x_{tijk} + \beta_2 x_{t-1,ijk}\\ + &\quad + \gamma_2 z_{tijk} + \gamma_2 z_{t-1,ijk} + \epsilon_{tijk} +\end{aligned} +\] + +where $y_{tijk}$ is the 5-year median income (proportion) at time $t$ (in +location $i$, age group $j$ with education quality $k$), +$x_{tijk}$ is the 2-year median income (proportion) at time $t$, and +$z_{tijk}$ is the number of graduates (proportion) at time $t$. + +## Pre-processing + +Again, we construct an `epi_recipe` detailing the pre-processing steps. + +```{r custom-arx, include=T} +rx <- epi_recipe(employ_small) %>% + step_epi_ahead(med_income_5y_prop, ahead = 1) %>% + # 5-year median income has current, and two lags c(0, 1, 2) + step_epi_lag(med_income_5y_prop, lag = 0:2) %>% + # But the two exogenous variables have current values, and 1 lag c(0, 1) + step_epi_lag(med_income_2y_prop, lag = c(0, 1)) %>% + step_epi_lag(num_graduates_prop, lag = c(0, 1)) %>% + step_epi_naomit() + +bake_and_show_sample(rx, employ_small) +``` + +## Model fitting & post-processing + +Before fitting our model and making predictions, let's add some post-processing +steps using a few [`frosting`]( + https://cmu-delphi.github.io/epipredict/reference/frosting.html) layers to do +a few things: + +1. Threshold our predictions to 0. We are predicting proportions, which can't +be negative. And the transformed values back to dollars and people can't be +negative either. +1. Generate prediction intervals based on residual quantiles, allowing us to +quantify the uncertainty associated with future predicted values. +1. Convert our predictions back to income values and number of graduates, +rather than standardized proportions. We do this via the frosting layer +`layer_population_scaling()`. + + +```{r custom-arx-post, include=T} +# Create dataframe of the sums we used for standardizing +# Only have to include med_income_5y since that is our outcome +totals <- employ_small %>% + group_by(geo_value, age_group, edu_qual) %>% + summarise(med_income_5y_tot = sum(med_income_5y), .groups = "drop") + +# Define post-processing steps +f <- frosting() %>% + layer_predict() %>% + layer_naomit(.pred) %>% + layer_threshold(.pred, lower = 0) %>% + # 90% prediction interval + layer_residual_quantiles( + quantile_levels = c(0.1, 0.9), + symmetrize = FALSE + ) %>% + layer_population_scaling( + .pred, .pred_distn, + df = totals, df_pop_col = "med_income_5y_tot" + ) + +wfx_linreg <- epi_workflow(rx, parsnip::linear_reg()) %>% + fit(employ_small) %>% + add_frosting(f) + +summary(extract_fit_engine(wfx_linreg)) +``` + +Based on the summary output for this model, we can examine confidence intervals +and perform hypothesis tests as usual. + +Let's take a look at the predictions along with their 90% prediction intervals. + +```{r} +latest <- get_test_data(recipe = rx, x = employ_small) +predsx <- predict(wfx_linreg, latest) + +# Display predictions along with prediction intervals +predsx %>% + select( + geo_value, time_value, edu_qual, age_group, + .pred_scaled, .pred_distn_scaled + ) %>% + head() %>% + pivot_quantiles_wider(.pred_distn_scaled) +``` + +# Using canned forecasters + +We've seen what we can do with non-epidemiological panel data using the +recipes frame, with `epi_recipe` for pre-processing, `epi_workflow` for model +fitting, and `frosting` for post-processing. + +`epipredict` also comes with canned forecasters that do all of those steps +behind the scenes for some simple models. Even though we aren't working with +epidemiological data, canned forecasters still work as expected, out of the box. +We will demonstrate this with the simple +[`flatline_forecaster`]( + https://cmu-delphi.github.io/epipredict/reference/flatline_forecaster.html) +and the direct autoregressive (AR) forecaster +[`arx_forecaster`]( + https://cmu-delphi.github.io/epipredict/reference/arx_forecaster.html). + +For both illustrations, we will continue to use the `employ_small` dataset +with the transformed numeric columns that are proportions within each group +by the keys in our `epi_df`. + +## Flatline forecaster + +In this first example, we'll use `flatline_forecaster` to make a simple +prediction of the 2-year median income for the next year, based on one previous +time point. This model is representated algebraically as: +\[y_{t+1,ijk} = y_{tijk} + \epsilon_{tijk}\] +where $y_{tijk}$ is the 2-year median income (proportion) at time $t$. + +```{r flatline, include=T, warning=F} +out_fl <- flatline_forecaster(employ_small, "med_income_2y_prop", + args_list = flatline_args_list(ahead = 1) +) + +out_fl +``` + +## Autoregressive forecaster with exogenous inputs + +In this second example, we'll use `arx_forecaster` to make a prediction of the +5-year median income based using two lags, _and_ using two lags on two exogenous +variables: 2-year median income and number of graduates. + +The canned forecaster gives us a simple way of making this forecast since it +defines the recipe, workflow, and post-processing steps behind the scenes. This +is very similar to the model we introduced in the "Autoregressive Linear Model +with Exogenous Inputs" section of this article, but where all inputs have the +same number of lags. + +```{r arx-lr, include=T, warning=F} +arx_args <- arx_args_list(lags = c(0L, 1L), ahead = 1L) + +out_arx_lr <- arx_forecaster(employ_small, "med_income_5y_prop", + c("med_income_5y_prop", "med_income_2y_prop", "num_graduates_prop"), + args_list = arx_args +) + +out_arx_lr +``` + +Other changes to the direct AR forecaster, like changing the engine, also work +as expected. Below we use a boosted tree model instead of a linear regression. + +```{r arx-rf, include=T, warning=F} +out_arx_rf <- arx_forecaster( + employ_small, "med_income_5y_prop", + c("med_income_5y_prop", "med_income_2y_prop", "num_graduates_prop"), + trainer = parsnip::boost_tree(mode = "regression", trees = 20), + args_list = arx_args +) + +out_arx_rf +``` + +# Conclusion + +While the purpose of `{epipredict}` is to allow `{tidymodels}` to operate on +epidemiology data, it can be easily adapted (both the workflows and the canned +forecasters) to work for generic panel data modelling. + diff --git a/vignettes/preprocessing-and-models.Rmd b/vignettes/preprocessing-and-models.Rmd index f1cdb3c87..63a27bd55 100644 --- a/vignettes/preprocessing-and-models.Rmd +++ b/vignettes/preprocessing-and-models.Rmd @@ -58,7 +58,8 @@ and forecasts to characterize the state of an outbreak and its course. They use it to inform public health decision makers on potential consequences of deploying control measures. -One of the outcomes that the CDC forecasts is [death counts from COVID-19](https://www.cdc.gov/coronavirus/2019-ncov/science/forecasting/forecasting-us.html). +One of the outcomes that the CDC forecasts is +[death counts from COVID-19](https://www.cdc.gov/coronavirus/2019-ncov/science/forecasting/forecasting-us.html). Although there are many state-of-the-art models, we choose to use Poisson regression, the textbook example for modeling count data, as an illustration for using the `epipredict` package with other existing tidymodels packages. @@ -99,7 +100,7 @@ intercept coefficients, we can allow for an intercept shift between states. The model takes the form \begin{aligned} \log\left( \mu_{t+7} \right) &= \beta_0 + \delta_1 s_{\text{state}_1} + -\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ +\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ &\quad\beta_1 \text{deaths}_{t} + \beta_2 \text{deaths}_{t-7} + \beta_3 \text{cases}_{t} + \beta_4 \text{cases}_{t-7}, @@ -109,7 +110,8 @@ Poisson distribution with mean $\mu_{t+7}$; $s_{\text{state}}$ are dummy variables for each state and take values of either 0 or 1. Preprocessing steps will be performed to prepare the -data for model fitting. But before diving into them, it will be helpful to understand what `roles` are in the `recipes` framework. +data for model fitting. But before diving into them, it will be helpful to +understand what `roles` are in the `recipes` framework. --- @@ -198,7 +200,7 @@ To model death rates, the Poisson regression would be expressed as: \begin{aligned} \log\left( \mu_{t+7} \right) &= \log(\text{population}) + \beta_0 + \delta_1 s_{\text{state}_1} + -\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ +\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ &\quad\beta_1 \text{deaths}_{t} + \beta_2 \text{deaths}_{t-7} + \beta_3 \text{cases}_{t} + \beta_4 \text{cases}_{t-7} @@ -235,9 +237,10 @@ using `layer_residual_quantiles()` should be used before population scaling or else the transformation will make the results uninterpretable. -We wish, now, to predict the 7-day ahead death counts with lagged case rates and death -rates, along with some extra behaviourial predictors. Namely, we will use survey data -from [COVID-19 Trends and Impact Survey](https://cmu-delphi.github.io/delphi-epidata/api/covidcast-signals/fb-survey.html#behavior-indicators). +We wish, now, to predict the 7-day ahead death counts with lagged case rates and +death rates, along with some extra behaviourial predictors. Namely, we will use +survey data from +[COVID-19 Trends and Impact Survey](https://cmu-delphi.github.io/delphi-epidata/api/covidcast-signals/fb-survey.html#behavior-indicators). The survey data provides the estimated percentage of people who wore a mask for most or all of the time while in public in the past 7 days and the estimated @@ -361,8 +364,7 @@ wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.05, .5, .95))) %>% fit(jhu) %>% add_frosting(f) -latest <- get_test_data(recipe = r, x = jhu) -p <- predict(wf, latest) +p <- forecast(wf) p ``` @@ -414,7 +416,8 @@ We say location $\ell$ is a hotspot at time $t$ when $Z_{\ell,t}$ is `up`, meaning the number of newly reported cases over the past 7 days has increased by at least 25% compared to the preceding week. When $Z_{\ell,t}$ is categorized as `down`, it suggests that there has been at least a 20% -decrease in newly reported cases over the past 7 days (a 20% decrease is the inverse of a 25% increase). Otherwise, we will +decrease in newly reported cases over the past 7 days (a 20% decrease is the +inverse of a 25% increase). Otherwise, we will consider the trend to be `flat`. The expression of the multinomial regression we will use is as follows: @@ -431,16 +434,17 @@ g_{\text{flat}}(x) &= \log\left(\frac{Pr(Z_{\ell,t}=\text{flat}\mid x)}{Pr(Z_{\e \beta_{10} + \beta_{11} t + \delta_{10} s_{\text{state_1}} + \delta_{11} s_{\text{state_2}} + \cdots \nonumber \\ &\quad + \beta_{12} Y^{\Delta}_{\ell, t} + -\beta_{13} Y^{\Delta}_{\ell, t-7} \\ -g_{\text{flat}}(x) &= \log\left(\frac{Pr(Z_{\ell,t}=\text{up}\mid x)}{Pr(Z_{\ell,t}=\text{down} \mid x)}\right) = +\beta_{13} Y^{\Delta}_{\ell, t-7} + \beta_{14} Y^{\Delta}_{\ell, t-14}\\ +g_{\text{up}}(x) &= \log\left(\frac{Pr(Z_{\ell,t}=\text{up}\mid x)}{Pr(Z_{\ell,t}=\text{down} \mid x)}\right) = \beta_{20} + \beta_{21}t + \delta_{20} s_{\text{state_1}} + \delta_{21} s_{\text{state}\_2} + \cdots \nonumber \\ &\quad + \beta_{22} Y^{\Delta}_{\ell, t} + -\beta_{23} Y^{\Delta}\_{\ell, t-7} +\beta_{23} Y^{\Delta}_{\ell, t-7} + \beta_{24} Y^{\Delta}_{\ell, t-14} \end{aligned} Preprocessing steps are similar to the previous models with an additional step -of categorizing the response variables. Again, we will use a subset of death rate and case rate data from our built-in dataset +of categorizing the response variables. Again, we will use a subset of death +rate and case rate data from our built-in dataset `case_death_rate_subset`. ```{r} @@ -474,11 +478,10 @@ r <- epi_recipe(jhu) %>% We will fit the multinomial regression and examine the predictions: ```{r, warning=FALSE} -wf <- epi_workflow(r, parsnip::multinom_reg()) %>% +wf <- epi_workflow(r, multinom_reg()) %>% fit(jhu) -latest <- get_test_data(recipe = r, x = jhu) -predict(wf, latest) %>% filter(!is.na(.pred_class)) +forecast(wf) %>% filter(!is.na(.pred_class)) ``` We can also look at the estimated coefficients and model summary information: @@ -587,7 +590,12 @@ Sciences 118.51 (2021): e2111453118. [doi:10.1073/pnas.2111453118](https://doi.o ## Attribution -This object contains a modified part of the [COVID-19 Data Repository by the Center for Systems Science and Engineering (CSSE) at Johns Hopkins University](https://github.com/CSSEGISandData/COVID-19) as [republished in the COVIDcast Epidata API.](https://cmu-delphi.github.io/delphi-epidata/api/covidcast-signals/jhu-csse.html) +This object contains a modified part of the +[COVID-19 Data Repository by the Center for Systems Science and Engineering (CSSE) at Johns Hopkins University](https://github.com/CSSEGISandData/COVID-19) +as [republished in the COVIDcast Epidata API.](https://cmu-delphi.github.io/delphi-epidata/api/covidcast-signals/jhu-csse.html) -This data set is licensed under the terms of the [Creative Commons Attribution 4.0 International license](https://creativecommons.org/licenses/by/4.0/) by the Johns Hopkins University -on behalf of its Center for Systems Science in Engineering. Copyright Johns Hopkins University 2020. +This data set is licensed under the terms of the +[Creative Commons Attribution 4.0 International license](https://creativecommons.org/licenses/by/4.0/) +by the Johns Hopkins +University on behalf of its Center for Systems Science in Engineering. Copyright +Johns Hopkins University 2020. diff --git a/vignettes/update.Rmd b/vignettes/update.Rmd index 221e1c37e..fcd3653ca 100644 --- a/vignettes/update.Rmd +++ b/vignettes/update.Rmd @@ -2,7 +2,7 @@ title: "Using the add/update/remove and adjust functions" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{Using the update and adjust functions} + %\VignetteIndexEntry{Using the add/update/remove and adjust functions} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- @@ -39,8 +39,8 @@ add/remove/update an `epi_recipe` or a step in it. For this, we have `add_epi_recipe()`, `update_epi_recipe()`, and `remove_epi_recipe()` to add/update/remove an entire `epi_recipe` in an `epi_workflow` as well as `adjust_epi_recipe()` to adjust a particular step in an `epi_recipe` or -`epi_workflow` by the step number or name. For a model, one may `add_model()`, -`update_model()`, or `remove_model()` in an `epi_workflow`. For post-processing, +`epi_workflow` by the step number or name. For a model, one may `Add_model()`, +`Update_model()`, or `Remove_model()` in an `epi_workflow`.[^1] For post-processing, where the goal is to update a frosting object or a layer in it, we have `add_frosting()`, `remove_frosting()`, and `update_frosting()` to add/update/remove an entire `frosting` object in an `epi_workflow` as well as @@ -51,9 +51,14 @@ processing step is shown by the following table: | | Add/update/remove functions | adjust functions | |----------------------------|------------------------------------------------------------|---------------------| | Pre-processing | `add_epi_recipe()`, `update_epi_recipe()`, `remove_epi_recipe()` | `adjust_epi_recipe()` | -| Model specification | `add_model()`, `update_model()` `remove_model()` | | +| Model specification | `Add_model()`, `Update_model()` `Remove_model()` | | | Post-processing | `add_frosting()`, `remove_frosting()`, `update_frosting()` | `adjust_frosting()` | +[^1]: We capitalize these names to avoid possible clashes with the `{workflows}` +versions of these functions. The lower-case versions are also available, +however, if you load `{workflows}` after `{epipredict}`, these will be masked +and may not work as expected. + Since adding/removing/updating frosting as well as adjusting a layer in a `frosting` object proceeds in the same way as performing those tasks on an `epi_recipe`, we will focus on implementing those for an `epi_recipe` in this @@ -162,16 +167,16 @@ point - Any operations performed using the old recipe are not updated automatically. So we should be careful to fit the model using the new recipe, `r2`. Similarly, if predictions were made using the old recipe, then they should be re-generated using the version `epi_workflow` that contains the updated -recipe. We can use `update_model()` to replace the model used in `wf`, and then +recipe. We can use `Update_model()` to replace the model used in `wf`, and then fit as before: ```{r} # fit linear model -wf <- update_model(wf, parsnip::linear_reg()) %>% fit(jhu) +wf <- Update_model(wf, parsnip::linear_reg()) %>% fit(jhu) wf ``` -Alternatively, we may use the `remove_model()` followed by `add_model()` +Alternatively, we may use the `Remove_model()` followed by `Add_model()` combination for the same effect. ## Add/update/remove a `frosting` object in an `epi_workflow` @@ -181,13 +186,11 @@ predictions. In our initial frosting object, `f`, we simply implement predictions on the fitted `epi_workflow`: ```{r} -latest <- get_test_data(recipe = r2, x = jhu) - f <- frosting() %>% layer_predict() wf1 <- wf %>% add_frosting(f) -p1 <- predict(wf1, latest) +p1 <- forecast(wf1) p1 ``` @@ -207,7 +210,7 @@ f2 <- frosting() %>% layer_add_target_date() wf2 <- wf1 %>% update_frosting(f2) -p2 <- predict(wf2, latest) +p2 <- forecast(wf2) p2 ``` @@ -223,7 +226,7 @@ remove the `frosting` object from the workflow and make predictions as follows: ```{r} wf3 <- wf2 %>% remove_frosting() -p3 <- predict(wf3, latest) +p3 <- forecast(wf3) p3 ``` @@ -285,7 +288,7 @@ Note that when we adjust the `r2` object directly, we are not adjusting the recipe in the `epi_workflow`. That is, if we modify a step in `r2`, the change will not automatically transfer over to `wf`. We would need to modify the recipe in `wf` directly (`adjust_epi_recipe()` on `wf`) or update the recipe in `wf` -with a new `epi_recipe` that has undergone the adjustment +with a new `epi_recipe` that has undergone the adjustment (using `update_epi_recipe()`): ```{r}