diff --git a/DESCRIPTION b/DESCRIPTION index 1126f8304..a219637b6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.16 +Version: 0.0.17 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), diff --git a/NAMESPACE b/NAMESPACE index 708c91e06..941ea1542 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -208,6 +208,7 @@ import(parsnip) import(recipes) importFrom(checkmate,assert) importFrom(checkmate,assert_character) +importFrom(checkmate,assert_class) importFrom(checkmate,assert_date) importFrom(checkmate,assert_function) importFrom(checkmate,assert_int) diff --git a/NEWS.md b/NEWS.md index bf3f4d9d5..4e21f8191 100644 --- a/NEWS.md +++ b/NEWS.md @@ -47,3 +47,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Revise `compat-purrr` to use the r-lang `standalone-*` version (via `{usethis}`) - `epi_recipe()` will now warn when given non-`epi_df` data +- `layer_predict()` and `predict.epi_workflow()` will now appropriately forward + `...` args intended for `predict.model_fit()` +- `bake.epi_recipe()` will now re-infer the geo and time type in case baking the + steps has changed the appropriate values diff --git a/R/arx_classifier.R b/R/arx_classifier.R index de730826c..44acb9b30 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -1,7 +1,7 @@ #' Direct autoregressive classifier with covariates #' #' This is an autoregressive classification model for -#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning #' that it estimates a class at a particular target horizon. #' #' @inheritParams arx_forecaster diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 10b2d2bce..1b9e3d503 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -1,7 +1,7 @@ #' Direct autoregressive forecaster with covariates #' #' This is an autoregressive forecasting model for -#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning #' that it estimates a model for a particular target horizon. #' #' diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index 4af6d6f3f..d5b74a9c3 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -1,7 +1,7 @@ #' Predict the future with the most recent value #' #' This is a simple forecasting model for -#' [epiprocess::epi_df] data. It uses the most recent observation as the +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the #' forecast for any future date, and produces intervals by shuffling the quantiles #' of the residuals of such a "flatline" forecast and incrementing these #' forward over all available training data. @@ -12,7 +12,7 @@ #' This forecaster is meant to produce exactly the CDC Baseline used for #' [COVID19ForecastHub](https://covid19forecasthub.org) #' -#' @param epi_data An [`epiprocess::epi_df`] +#' @param epi_data An [`epiprocess::epi_df`][epiprocess::as_epi_df] #' @param outcome A scalar character for the column name we wish to predict. #' @param args_list A list of additional arguments as created by the #' [cdc_baseline_args_list()] constructor function. diff --git a/R/data.R b/R/data.R index 6641abf44..71e5bdcd3 100644 --- a/R/data.R +++ b/R/data.R @@ -59,7 +59,7 @@ #' Subset of Statistics Canada median employment income for postsecondary graduates #' -#' @format An [epiprocess::epi_df] with 10193 rows and 8 variables: +#' @format An [epiprocess::epi_df][epiprocess::as_epi_df] with 10193 rows and 8 variables: #' \describe{ #' \item{geo_value}{The province in Canada associated with each #' row of measurements.} diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 1a1cd1455..6d01d718f 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -245,7 +245,7 @@ is_epi_recipe <- function(x) { #' @details #' `add_epi_recipe` has the same behaviour as #' [workflows::add_recipe()] but sets a different -#' default blueprint to automatically handle [epiprocess::epi_df] data. +#' default blueprint to automatically handle [epiprocess::epi_df][epiprocess::as_epi_df] data. #' #' @param x A `workflow` or `epi_workflow` #' @@ -572,9 +572,13 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") { } new_data <- NextMethod("bake") if (!is.null(meta)) { + # Baking should have dropped epi_df-ness and metadata. Re-infer some + # metadata and assume others remain the same as the object/template: new_data <- as_epi_df( - new_data, meta$geo_type, meta$time_type, meta$as_of, - meta$additional_metadata %||% list() + new_data, + as_of = meta$as_of, + # avoid NULL if meta is from saved older epi_df: + additional_metadata = meta$additional_metadata %||% list() ) } new_data diff --git a/R/epi_workflow.R b/R/epi_workflow.R index c6f1e43a9..0bdeece4f 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -119,18 +119,18 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' - Call [parsnip::predict.model_fit()] for you using the underlying fit #' parsnip model. #' -#' - Ensure that the returned object is an [epiprocess::epi_df] where +#' - Ensure that the returned object is an [epiprocess::epi_df][epiprocess::as_epi_df] where #' possible. Specifically, the output will have `time_value` and #' `geo_value` columns as well as the prediction. #' -#' @inheritParams parsnip::predict.model_fit -#' #' @param object An epi_workflow that has been fit by #' [workflows::fit.workflow()] #' #' @param new_data A data frame containing the new predictors to preprocess #' and predict on #' +#' @inheritParams parsnip::predict.model_fit +#' #' @return #' A data frame of model predictions, with as many rows as `new_data` has. #' If `new_data` is an `epi_df` or a data frame with `time_value` or @@ -152,7 +152,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' #' preds <- predict(wf, latest) #' preds -predict.epi_workflow <- function(object, new_data, ...) { +predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) { if (!workflows::is_trained_workflow(object)) { cli::cli_abort(c( "Can't predict on an untrained epi_workflow.", @@ -168,7 +168,7 @@ predict.epi_workflow <- function(object, new_data, ...) { components$forged, components$mold, new_data ) - components <- apply_frosting(object, components, new_data, ...) + components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) components$predictions } diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 4bd37c519..7746281ba 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -6,7 +6,7 @@ #' @importFrom cli cli_abort #' @importFrom checkmate assert assert_character assert_int assert_scalar #' assert_logical assert_numeric assert_number assert_integer -#' assert_integerish assert_date assert_function +#' assert_integerish assert_date assert_function assert_class #' @import epiprocess parsnip ## usethis namespace: end NULL diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index fa80dfba5..e14e44a96 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -1,7 +1,7 @@ #' Predict the future with today's value #' #' This is a simple forecasting model for -#' [epiprocess::epi_df] data. It uses the most recent observation as the +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the #' forcast for any future date, and produces intervals based on the quantiles #' of the residuals of such a "flatline" forecast over all available training #' data. @@ -13,7 +13,7 @@ #' This forecaster is very similar to that used by the #' [COVID19ForecastHub](https://covid19forecasthub.org) #' -#' @param epi_data An [epiprocess::epi_df] +#' @param epi_data An [epiprocess::epi_df][epiprocess::as_epi_df] #' @param outcome A scalar character for the column name we wish to predict. #' @param args_list A list of dditional arguments as created by the #' [flatline_args_list()] constructor function. diff --git a/R/frosting.R b/R/frosting.R index f9c5867a4..4fc0caec3 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -355,9 +355,11 @@ apply_frosting.default <- function(workflow, components, ...) { #' @rdname apply_frosting #' @importFrom rlang is_null #' @importFrom rlang abort +#' @param type,opts forwarded (along with `...`) to [`predict.model_fit()`] and +#' [`slather()`] for supported layers #' @export apply_frosting.epi_workflow <- - function(workflow, components, new_data, ...) { + function(workflow, components, new_data, type = NULL, opts = list(), ...) { the_fit <- workflows::extract_fit_parsnip(workflow) if (!has_postprocessor(workflow)) { @@ -376,7 +378,7 @@ apply_frosting.epi_workflow <- "Returning unpostprocessed predictions." )) components$predictions <- predict( - the_fit, components$forged$predictors, ... + the_fit, components$forged$predictors, type, opts, ... ) components$predictions <- dplyr::bind_cols( components$keys, components$predictions @@ -397,10 +399,28 @@ apply_frosting.epi_workflow <- layers ) } + if (length(layers) > 1L && + (!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) { + cli_abort(" + Passing `type`, `opts`, or `...` into `predict.epi_workflow()` is not + supported if you have frosting layers other than `layer_predict`. Please + provide these arguments earlier (i.e. while constructing the frosting + object) by passing them into an explicit call to `layer_predict(), and + adjust the remaining layers to account for resulting differences in + output format from these settings. + ", class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers") + } for (l in seq_along(layers)) { la <- layers[[l]] - components <- slather(la, components, workflow, new_data) + if (inherits(la, "layer_predict")) { + components <- slather(la, components, workflow, new_data, type = type, opts = opts, ...) + } else { + # The check above should ensure we have default `type` and `opts`, and + # empty `...`; don't forward these default `type` and `opts`, to avoid + # upsetting some slather method validation. + components <- slather(la, components, workflow, new_data) + } } return(components) diff --git a/R/get_test_data.R b/R/get_test_data.R index e76715daf..2a2484749 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -1,7 +1,7 @@ #' Get test data for prediction based on longest lag period #' #' Based on the longest lag period in the recipe, -#' `get_test_data()` creates an [epi_df] +#' `get_test_data()` creates an [epi_df][epiprocess::as_epi_df] #' with columns `geo_value`, `time_value` #' and other variables in the original dataset, #' which will be used to create features necessary to produce forecasts. diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 2174b7330..c4bb7d483 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -86,6 +86,7 @@ layer_add_forecast_date_new <- function(forecast_date, id) { #' @export slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() if (is.null(object$forecast_date)) { max_time_value <- as.Date(max( workflows::extract_preprocessor(workflow)$max_time_value, diff --git a/R/layer_naomit.R b/R/layer_naomit.R index ad6c5606c..85842bfdf 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -45,6 +45,7 @@ layer_naomit_new <- function(terms, id) { #' @export slather.layer_naomit <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layer_point_from_distn.R b/R/layer_point_from_distn.R index 52ecef3cc..f415e7bd4 100644 --- a/R/layer_point_from_distn.R +++ b/R/layer_point_from_distn.R @@ -76,7 +76,6 @@ layer_point_from_distn_new <- function(type, name, id) { #' @export slather.layer_point_from_distn <- function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { rlang::warn( @@ -86,6 +85,7 @@ slather.layer_point_from_distn <- ) return(components) } + rlang::check_dots_empty() dstn <- match.fun(object$type)(dstn) if (is.null(object$name)) { diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 1d02604e5..33183198d 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -128,11 +128,11 @@ layer_population_scaling_new <- #' @export slather.layer_population_scaling <- function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() stopifnot( "Only one population column allowed for scaling" = length(object$df_pop_col) == 1 ) + rlang::check_dots_empty() if (is.null(object$by)) { object$by <- intersect( diff --git a/R/layer_predict.R b/R/layer_predict.R index b40c24be5..46d81be18 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -45,12 +45,19 @@ layer_predict <- id = rand_id("predict_default")) { arg_is_chr_scalar(id) arg_is_chr_scalar(type, allow_null = TRUE) + assert_class(opts, "list") + dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE) + if (any(rlang::names2(dots_list) == "")) { + cli_abort("All `...` arguments must be named.", + class = "epipredict__layer_predict__unnamed_dot" + ) + } add_layer( frosting, layer_predict_new( type = type, opts = opts, - dots_list = rlang::list2(...), # can't figure how to use this + dots_list = dots_list, id = id ) ) @@ -62,14 +69,27 @@ layer_predict_new <- function(type, opts, dots_list, id) { } #' @export -slather.layer_predict <- function(object, components, workflow, new_data, ...) { +slather.layer_predict <- function(object, components, workflow, new_data, type = NULL, opts = list(), ...) { + arg_is_chr_scalar(type, allow_null = TRUE) + if (!is.null(object$type) && !is.null(type) && !identical(object$type, type)) { + cli_abort(" + Conflicting `type` settings were specified during frosting construction + (in call to `layer_predict()`) and while slathering (in call to + `slather()`/ `predict()`/etc.): {object$type} vs. {type}. Please remove + one of these `type` settings. + ", class = "epipredict__layer_predict__conflicting_type_settings") + } + assert_class(opts, "list") + the_fit <- workflows::extract_fit_parsnip(workflow) - components$predictions <- predict( + components$predictions <- rlang::inject(predict( the_fit, components$forged$predictors, - type = object$type, opts = object$opts - ) + type = object$type %||% type, + opts = c(object$opts, opts), + !!!object$dots_list, ... + )) components$predictions <- dplyr::bind_cols( components$keys, components$predictions ) diff --git a/R/layer_predictive_distn.R b/R/layer_predictive_distn.R index 652e42368..9b1a160e1 100644 --- a/R/layer_predictive_distn.R +++ b/R/layer_predictive_distn.R @@ -73,6 +73,7 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) { slather.layer_predictive_distn <- function(object, components, workflow, new_data, ...) { the_fit <- workflows::extract_fit_parsnip(workflow) + rlang::check_dots_empty() m <- components$predictions$.pred r <- grab_residuals(the_fit, components) diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index d763207a4..734ccec9e 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -79,6 +79,8 @@ slather.layer_quantile_distn <- "These are of class {.cls {class(dstn)}}." )) } + rlang::check_dots_empty() + dstn <- dist_quantiles( quantile(dstn, object$quantile_levels), object$quantile_levels diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 514bddc5f..85c1c6ed0 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -75,6 +75,8 @@ layer_residual_quantiles_new <- function( #' @export slather.layer_residual_quantiles <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() + the_fit <- workflows::extract_fit_parsnip(workflow) if (is.null(object$quantile_levels)) { diff --git a/R/layer_threshold_preds.R b/R/layer_threshold_preds.R index ef1781a3c..8b2b56d1e 100644 --- a/R/layer_threshold_preds.R +++ b/R/layer_threshold_preds.R @@ -98,6 +98,7 @@ snap.dist_quantiles <- function(x, lower, upper, ...) { #' @export slather.layer_threshold <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layer_unnest.R b/R/layer_unnest.R index 64b17a306..dfc391942 100644 --- a/R/layer_unnest.R +++ b/R/layer_unnest.R @@ -28,6 +28,7 @@ layer_unnest_new <- function(terms, id) { #' @export slather.layer_unnest <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/inst/templates/layer.R b/inst/templates/layer.R index 3fecb3c33..59556db5f 100644 --- a/inst/templates/layer.R +++ b/inst/templates/layer.R @@ -29,6 +29,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) { #' @export slather.layer_{{{ name }}} <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() # if layer_ used ... in tidyselect, we need to evaluate it now exprs <- rlang::expr(c(!!!object$terms)) diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index fc01a3461..ef18796cc 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -11,7 +11,7 @@ apply_frosting(workflow, ...) \method{apply_frosting}{default}(workflow, components, ...) -\method{apply_frosting}{epi_workflow}(workflow, components, new_data, ...) +\method{apply_frosting}{epi_workflow}(workflow, components, new_data, type = NULL, opts = list(), ...) } \arguments{ \item{workflow}{An object of class workflow} @@ -34,6 +34,9 @@ here for ease. \item{new_data}{a data frame containing the new predictors to preprocess and predict on} + +\item{type, opts}{forwarded (along with \code{...}) to \code{\link[=predict.model_fit]{predict.model_fit()}} and +\code{\link[=slather]{slather()}} for supported layers} } \description{ This function is intended for internal use. It implements postprocessing diff --git a/man/get_test_data.Rd b/man/get_test_data.Rd index 392d1dce2..b18685d89 100644 --- a/man/get_test_data.Rd +++ b/man/get_test_data.Rd @@ -37,7 +37,7 @@ keys, as well other variables in the original dataset. } \description{ Based on the longest lag period in the recipe, -\code{get_test_data()} creates an \link{epi_df} +\code{get_test_data()} creates an \link[epiprocess:epi_df]{epi_df} with columns \code{geo_value}, \code{time_value} and other variables in the original dataset, which will be used to create features necessary to produce forecasts. diff --git a/man/predict-epi_workflow.Rd b/man/predict-epi_workflow.Rd index d92fd8ca9..130279249 100644 --- a/man/predict-epi_workflow.Rd +++ b/man/predict-epi_workflow.Rd @@ -5,7 +5,7 @@ \alias{predict.epi_workflow} \title{Predict from an epi_workflow} \usage{ -\method{predict}{epi_workflow}(object, new_data, ...) +\method{predict}{epi_workflow}(object, new_data, type = NULL, opts = list(), ...) } \arguments{ \item{object}{An epi_workflow that has been fit by @@ -14,6 +14,16 @@ \item{new_data}{A data frame containing the new predictors to preprocess and predict on} +\item{type}{A single character value or \code{NULL}. Possible values +are \code{"numeric"}, \code{"class"}, \code{"prob"}, \code{"conf_int"}, \code{"pred_int"}, +\code{"quantile"}, \code{"time"}, \code{"hazard"}, \code{"survival"}, or \code{"raw"}. When \code{NULL}, +\code{predict()} will choose an appropriate value based on the model's mode.} + +\item{opts}{A list of optional arguments to the underlying +predict function that will be used when \code{type = "raw"}. The +list should not include options for the model object or the +new data being predicted.} + \item{...}{Additional \code{parsnip}-related options, depending on the value of \code{type}. Arguments to the underlying model's prediction function cannot be passed here (use the \code{opts} argument instead). diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index 8af9f1c39..5cab9c494 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -86,3 +86,60 @@ test_that("layer_predict is added by default if missing", { expect_equal(forecast(wf1), forecast(wf2)) }) + + +test_that("parsnip settings can be passed through predict.epi_workflow", { + jhu <- case_death_rate_subset %>% + dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + + r <- epi_recipe(jhu) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() + + wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) + + latest <- get_test_data(r, jhu) + + f1 <- frosting() %>% layer_predict() + f2 <- frosting() %>% layer_predict(type = "pred_int") + f3 <- frosting() %>% layer_predict(type = "pred_int", level = 0.6) + + pred2 <- wf %>% + add_frosting(f2) %>% + predict(latest) + pred3 <- wf %>% + add_frosting(f3) %>% + predict(latest) + + pred2_re <- wf %>% + add_frosting(f1) %>% + predict(latest, type = "pred_int") + pred3_re <- wf %>% + add_frosting(f1) %>% + predict(latest, type = "pred_int", level = 0.6) + + expect_identical(pred2, pred2_re) + expect_identical(pred3, pred3_re) + + expect_error(wf %>% add_frosting(f2) %>% predict(latest, type = "raw"), + class = "epipredict__layer_predict__conflicting_type_settings" + ) + + f4 <- frosting() %>% + layer_predict() %>% + layer_threshold(.pred, lower = 0) + + expect_error(wf %>% add_frosting(f4) %>% predict(latest, type = "pred_int"), + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers" + ) + + # We also refuse to continue when just passing the level, which might not be ideal: + f5 <- frosting() %>% + layer_predict(type = "pred_int") %>% + layer_threshold(.pred_lower, .pred_upper, lower = 0) + + expect_error(wf %>% add_frosting(f5) %>% predict(latest, level = 0.6), + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers" + ) +}) diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index bd10de08c..041516b29 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -31,3 +31,63 @@ test_that("prediction with interval works", { expect_equal(nrow(p), 108L) expect_named(p, c("geo_value", "time_value", ".pred_lower", ".pred_upper")) }) + +test_that("layer_predict dots validation", { + # We balk at unnamed arguments, though perhaps not with the most helpful error messages: + expect_error( + frosting() %>% layer_predict("pred_int", list(), tibble::tibble(x = 5)), + class = "epipredict__layer_predict__unnamed_dot" + ) + expect_error( + frosting() %>% layer_predict("pred_int", list(), "maybe_meant_to_be_id"), + class = "epipredict__layer_predict__unnamed_dot" + ) + # We allow arguments that might actually work at prediction time: + expect_no_error(frosting() %>% layer_predict(type = "quantile", interval = "confidence")) + + # We don't detect completely-bogus arg names until predict time: + expect_no_error(f_bad_arg <- frosting() %>% layer_predict(bogus_argument = "something")) + wf_bad_arg <- wf %>% add_frosting(f_bad_arg) + expect_error(predict(wf_bad_arg, latest)) + # ^ (currently with a awful error message, due to an extra comma in parsnip::check_pred_type_dots) + + # Some argument names only apply for some prediction `type`s; we don't check + # for invalid pairings, nor does {parsnip}, so we end up producing a forecast + # that silently ignores some arguments some of the time. ({workflows} doesn't + # check for these either.) + expect_no_error(frosting() %>% layer_predict(eval_time = "preferably this would error")) +}) + +test_that("layer_predict dots are forwarded", { + f_lm_int_level_95 <- frosting() %>% + layer_predict(type = "pred_int") + f_lm_int_level_80 <- frosting() %>% + layer_predict(type = "pred_int", level = 0.8) + wf_lm_int_level_95 <- wf %>% add_frosting(f_lm_int_level_95) + wf_lm_int_level_80 <- wf %>% add_frosting(f_lm_int_level_80) + p <- predict(wf, latest) + p_lm_int_level_95 <- predict(wf_lm_int_level_95, latest) + p_lm_int_level_80 <- predict(wf_lm_int_level_80, latest) + expect_contains(names(p_lm_int_level_95), c(".pred_lower", ".pred_upper")) + expect_contains(names(p_lm_int_level_80), c(".pred_lower", ".pred_upper")) + expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_95))) + expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_80))) + expect_true( + cbind( + p, + p_lm_int_level_95 %>% dplyr::select(.pred_lower_95 = .pred_lower, .pred_upper_95 = .pred_upper), + p_lm_int_level_80 %>% dplyr::select(.pred_lower_80 = .pred_lower, .pred_upper_80 = .pred_upper) + ) %>% + na.omit() %>% + mutate( + sandwiched = + .pred_lower_95 <= .pred_lower_80 & + .pred_lower_80 <= .pred & + .pred <= .pred_upper_80 & + .pred_upper_80 <= .pred_upper_95 + ) %>% + `[[`("sandwiched") %>% + all() + ) + # There are many possible other valid configurations that aren't tested here. +}) diff --git a/tests/testthat/test-pad_to_end.R b/tests/testthat/test-pad_to_end.R index 474b9001b..0ea6244b0 100644 --- a/tests/testthat/test-pad_to_end.R +++ b/tests/testthat/test-pad_to_end.R @@ -32,6 +32,6 @@ test_that("test set padding works", { # make sure it maintains the epi_df dat <- dat %>% dplyr::rename(geo_value = gr1) %>% - as_epi_df(dat) + as_epi_df() expect_s3_class(pad_to_end(dat, "geo_value", 2), "epi_df") })