diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..362fafd1d --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,3 @@ +# using styler at all +aca7d5e7b66d8bac9d9fbcec3acdb98a087d58fa +f12fcc2bf3fe0a75ba2b10eaaf8a1f1d22486a17 diff --git a/.github/workflows/styler.yml b/.github/workflows/styler.yml new file mode 100644 index 000000000..c78ae8dd4 --- /dev/null +++ b/.github/workflows/styler.yml @@ -0,0 +1,80 @@ +# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples +# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help +on: + push: + paths: + [ + "**.[rR]", + "**.[qrR]md", + "**.[rR]markdown", + "**.[rR]nw", + "**.[rR]profile", + ] + +name: Style + +jobs: + style: + runs-on: ubuntu-latest + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Setup R + uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + + - name: Install dependencies + uses: r-lib/actions/setup-r-dependencies@v2 + with: + extra-packages: any::styler, any::roxygen2 + needs: styler + + - name: Enable styler cache + run: styler::cache_activate() + shell: Rscript {0} + + - name: Determine cache location + id: styler-location + run: | + cat( + "location=", + styler::cache_info(format = "tabular")$location, + "\n", + file = Sys.getenv("GITHUB_OUTPUT"), + append = TRUE, + sep = "" + ) + shell: Rscript {0} + + - name: Cache styler + uses: actions/cache@v3 + with: + path: ${{ steps.styler-location.outputs.location }} + key: ${{ runner.os }}-styler-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-styler- + ${{ runner.os }}- + + - name: Style + run: styler::style_pkg() + shell: Rscript {0} + + - name: Commit and push changes + run: | + if FILES_TO_COMMIT=($(git diff-index --name-only ${{ github.sha }} \ + | egrep --ignore-case '\.(R|[qR]md|Rmarkdown|Rnw|Rprofile)$')) + then + git config --local user.name "$GITHUB_ACTOR" + git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" + git commit ${FILES_TO_COMMIT[*]} -m "Style code (GHA)" + git pull --ff-only + git push origin + else + echo "No changes to commit." + fi diff --git a/R/arx_classifier.R b/R/arx_classifier.R index a8a8ea2b2..9370da423 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -47,9 +47,9 @@ arx_classifier <- function( predictors, trainer = parsnip::logistic_reg(), args_list = arx_class_args_list()) { - - if (!is_classification(trainer)) + if (!is_classification(trainer)) { cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") + } wf <- arx_class_epi_workflow( epi_data, outcome, predictors, trainer, args_list @@ -65,13 +65,15 @@ arx_classifier <- function( tibble::as_tibble() %>% dplyr::select(-time_value) - structure(list( - predictions = preds, - epi_workflow = wf, - metadata = list( - training = attr(epi_data, "metadata"), - forecast_created = Sys.time() - )), + structure( + list( + predictions = preds, + epi_workflow = wf, + metadata = list( + training = attr(epi_data, "metadata"), + forecast_created = Sys.time() + ) + ), class = c("arx_class", "canned_epipred") ) } @@ -117,12 +119,13 @@ arx_class_epi_workflow <- function( predictors, trainer = NULL, args_list = arx_class_args_list()) { - validate_forecaster_inputs(epi_data, outcome, predictors) - if (!inherits(args_list, c("arx_class", "alist"))) + if (!inherits(args_list, c("arx_class", "alist"))) { rlang::abort("args_list was not created using `arx_class_args_list().") - if (!(is.null(trainer) || is_classification(trainer))) + } + if (!(is.null(trainer) || is_classification(trainer))) { rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.") + } lags <- arx_lags_validator(predictors, args_list$lags) # --- preprocessor @@ -172,8 +175,10 @@ arx_class_epi_workflow <- function( o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o)) r <- r %>% step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>% - step_mutate(outcome_class = cut(!!o2, breaks = args_list$breaks), - role = "outcome") %>% + step_mutate( + outcome_class = cut(!!o2, breaks = args_list$breaks), + role = "outcome" + ) %>% step_epi_naomit() %>% step_training_window(n_recent = args_list$n_training) @@ -245,9 +250,7 @@ arx_class_args_list <- function( method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), log_scale = FALSE, additional_gr_args = list(), - nafill_buffer = Inf -) { - + nafill_buffer = Inf) { .lags <- lags if (is.list(lags)) lags <- unlist(lags) method <- match.arg(method) @@ -266,7 +269,8 @@ arx_class_args_list <- function( cli::cli_abort( c("`additional_gr_args` must be a {.cls list}.", "!" = "This is a {.cls {class(additional_gr_args)}}.", - i = "See `?epiprocess::growth_rate` for available arguments.") + i = "See `?epiprocess::growth_rate` for available arguments." + ) ) } @@ -277,19 +281,20 @@ arx_class_args_list <- function( max_lags <- max(lags) structure( - enlist(lags = .lags, - ahead, - n_training, - breaks, - forecast_date, - target_date, - outcome_transform, - max_lags, - horizon, - method, - log_scale, - additional_gr_args, - nafill_buffer + enlist( + lags = .lags, + ahead, + n_training, + breaks, + forecast_date, + target_date, + outcome_transform, + max_lags, + horizon, + method, + log_scale, + additional_gr_args, + nafill_buffer ), class = c("arx_class", "alist") ) @@ -300,4 +305,3 @@ print.arx_class <- function(x, ...) { name <- "ARX Classifier" NextMethod(name = name, ...) } - diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 172daa17a..2e242d770 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -25,20 +25,24 @@ #' jhu <- case_death_rate_subset %>% #' dplyr::filter(time_value >= as.Date("2021-12-01")) #' -#' out <- arx_forecaster(jhu, "death_rate", -#' c("case_rate", "death_rate")) +#' out <- arx_forecaster( +#' jhu, "death_rate", +#' c("case_rate", "death_rate") +#' ) #' #' out <- arx_forecaster(jhu, "death_rate", -#' c("case_rate", "death_rate"), trainer = quantile_reg(), -#' args_list = arx_args_list(levels = 1:9 / 10)) +#' c("case_rate", "death_rate"), +#' trainer = quantile_reg(), +#' args_list = arx_args_list(levels = 1:9 / 10) +#' ) arx_forecaster <- function(epi_data, outcome, predictors, trainer = parsnip::linear_reg(), args_list = arx_args_list()) { - - if (!is_regression(trainer)) + if (!is_regression(trainer)) { cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") + } wf <- arx_fcast_epi_workflow( epi_data, outcome, predictors, trainer, args_list @@ -54,13 +58,15 @@ arx_forecaster <- function(epi_data, tibble::as_tibble() %>% dplyr::select(-time_value) - structure(list( - predictions = preds, - epi_workflow = wf, - metadata = list( - training = attr(epi_data, "metadata"), - forecast_created = Sys.time() - )), + structure( + list( + predictions = preds, + epi_workflow = wf, + metadata = list( + training = attr(epi_data, "metadata"), + forecast_created = Sys.time() + ) + ), class = c("arx_fcast", "canned_epipred") ) } @@ -85,25 +91,30 @@ arx_forecaster <- function(epi_data, #' jhu <- case_death_rate_subset %>% #' dplyr::filter(time_value >= as.Date("2021-12-01")) #' -#' arx_fcast_epi_workflow(jhu, "death_rate", -#' c("case_rate", "death_rate")) +#' arx_fcast_epi_workflow( +#' jhu, "death_rate", +#' c("case_rate", "death_rate") +#' ) #' #' arx_fcast_epi_workflow(jhu, "death_rate", -#' c("case_rate", "death_rate"), trainer = quantile_reg(), -#' args_list = arx_args_list(levels = 1:9 / 10)) +#' c("case_rate", "death_rate"), +#' trainer = quantile_reg(), +#' args_list = arx_args_list(levels = 1:9 / 10) +#' ) arx_fcast_epi_workflow <- function( epi_data, outcome, predictors, trainer = NULL, args_list = arx_args_list()) { - # --- validation validate_forecaster_inputs(epi_data, outcome, predictors) - if (!inherits(args_list, c("arx_fcast", "alist"))) + if (!inherits(args_list, c("arx_fcast", "alist"))) { cli::cli_abort("args_list was not created using `arx_args_list().") - if (!(is.null(trainer) || is_regression(trainer))) + } + if (!(is.null(trainer) || is_regression(trainer))) { cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.") + } lags <- arx_lags_validator(predictors, args_list$lags) # --- preprocessor @@ -126,15 +137,17 @@ arx_fcast_epi_workflow <- function( # add all levels to the forecaster and update postprocessor tau <- sort(compare_quantile_args( args_list$levels, - rlang::eval_tidy(trainer$args$tau)) - ) + rlang::eval_tidy(trainer$args$tau) + )) args_list$levels <- tau trainer$args$tau <- rlang::enquo(tau) f <- layer_quantile_distn(f, levels = tau) %>% layer_point_from_distn() } else { f <- layer_residual_quantiles( - f, probs = args_list$levels, symmetrize = args_list$symmetrize, - by_key = args_list$quantile_by_key) + f, + probs = args_list$levels, symmetrize = args_list$symmetrize, + by_key = args_list$quantile_by_key + ) } f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>% layer_add_target_date(target_date = target_date) @@ -204,7 +217,6 @@ arx_args_list <- function( nonneg = TRUE, quantile_by_key = character(0L), nafill_buffer = Inf) { - # error checking if lags is a list .lags <- lags if (is.list(lags)) lags <- unlist(lags) @@ -222,17 +234,19 @@ arx_args_list <- function( max_lags <- max(lags) structure( - enlist(lags = .lags, - ahead, - n_training, - levels, - forecast_date, - target_date, - symmetrize, - nonneg, - max_lags, - quantile_by_key, - nafill_buffer), + enlist( + lags = .lags, + ahead, + n_training, + levels, + forecast_date, + target_date, + symmetrize, + nonneg, + max_lags, + quantile_by_key, + nafill_buffer + ), class = c("arx_fcast", "alist") ) } @@ -248,16 +262,22 @@ compare_quantile_args <- function(alist, tlist) { default_alist <- eval(formals(arx_args_list)$levels) default_tlist <- eval(formals(quantile_reg)$tau) if (setequal(alist, default_alist)) { - if (setequal(tlist, default_tlist)) return(sort(unique(union(alist, tlist)))) - else return(sort(unique(tlist))) + if (setequal(tlist, default_tlist)) { + return(sort(unique(union(alist, tlist)))) + } else { + return(sort(unique(tlist))) + } } else { - if (setequal(tlist, default_tlist)) return(sort(unique(alist))) - else { - if (setequal(alist, tlist)) return(sort(unique(alist))) + if (setequal(tlist, default_tlist)) { + return(sort(unique(alist))) + } else { + if (setequal(alist, tlist)) { + return(sort(unique(alist))) + } rlang::abort(c( "You have specified different, non-default, quantiles in the trainier and `arx_args` options.", - i = "Please only specify quantiles in one location.") - ) + i = "Please only specify quantiles in one location." + )) } } } diff --git a/R/bake.epi_recipe.R b/R/bake.epi_recipe.R index ba29e97a2..6857df4ef 100644 --- a/R/bake.epi_recipe.R +++ b/R/bake.epi_recipe.R @@ -17,7 +17,6 @@ #' @rdname bake #' @export bake.epi_recipe <- function(object, new_data, ...) { - if (rlang::is_missing(new_data)) { rlang::abort("'new_data' must be either an epi_df or NULL. No value is not allowed.") } @@ -83,7 +82,8 @@ bake.epi_recipe <- function(object, new_data, ...) { # Now reduce to only user selected columns out_names <- recipes_eval_select(terms, new_data, info, - check_case_weights = FALSE) + check_case_weights = FALSE + ) new_data <- new_data[, out_names] # The levels are not null when no nominal data are present or diff --git a/R/blueprint-epi_recipe-default.R b/R/blueprint-epi_recipe-default.R index 147efc4fc..886cd5512 100644 --- a/R/blueprint-epi_recipe-default.R +++ b/R/blueprint-epi_recipe-default.R @@ -1,4 +1,3 @@ - #' Recipe blueprint that accounts for `epi_df` panel data #' #' Used for simplicity. See [hardhat::new_recipe_blueprint()] or @@ -15,17 +14,17 @@ new_epi_recipe_blueprint <- function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE, composition = "tibble", ptypes = NULL, recipe = NULL, ..., subclass = character()) { - hardhat::new_recipe_blueprint( - intercept = intercept, - allow_novel_levels = allow_novel_levels, - fresh = fresh, - composition = composition, - ptypes = ptypes, - recipe = recipe, - ..., - subclass = c(subclass, "epi_recipe_blueprint") - ) -} + hardhat::new_recipe_blueprint( + intercept = intercept, + allow_novel_levels = allow_novel_levels, + fresh = fresh, + composition = composition, + ptypes = ptypes, + recipe = recipe, + ..., + subclass = c(subclass, "epi_recipe_blueprint") + ) + } #' @rdname new_epi_recipe_blueprint @@ -34,10 +33,12 @@ epi_recipe_blueprint <- function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE, composition = "tibble") { - new_epi_recipe_blueprint(intercept = intercept, - allow_novel_levels = allow_novel_levels, - fresh = fresh, - composition = composition) + new_epi_recipe_blueprint( + intercept = intercept, + allow_novel_levels = allow_novel_levels, + fresh = fresh, + composition = composition + ) } #' @rdname new_epi_recipe_blueprint @@ -61,18 +62,18 @@ new_default_epi_recipe_blueprint <- fresh = TRUE, composition = "tibble", ptypes = NULL, recipe = NULL, extra_role_ptypes = NULL, ..., subclass = character()) { - new_epi_recipe_blueprint( - intercept = intercept, - allow_novel_levels = allow_novel_levels, - fresh = fresh, - composition = composition, - ptypes = ptypes, - recipe = recipe, - extra_role_ptypes = extra_role_ptypes, - ..., - subclass = c(subclass, "default_epi_recipe_blueprint", "default_recipe_blueprint") - ) -} + new_epi_recipe_blueprint( + intercept = intercept, + allow_novel_levels = allow_novel_levels, + fresh = fresh, + composition = composition, + ptypes = ptypes, + recipe = recipe, + extra_role_ptypes = extra_role_ptypes, + ..., + subclass = c(subclass, "default_epi_recipe_blueprint", "default_recipe_blueprint") + ) + } #' @importFrom hardhat run_mold #' @export diff --git a/R/canned-epipred.R b/R/canned-epipred.R index d6f2f3680..bf99d74c7 100644 --- a/R/canned-epipred.R +++ b/R/canned-epipred.R @@ -7,8 +7,9 @@ validate_forecaster_inputs <- function(epi_data, outcome, predictors) { } arg_is_chr(predictors) arg_is_chr_scalar(outcome) - if (!outcome %in% names(epi_data)) + if (!outcome %in% names(epi_data)) { cli::cli_abort("{outcome} was not found in the training data.") + } check <- hardhat::check_column_names(epi_data, predictors) if (!check$ok) { cli::cli_abort(c( @@ -25,8 +26,9 @@ arx_lags_validator <- function(predictors, lags) { if (!is.list(lags)) lags <- list(lags) l <- length(lags) - if (l == 1) lags <- rep(lags, p) - else if (length(lags) != p) { + if (l == 1) { + lags <- rep(lags, p) + } else if (length(lags) != p) { cli::cli_abort(c( "You have requested {p} predictor(s) but {l} different lags.", i = "Lags must be a vector or a list with length == number of predictors." @@ -64,7 +66,8 @@ print.canned_epipred <- function(x, name, ...) { cat("\n") date_created <- glue::glue( - "This forecaster was fit on {format(x$metadata$forecast_created)}") + "This forecaster was fit on {format(x$metadata$forecast_created)}" + ) cat_line(date_created) cat("\n") diff --git a/R/compat-recipes.R b/R/compat-recipes.R index c035a426e..12d11049a 100644 --- a/R/compat-recipes.R +++ b/R/compat-recipes.R @@ -1,12 +1,15 @@ # These are copied from `recipes` where they are unexported -fun_calls <- function (f) { - if (is.function(f)) fun_calls(body(f)) - else if (rlang::is_quosure(f)) fun_calls(rlang::quo_get_expr(f)) - else if (is.call(f)) { +fun_calls <- function(f) { + if (is.function(f)) { + fun_calls(body(f)) + } else if (rlang::is_quosure(f)) { + fun_calls(rlang::quo_get_expr(f)) + } else if (is.call(f)) { fname <- as.character(f[[1]]) - if (identical(fname, ".Internal")) + if (identical(fname, ".Internal")) { return(fname) + } unique(c(fname, unlist(lapply(f[-1], fun_calls), use.names = FALSE))) } } diff --git a/R/create-layer.R b/R/create-layer.R index 6e30dc606..fee279796 100644 --- a/R/create-layer.R +++ b/R/create-layer.R @@ -1,4 +1,3 @@ - #' Create a new layer #' #' This function creates the skeleton for a new `frosting` layer. When called @@ -13,9 +12,9 @@ #' @examples #' \dontrun{ #' -#' # Note: running this will write `layer_strawberry.R` to -#' # the `R/` directory of your current project -#' create_layer("strawberry") +#' # Note: running this will write `layer_strawberry.R` to +#' # the `R/` directory of your current project +#' create_layer("strawberry") #' } #' create_layer <- function(name = NULL, open = rlang::is_interactive()) { @@ -25,7 +24,8 @@ create_layer <- function(name = NULL, open = rlang::is_interactive()) { if (substr(nn, 1, 1) == "_") nn <- substring(nn, 2) cli::cli_abort( c('`name` should not begin with "layer" or "layer_".', - i = 'Did you mean to use `create_layer("{ nn }")`?') + i = 'Did you mean to use `create_layer("{ nn }")`?' + ) ) } layer_name <- name @@ -35,7 +35,8 @@ create_layer <- function(name = NULL, open = rlang::is_interactive()) { path <- fs::path("R", name) if (!fs::file_exists(path)) { usethis::use_template( - "layer.R", save_as = path, + "layer.R", + save_as = path, data = list(name = layer_name), open = FALSE, package = "epipredict" ) diff --git a/R/dist_quantiles.R b/R/dist_quantiles.R index 24d2301d6..032a4d96c 100644 --- a/R/dist_quantiles.R +++ b/R/dist_quantiles.R @@ -10,11 +10,13 @@ new_quantiles <- function(q = double(), tau = double()) { q <- q[o] tau <- tau[o] } - if (is.unsorted(q, na.rm = TRUE)) + if (is.unsorted(q, na.rm = TRUE)) { rlang::abort("`q[order(tau)]` produces unsorted quantiles.") + } new_rcrd(list(q = q, tau = tau), - class = c("dist_quantiles", "dist_default")) + class = c("dist_quantiles", "dist_default") + ) } #' @export @@ -42,7 +44,7 @@ format.dist_quantiles <- function(x, digits = 2, ...) { #' #' @import vctrs #' @examples -#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2,.4,.6,.8))) +#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) #' quantile(dstn, p = c(.1, .25, .5, .9)) #' median(dstn) #' @@ -74,13 +76,15 @@ dist_quantiles <- function(x, tau) { #' dstn <- dist_normal(c(10, 2), c(5, 10)) #' extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) #' -#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2,.4,.6,.8))) +#' dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) #' # because this distribution is already quantiles, any extra quantiles are #' # appended #' extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) #' -#' dstn <- c(dist_normal(c(10, 2), c(5, 10)), -#' dist_quantiles(list(1:4, 8:11), list(c(.2,.4,.6,.8)))) +#' dstn <- c( +#' dist_normal(c(10, 2), c(5, 10)), +#' dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) +#' ) #' extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) extrapolate_quantiles <- function(x, p, ...) { UseMethod("extrapolate_quantiles") @@ -120,8 +124,8 @@ is_dist_quantiles <- function(x) { #' @export #' #' @examples -#' edf <- case_death_rate_subset[1:3,] -#' edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5/6, 2:4/5, 3:10/11)) +#' edf <- case_death_rate_subset[1:3, ] +#' edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) #' #' edf_nested <- edf %>% dplyr::mutate(q = nested_quantiles(q)) #' edf_nested %>% tidyr::unnest(q) @@ -177,7 +181,9 @@ pivot_quantiles <- function(.data, ...) { nms <- cols[!checks] cli::cli_abort( c("Quantiles must be the same length and have the same set of taus.", - i = "Check failed for variables(s) {.var {nms}}.")) + i = "Check failed for variables(s) {.var {nms}}." + ) + ) } if (length(cols) > 1L) { for (col in cols) { @@ -219,15 +225,14 @@ quantile.dist_quantiles <- function(x, probs, ..., left_tail = c("normal", "exponential"), right_tail = c("normal", "exponential")) { arg_is_probabilities(probs) - middle = match.arg(middle) - left_tail = match.arg(left_tail) - right_tail = match.arg(right_tail) + middle <- match.arg(middle) + left_tail <- match.arg(left_tail) + right_tail <- match.arg(right_tail) quantile_extrapolate(x, probs, middle, left_tail, right_tail) } quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { - tau <- field(x, "tau") qvals <- field(x, "q") r <- range(tau, na.rm = TRUE) @@ -235,10 +240,14 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { # short circuit if we aren't actually extrapolating # matches to ~15 decimals - if (all(tau_out %in% tau)) return(qvals[match(tau_out, tau)]) + if (all(tau_out %in% tau)) { + return(qvals[match(tau_out, tau)]) + } if (length(qvals) < 3 || r[1] > .25 || r[2] < .75) { - rlang::warn(c("Quantile extrapolation is not possible with fewer than", - "3 quantiles or when the probs don't span [.25, .75]")) + rlang::warn(c( + "Quantile extrapolation is not possible with fewer than", + "3 quantiles or when the probs don't span [.25, .75]" + )) return(qvals_out) } @@ -248,11 +257,15 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { if (middle == "cubic") { method <- "cubic" - result <- tryCatch({ - Q <- stats::splinefun(tau, qvals, method = "hyman") - qvals_out[indm] <- Q(tau_out[indm]) - quartiles <- Q(c(.25, .5, .75))}, - error = function(e) { return(NA) } + result <- tryCatch( + { + Q <- stats::splinefun(tau, qvals, method = "hyman") + qvals_out[indm] <- Q(tau_out[indm]) + quartiles <- Q(c(.25, .5, .75)) + }, + error = function(e) { + return(NA) + } ) } if (middle == "linear" || any(is.na(result))) { @@ -262,19 +275,21 @@ quantile_extrapolate <- function(x, tau_out, middle, left_tail, right_tail) { if (any(indm)) { - qvals_out[indm] <- switch( - method, + qvals_out[indm] <- switch(method, linear = stats::approx(tau, qvals, tau_out[indm])$y, cubic = Q(tau_out[indm]) - )} + ) + } if (any(indl)) { qvals_out[indl] <- tail_extrapolate( tau_out[indl], quartiles, "left", left_tail - )} + ) + } if (any(indr)) { - qvals_out[indr] <- tail_extrapolate( - tau_out[indr], quartiles, "right", right_tail - )} + qvals_out[indr] <- tail_extrapolate( + tau_out[indr], quartiles, "right", right_tail + ) + } qvals_out } @@ -287,17 +302,21 @@ tail_extrapolate <- function(tau_out, quartiles, tail, type) { p <- c(.75, .5) par <- quartiles[3:2] } - if (type == "normal") return(norm_tail_q(p, par, tau_out)) - if (type == "exponential") return(exp_tail_q(p, par, tau_out)) + if (type == "normal") { + return(norm_tail_q(p, par, tau_out)) + } + if (type == "exponential") { + return(exp_tail_q(p, par, tau_out)) + } } exp_q_par <- function(q) { # tau should always be c(.75, .5) or c(.25, .5) iqr <- 2 * abs(diff(q)) - s <- iqr / (2*log(2)) + s <- iqr / (2 * log(2)) m <- q[2] - return(list(m=m, s=s)) + return(list(m = m, s = s)) } exp_tail_q <- function(p, q, target) { @@ -315,7 +334,7 @@ norm_q_par <- function(q) { iqr <- 2 * abs(diff(q)) s <- iqr / 1.34897950039 # abs(diff(qnorm(c(.75, .25)))) m <- q[2] - return(list(m=m, s=s)) + return(list(m = m, s = s)) } norm_tail_q <- function(p, q, target) { @@ -335,8 +354,10 @@ Math.dist_quantiles <- function(x, ...) { #' @method Ops dist_quantiles #' @export Ops.dist_quantiles <- function(e1, e2) { - is_quantiles <- c(inherits(e1, "dist_quantiles"), - inherits(e2, "dist_quantiles")) + is_quantiles <- c( + inherits(e1, "dist_quantiles"), + inherits(e2, "dist_quantiles") + ) is_dist <- c(inherits(e1, "dist_default"), inherits(e2, "dist_default")) tau1 <- tau2 <- NULL if (is_quantiles[1]) { @@ -353,8 +374,11 @@ Ops.dist_quantiles <- function(e1, e2) { "You can't perform arithmetic between two distributions like this." ) } else { - if (is_quantiles[1]) q2 <- e2 - else q1 <- e1 + if (is_quantiles[1]) { + q2 <- e2 + } else { + q1 <- e1 + } } q <- vctrs::vec_arith(.Generic, q1, q2) new_quantiles(q = q, tau = tau) diff --git a/R/epi_check_training_set.R b/R/epi_check_training_set.R index 22e70dc60..0c7dc9036 100644 --- a/R/epi_check_training_set.R +++ b/R/epi_check_training_set.R @@ -45,8 +45,8 @@ validate_meta_match <- function(x, template, meta, warn_or_abort = "warn") { ) if (new_meta != old_meta) { switch(warn_or_abort, - warn = cli::cli_warn(msg), - abort = cli::cli_abort(msg) + warn = cli::cli_warn(msg), + abort = cli::cli_abort(msg) ) } } diff --git a/R/epi_juice.R b/R/epi_juice.R index bf48152c3..d9d23df97 100644 --- a/R/epi_juice.R +++ b/R/epi_juice.R @@ -21,7 +21,8 @@ epi_juice <- function(object, ...) { # Get user requested columns new_data <- object$template out_names <- recipes_eval_select(terms, new_data, object$term_info, - check_case_weights = FALSE) + check_case_weights = FALSE + ) new_data <- new_data[, out_names] # Since most models require factors, do the conversion from character diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 4caea7476..bd83a4eae 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -17,8 +17,9 @@ epi_recipe <- function(x, ...) { #' @export epi_recipe.default <- function(x, ...) { ## if not a formula or an epi_df, we just pass to recipes::recipe - if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) - x <- x[1,,drop=FALSE] + if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) { + x <- x[1, , drop = FALSE] + } recipes::recipe(x, ...) } @@ -98,8 +99,9 @@ epi_recipe.epi_df <- if (!is.null(roles)) { if (length(roles) != length(vars)) { rlang::abort(c( - "The number of roles should be the same as the number of ", - "variables.")) + "The number of roles should be the same as the number of ", + "variables." + )) } var_info$role <- roles } else { @@ -122,7 +124,8 @@ epi_recipe.epi_df <- role, levels = union( c("predictor", "outcome", "time_value", "geo_value", "key"), - unique(role)) # anything else + unique(role) + ) # anything else )) ## Return final object of class `recipe` @@ -130,7 +133,7 @@ epi_recipe.epi_df <- var_info = var_info, term_info = var_info, steps = NULL, - template = x[1,], + template = x[1, ], max_time_value = max(x$time_value), levels = NULL, retained = NA @@ -145,7 +148,7 @@ epi_recipe.epi_df <- #' @export epi_recipe.formula <- function(formula, data, ...) { # we ensure that there's only 1 row in the template - data <- data[1,] + data <- data[1, ] # check for minus: if (!epiprocess::is_epi_df(data)) { return(recipes::recipe(formula, data, ...)) @@ -171,7 +174,7 @@ epi_recipe.formula <- function(formula, data, ...) { # slightly modified version of `form2args()` in {recipes} epi_form2args <- function(formula, data, ...) { - if (! rlang::is_formula(formula)) formula <- as.formula(formula) + if (!rlang::is_formula(formula)) formula <- as.formula(formula) ## check for in-line formulas recipes:::inline_check(formula) @@ -303,9 +306,11 @@ prep.epi_recipe <- function( } skippers <- map_lgl(x$steps, recipes:::is_skipable) if (any(skippers) & !retain) { - rlang::warn(c("Since some operations have `skip = TRUE`, using ", - "`retain = TRUE` will allow those steps results to ", - "be accessible.")) + rlang::warn(c( + "Since some operations have `skip = TRUE`, using ", + "`retain = TRUE` will allow those steps results to ", + "be accessible." + )) } if (fresh) x$term_info <- x$var_info @@ -317,7 +322,8 @@ prep.epi_recipe <- function( arg <- paste0("'", arg, "'", collapse = ", ") msg <- paste0( "You cannot `prep()` a tuneable recipe. Argument(s) with `tune()`: ", - arg, ". Do you want to use a tuning function such as `tune_grid()`?") + arg, ". Do you want to use a tuning function such as `tune_grid()`?" + ) rlang::abort(msg) } note <- paste("oper", i, gsub("_", " ", class(x$steps[[i]])[1])) @@ -327,8 +333,10 @@ prep.epi_recipe <- function( } before_nms <- names(training) before_template <- training[1, ] - x$steps[[i]] <- prep(x$steps[[i]], training = training, - info = x$term_info) + x$steps[[i]] <- prep(x$steps[[i]], + training = training, + info = x$term_info + ) training <- bake(x$steps[[i]], new_data = training) if (!tibble::is_tibble(training)) { abort("bake() methods should always return tibbles") @@ -337,7 +345,8 @@ prep.epi_recipe <- function( # tidymodels killed our class # for now, we only allow step_epi_* to alter the metadata training <- dplyr::dplyr_reconstruct( - epiprocess::as_epi_df(training), before_template) + epiprocess::as_epi_df(training), before_template + ) } training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training))) x$term_info <- recipes:::merge_term_info(get_types(training), x$term_info) @@ -352,7 +361,8 @@ prep.epi_recipe <- function( recipes:::changelog(log_changes, before_nms, names(training), x$steps[[i]]) running_info <- rbind( running_info, - dplyr::mutate(x$term_info, number = i, skip = x$steps[[i]]$skip)) + dplyr::mutate(x$term_info, number = i, skip = x$steps[[i]]$skip) + ) } else { if (verbose) cat(note, "[pre-trained]\n") } @@ -367,9 +377,11 @@ prep.epi_recipe <- function( } if (retain) { if (verbose) { - cat("The retained training set is ~", - format(utils::object.size(training), units = "Mb", digits = 2), - " in memory.\n\n") + cat( + "The retained training set is ~", + format(utils::object.size(training), units = "Mb", digits = 2), + " in memory.\n\n" + ) } x$template <- training } else { @@ -389,7 +401,8 @@ prep.epi_recipe <- function( source = dplyr::first(source), number = dplyr::first(number), skip = dplyr::first(skip), - .groups = "keep") + .groups = "keep" + ) x } diff --git a/R/epi_selectors.R b/R/epi_selectors.R index a800782c3..673e8c575 100644 --- a/R/epi_selectors.R +++ b/R/epi_selectors.R @@ -5,4 +5,3 @@ all_epi_keys <- function() { base_epi_keys <- function() { union(has_role("time_value"), has_role("geo_value")) } - diff --git a/R/epi_shift.R b/R/epi_shift.R index 0264b2ad5..b40b36ecc 100644 --- a/R/epi_shift.R +++ b/R/epi_shift.R @@ -16,18 +16,22 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { if (!is.data.frame(x)) x <- data.frame(x) if (is.null(keys)) keys <- rep("empty", nrow(x)) - p_in = ncol(x) + p_in <- ncol(x) out_list <- tibble::tibble(i = 1:p_in, shift = shifts) %>% tidyr::unchop(shift) %>% # what is chop dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>% # One list element for each shifted feature pmap(function(i, shift, name) { tibble(keys, - time_value = time_value + shift, # Shift back - !!name := x[[i]]) + time_value = time_value + shift, # Shift back + !!name := x[[i]] + ) }) - if (is.data.frame(keys)) common_names <- c(names(keys), "time_value") - else common_names <- c("keys", "time_value") + if (is.data.frame(keys)) { + common_names <- c(names(keys), "time_value") + } else { + common_names <- c("keys", "time_value") + } reduce(out_list, dplyr::full_join, by = common_names) } diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 1379fef86..bc72b23b2 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -84,7 +84,7 @@ is_epi_workflow <- function(x) { #' @export #' @examples #' jhu <- case_death_rate_subset %>% -#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -94,8 +94,7 @@ is_epi_workflow <- function(x) { #' wf #' #' @export -fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()){ - +fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) { object$fit$meta <- list(max_time_value = max(data$time_value), as_of = attributes(data)$metadata$as_of) NextMethod() @@ -152,14 +151,19 @@ predict.epi_workflow <- function(object, new_data, ...) { if (!workflows::is_trained_workflow(object)) { rlang::abort( c("Can't predict on an untrained epi_workflow.", - i = "Do you need to call `fit()`?")) + i = "Do you need to call `fit()`?" + ) + ) } components <- list() components$mold <- workflows::extract_mold(object) components$forged <- hardhat::forge(new_data, - blueprint = components$mold$blueprint) - components$keys <- grab_forged_keys(components$forged, - components$mold, new_data) + blueprint = components$mold$blueprint + ) + components$keys <- grab_forged_keys( + components$forged, + components$mold, new_data + ) components <- apply_frosting(object, components, new_data, ...) components$predictions } @@ -174,18 +178,27 @@ predict.epi_workflow <- function(object, new_data, ...) { #' #' @return new_data with additional columns containing the predicted values #' @export -augment.epi_workflow <- function (x, new_data, ...) { +augment.epi_workflow <- function(x, new_data, ...) { predictions <- predict(x, new_data, ...) - if (epiprocess::is_epi_df(predictions)) join_by <- epi_keys(predictions) - else rlang::abort( - c("Cannot determine how to join new_data with the predictions.", - "Try converting new_data to an epi_df with `as_epi_df(new_data)`.")) + if (epiprocess::is_epi_df(predictions)) { + join_by <- epi_keys(predictions) + } else { + rlang::abort( + c( + "Cannot determine how to join new_data with the predictions.", + "Try converting new_data to an epi_df with `as_epi_df(new_data)`." + ) + ) + } complete_overlap <- intersect(names(new_data), join_by) if (length(complete_overlap) < length(join_by)) { rlang::warn( - glue::glue("Your original training data had keys {join_by}, but", - "`new_data` only has {complete_overlap}. The output", - "may be strange.")) + glue::glue( + "Your original training data had keys {join_by}, but", + "`new_data` only has {complete_overlap}. The output", + "may be strange." + ) + ) } dplyr::full_join(predictions, new_data, by = join_by) } @@ -195,9 +208,9 @@ new_epi_workflow <- function( fit = workflows:::new_stage_fit(), post = workflows:::new_stage_post(), trained = FALSE) { - out <- workflows:::new_workflow( - pre = pre, fit = fit, post = post, trained = trained) + pre = pre, fit = fit, post = post, trained = trained + ) class(out) <- c("epi_workflow", class(out)) } @@ -206,7 +219,7 @@ new_epi_workflow <- function( print.epi_workflow <- function(x, ...) { print_header(x) workflows:::print_preprocessor(x) - #workflows:::print_case_weights(x) + # workflows:::print_case_weights(x) workflows:::print_model(x) print_postprocessor(x) invisible(x) @@ -254,4 +267,3 @@ print_header <- function(x) { invisible(x) } - diff --git a/R/extract.R b/R/extract.R index bbb7c9152..574cc40cc 100644 --- a/R/extract.R +++ b/R/extract.R @@ -25,11 +25,13 @@ extract_argument <- function(x, name, arg, ...) { extract_argument.layer <- function(x, name, arg, ...) { rlang::check_dots_empty() arg_is_chr_scalar(name, arg) - in_layer_name = class(x)[1] - if (name != in_layer_name) + in_layer_name <- class(x)[1] + if (name != in_layer_name) { cli_stop("Requested {name} not found. This is a(n) {in_layer_name}.") - if (! arg %in% names(x)) + } + if (!arg %in% names(x)) { cli_stop("Requested argument {arg} not found in {name}.") + } x[[arg]] } @@ -37,21 +39,24 @@ extract_argument.layer <- function(x, name, arg, ...) { extract_argument.step <- function(x, name, arg, ...) { rlang::check_dots_empty() arg_is_chr_scalar(name, arg) - in_step_name = class(x)[1] - if (name != in_step_name) + in_step_name <- class(x)[1] + if (name != in_step_name) { cli_stop("Requested {name} not found. This is a {in_step_name}.") - if (! arg %in% names(x)) + } + if (!arg %in% names(x)) { cli_stop("Requested argument {arg} not found in {name}.") + } x[[arg]] } #' @export -extract_argument.recipe <- function(x, name, arg, ...){ +extract_argument.recipe <- function(x, name, arg, ...) { rlang::check_dots_empty() - step_names <- map_chr(x$steps, ~class(.x)[1]) + step_names <- map_chr(x$steps, ~ class(.x)[1]) has_step <- name %in% step_names - if (!has_step) + if (!has_step) { cli_stop("recipe object does not contain a {name}.") + } step_locations <- which(name == step_names) out <- map(x$steps[step_locations], extract_argument, name = name, arg = arg) if (length(out) == 1) out <- out[[1]] @@ -63,8 +68,9 @@ extract_argument.frosting <- function(x, name, arg, ...) { rlang::check_dots_empty() layer_names <- map_chr(x$layers, ~ class(.x)[1]) has_layer <- name %in% layer_names - if (! has_layer) + if (!has_layer) { cli_stop("frosting object does not contain a {name} layer.") + } layer_locations <- which(name == layer_names) out <- map(x$layers[layer_locations], extract_argument, name = name, arg = arg) if (length(out) == 1) out <- out[[1]] @@ -76,14 +82,17 @@ extract_argument.epi_workflow <- function(x, name, arg, ...) { rlang::check_dots_empty() type <- sub("_.*", "", name) if (type %in% c("check", "step")) { - if (!workflows:::has_preprocessor_recipe(x)) + if (!workflows:::has_preprocessor_recipe(x)) { cli_stop("The workflow must have a recipe preprocessor.") + } out <- extract_argument(x$pre$actions$recipe$recipe, name, arg) } - if (type %in% "layer") + if (type %in% "layer") { out <- extract_argument(extract_frosting(x), name, arg) - if (! type %in% c("check", "step", "layer")) + } + if (!type %in% c("check", "step", "layer")) { cli_stop("{name} must begin with one of step, check, or layer") + } return(out) } diff --git a/R/flatline.R b/R/flatline.R index 14d14ebf3..0f98b0e2b 100644 --- a/R/flatline.R +++ b/R/flatline.R @@ -1,4 +1,3 @@ - #' (Internal) implementation of the flatline forecaster #' #' This is an internal function that is used to create a [parsnip::linear_reg()] @@ -29,8 +28,10 @@ #' @keywords internal #' #' @examples -#' tib <- data.frame(y = runif(100), -#' expand.grid(k = letters[1:4], j = letters[5:9], time_value = 1:5)) %>% +#' tib <- data.frame( +#' y = runif(100), +#' expand.grid(k = letters[1:4], j = letters[5:9], time_value = 1:5) +#' ) %>% #' dplyr::group_by(k, j) %>% #' dplyr::mutate(y2 = dplyr::lead(y, 2)) # predict 2 steps ahead #' flat <- flatline(y2 ~ j + k + y, tib) # predictions for 20 locations @@ -41,13 +42,16 @@ flatline <- function(formula, data) { n <- length(rhs) observed <- rhs[n] # DANGER!! ek <- rhs[-n] - if (length(response) > 1) + if (length(response) > 1) { cli_stop("flatline forecaster can accept only 1 observed time series.") + } keys <- kill_time_value(ek) preds <- data %>% - dplyr::mutate(.pred = !!rlang::sym(observed), - .resid = !!rlang::sym(response) - .pred) + dplyr::mutate( + .pred = !!rlang::sym(observed), + .resid = !!rlang::sym(response) - .pred + ) .pred <- preds %>% dplyr::filter(!is.na(.pred)) %>% dplyr::group_by(!!!rlang::syms(keys)) %>% @@ -56,9 +60,11 @@ flatline <- function(formula, data) { dplyr::ungroup() %>% dplyr::select(tidyselect::all_of(c(keys, ".pred"))) - structure(list( - residuals = dplyr::select(preds, dplyr::all_of(c(keys, ".resid"))), - .pred = .pred), + structure( + list( + residuals = dplyr::select(preds, dplyr::all_of(c(keys, ".resid"))), + .pred = .pred + ), class = "flatline" ) } @@ -74,8 +80,10 @@ predict.flatline <- function(object, newdata, ...) { metadata <- names(object)[names(object) != ".pred"] ek <- names(newdata) if (!all(metadata %in% ek)) { - cli_stop("`newdata` has different metadata than was used", - "to fit the flatline forecaster") + cli_stop( + "`newdata` has different metadata than was used", + "to fit the flatline forecaster" + ) } dplyr::left_join(newdata, object, by = metadata) %>% diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 8529ba56b..e437f50ea 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -31,7 +31,6 @@ flatline_forecaster <- function( epi_data, outcome, args_list = flatline_args_list()) { - validate_forecaster_inputs(epi_data, outcome, "time_value") if (!inherits(args_list, c("flat_fcast", "alist"))) { cli_stop("args_list was not created using `flatline_args_list().") @@ -61,7 +60,8 @@ flatline_forecaster <- function( layer_residual_quantiles( probs = args_list$levels, symmetrize = args_list$symmetrize, - by_key = args_list$quantile_by_key) %>% + by_key = args_list$quantile_by_key + ) %>% layer_add_forecast_date(forecast_date = forecast_date) %>% layer_add_target_date(target_date = target_date) if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) @@ -74,13 +74,15 @@ flatline_forecaster <- function( tibble::as_tibble() %>% dplyr::select(-time_value) - structure(list( - predictions = preds, - epi_workflow = wf, - metadata = list( - training = attr(epi_data, "metadata"), - forecast_created = Sys.time() - )), + structure( + list( + predictions = preds, + epi_workflow = wf, + metadata = list( + training = attr(epi_data, "metadata"), + forecast_created = Sys.time() + ) + ), class = c("flat_fcast", "canned_epipred") ) } @@ -109,9 +111,7 @@ flatline_args_list <- function( symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), - nafill_buffer = Inf -) { - + nafill_buffer = Inf) { arg_is_scalar(ahead, n_training) arg_is_chr(quantile_by_key, allow_empty = TRUE) arg_is_scalar(forecast_date, target_date, allow_null = TRUE) @@ -124,15 +124,17 @@ flatline_args_list <- function( if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) structure( - enlist(ahead, - n_training, - forecast_date, - target_date, - levels, - symmetrize, - nonneg, - quantile_by_key, - nafill_buffer), + enlist( + ahead, + n_training, + forecast_date, + target_date, + levels, + symmetrize, + nonneg, + quantile_by_key, + nafill_buffer + ), class = c("flat_fcast", "alist") ) } diff --git a/R/frosting.R b/R/frosting.R index 88cd44b5b..f5b2adcf8 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -20,7 +20,9 @@ #' dplyr::filter(time_value >= max(time_value) - 14) #' #' # Add frosting to a workflow and predict -#' f <- frosting() %>% layer_predict() %>% layer_naomit(.pred) +#' f <- frosting() %>% +#' layer_predict() %>% +#' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) #' p1 <- predict(wf1, latest) #' p1 @@ -78,7 +80,8 @@ validate_has_postprocessor <- function(x, ..., call = caller_env()) { has_postprocessor <- has_postprocessor_frosting(x) if (!has_postprocessor) { message <- c("The workflow must have a frosting postprocessor.", - i = "Provide one with `add_frosting()`.") + i = "Provide one with `add_frosting()`." + ) rlang::abort(message, call = call) } invisible(x) @@ -139,8 +142,8 @@ new_frosting <- function() { #' @examples #' #' # Toy example to show that frosting can be created and added for postprocessing -#' f <- frosting() -#' wf <- epi_workflow() %>% add_frosting(f) +#' f <- frosting() +#' wf <- epi_workflow() %>% add_frosting(f) #' #' # A more realistic example #' jhu <- case_death_rate_subset %>% @@ -164,8 +167,10 @@ new_frosting <- function() { #' p frosting <- function(layers = NULL, requirements = NULL) { if (!is_null(layers) || !is_null(requirements)) { - rlang::abort(c("Currently, no arguments to `frosting()` are allowed", - "to be non-null.")) + rlang::abort(c( + "Currently, no arguments to `frosting()` are allowed", + "to be non-null." + )) } out <- new_frosting() } @@ -185,14 +190,18 @@ extract_frosting <- function(x, ...) { #' @export extract_frosting.default <- function(x, ...) { abort(c("Frosting is only available for epi_workflows currently.", - i = "Can you use `epi_workflow()` instead of `workflow()`?")) + i = "Can you use `epi_workflow()` instead of `workflow()`?" + )) invisible(x) } #' @export extract_frosting.epi_workflow <- function(x, ...) { - if (has_postprocessor_frosting(x)) return(x$post$actions$frosting$frosting) - else cli_stop("The epi_workflow does not have a postprocessor.") + if (has_postprocessor_frosting(x)) { + return(x$post$actions$frosting$frosting) + } else { + cli_stop("The epi_workflow does not have a postprocessor.") + } } #' Apply postprocessing to a fitted workflow @@ -215,7 +224,8 @@ apply_frosting <- function(workflow, ...) { apply_frosting.default <- function(workflow, components, ...) { if (has_postprocessor(workflow)) { abort(c("Postprocessing is only available for epi_workflows currently.", - i = "Can you use `epi_workflow()` instead of `workflow()`?")) + i = "Can you use `epi_workflow()` instead of `workflow()`?" + )) } return(components) } @@ -228,24 +238,29 @@ apply_frosting.default <- function(workflow, components, ...) { #' @export apply_frosting.epi_workflow <- function(workflow, components, new_data, ...) { - the_fit <- workflows::extract_fit_parsnip(workflow) if (!has_postprocessor(workflow)) { components$predictions <- predict( - the_fit, components$forged$predictors, ...) + the_fit, components$forged$predictors, ... + ) components$predictions <- dplyr::bind_cols( - components$keys, components$predictions) + components$keys, components$predictions + ) return(components) } if (!has_postprocessor_frosting(workflow)) { - rlang::warn(c("Only postprocessors of class frosting are allowed.", - "Returning unpostprocessed predictions.")) + rlang::warn(c( + "Only postprocessors of class frosting are allowed.", + "Returning unpostprocessed predictions." + )) components$predictions <- predict( - the_fit, components$forged$predictors, ...) + the_fit, components$forged$predictors, ... + ) components$predictions <- dplyr::bind_cols( - components$keys, components$predictions) + components$keys, components$predictions + ) return(components) } @@ -255,9 +270,12 @@ apply_frosting.epi_workflow <- if (rlang::is_null(layers)) { layers <- extract_layers(frosting() %>% layer_predict()) } else if (!detect_layer(workflow, "layer_predict")) { - layers <- c(list( - layer_predict_new(NULL, list(), list(), rand_id("predict_default"))), - layers) + layers <- c( + list( + layer_predict_new(NULL, list(), list(), rand_id("predict_default")) + ), + layers + ) } for (l in seq_along(layers)) { @@ -283,14 +301,15 @@ print.frosting <- function(x, form_width = 30, ...) { # Currently only used in the workflow printing print_frosting <- function(x, ...) { - layers <- x$layers n_layers <- length(layers) layer <- ifelse(n_layers == 1L, "Layer", "Layers") n_layers_msg <- glue::glue("{n_layers} Frosting {layer}") cat_line(n_layers_msg) - if (n_layers == 0L) return(invisible(x)) + if (n_layers == 0L) { + return(invisible(x)) + } cat_line("") @@ -316,7 +335,9 @@ print_frosting <- function(x, ...) { } print_postprocessor <- function(x) { - if (!has_postprocessor_frosting(x)) return(invisible(x)) + if (!has_postprocessor_frosting(x)) { + return(invisible(x)) + } header <- cli::rule("Postprocessor") cat_line(header) diff --git a/R/get_test_data.R b/R/get_test_data.R index 4de8910af..b4c8a2eb2 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -48,14 +48,14 @@ get_test_data <- function( x, fill_locf = FALSE, n_recent = NULL, - forecast_date = max(x$time_value) -) { + forecast_date = max(x$time_value)) { if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.") arg_is_lgl(fill_locf) arg_is_scalar(fill_locf) arg_is_scalar(n_recent, allow_null = TRUE) - if (!is.null(n_recent) && is.finite(n_recent)) + if (!is.null(n_recent) && is.finite(n_recent)) { arg_is_pos_int(n_recent, allow_null = TRUE) + } if (!is.null(n_recent)) n_recent <- abs(n_recent) # in case they passed -Inf check <- hardhat::check_column_names(x, colnames(recipe$template)) @@ -66,12 +66,14 @@ get_test_data <- function( )) } - if (class(forecast_date) != class(x$time_value)) + if (class(forecast_date) != class(x$time_value)) { cli::cli_abort("`forecast_date` must be the same class as `x$time_value`.") + } - if (forecast_date < max(x$time_value)) + if (forecast_date < max(x$time_value)) { cli::cli_abort("`forecast_date` must be no earlier than `max(x$time_value)`") + } min_lags <- min(map_dbl(recipe$steps, ~ min(.x$lag %||% Inf)), Inf) max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0) @@ -87,7 +89,8 @@ get_test_data <- function( cli::cli_abort(c( "You supplied insufficient recent data for this recipe. ", "!" = "You need at least {min_required} days of data,", - "!" = "but `x` contains only {avail_recent}.")) + "!" = "but `x` contains only {avail_recent}." + )) } x <- arrange(x, time_value) @@ -104,8 +107,9 @@ get_test_data <- function( epiprocess::group_by(dplyr::across(dplyr::all_of(groups))) # If all(lags > 0), then we get rid of recent data - if (min_lags > 0 && min_lags < Inf) + if (min_lags > 0 && min_lags < Inf) { x <- dplyr::filter(x, forecast_date - time_value >= min_lags) + } # Now, fill forward missing data if requested if (fill_locf) { @@ -126,14 +130,15 @@ get_test_data <- function( unlist() if (any(cannot_be_used)) { bad_vars <- names(cannot_be_used)[cannot_be_used] - if (recipes::is_trained(recipe)) - cli::cli_abort(c( - "The variables {.var {bad_vars}} have too many recent missing", - `!` = "values to be filled automatically. ", - i = "You should either choose `n_recent` larger than its current ", - i = "value {n_recent}, or perform NA imputation manually, perhaps with ", - i = "{.code recipes::step_impute_*()} or with {.code tidyr::fill()}." - )) + if (recipes::is_trained(recipe)) { + cli::cli_abort(c( + "The variables {.var {bad_vars}} have too many recent missing", + `!` = "values to be filled automatically. ", + i = "You should either choose `n_recent` larger than its current ", + i = "value {n_recent}, or perform NA imputation manually, perhaps with ", + i = "{.code recipes::step_impute_*()} or with {.code tidyr::fill()}." + )) + } } x <- tidyr::fill(x, !time_value) } @@ -159,6 +164,8 @@ pad_to_end <- function(x, groups, end_date) { } Seq <- function(from, to, by) { - if (from > to) return(NULL) + if (from > to) { + return(NULL) + } seq(from = from, to = to, by = by) } diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 0b522ef65..6bb2cf572 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -29,15 +29,17 @@ #' latest <- jhu %>% #' dplyr::filter(time_value >= max(time_value) - 14) #' -#' # Don't specify `forecast_date` (by default, this should be last date in latest) -#' f <- frosting() %>% layer_predict() %>% -#' layer_naomit(.pred) +#' # Don't specify `forecast_date` (by default, this should be last date in latest) +#' f <- frosting() %>% +#' layer_predict() %>% +#' layer_naomit(.pred) #' wf0 <- wf %>% add_frosting(f) #' p0 <- predict(wf0, latest) #' p0 #' #' # Specify a `forecast_date` that is greater than or equal to `as_of` date -#' f <- frosting() %>% layer_predict() %>% +#' f <- frosting() %>% +#' layer_predict() %>% #' layer_add_forecast_date(forecast_date = "2022-05-31") %>% #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) @@ -56,7 +58,7 @@ #' p2 #' #' # Do not specify a forecast_date -#' f3 <- frosting() %>% +#' f3 <- frosting() %>% #' layer_predict() %>% #' layer_add_forecast_date() %>% #' layer_naomit(.pred) @@ -83,11 +85,12 @@ layer_add_forecast_date_new <- function(forecast_date, id) { #' @export slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { - if (is.null(object$forecast_date)) { - max_time_value <- max(workflows::extract_preprocessor(workflow)$max_time_value, - workflow$fit$meta$max_time_value, - max(new_data$time_value)) + max_time_value <- max( + workflows::extract_preprocessor(workflow)$max_time_value, + workflow$fit$meta$max_time_value, + max(new_data$time_value) + ) object$forecast_date <- max_time_value } as_of_pre <- attributes(workflows::extract_preprocessor(workflow)$template)$metadata$as_of @@ -100,7 +103,8 @@ slather.layer_add_forecast_date <- function(object, components, workflow, new_da cli_warn( c("The forecast_date is less than the most ", "recent update date of the data: ", - i = "forecast_date = {object$forecast_date} while data is from {as_of_date}.") + i = "forecast_date = {object$forecast_date} while data is from {as_of_date}." + ) ) } components$predictions <- dplyr::bind_cols( @@ -113,11 +117,10 @@ slather.layer_add_forecast_date <- function(object, components, workflow, new_da #' @export print.layer_add_forecast_date <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Adding forecast date" fd <- ifelse(is.null(x$forecast_date), "", - as.character(x$forecast_date)) + as.character(x$forecast_date) + ) fd <- rlang::enquos(fd) print_layer(fd, title = title, width = width) } - diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index 1fe151bce..bc5372baf 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -31,7 +31,8 @@ #' latest <- get_test_data(r, jhu) #' #' # Use ahead + forecast date -#' f <- frosting() %>% layer_predict() %>% +#' f <- frosting() %>% +#' layer_predict() %>% #' layer_add_forecast_date(forecast_date = "2022-05-31") %>% #' layer_add_target_date() %>% #' layer_naomit(.pred) @@ -42,7 +43,8 @@ #' #' # Use ahead + max time value from pre, fit, post #' # which is the same if include `layer_add_forecast_date()` -#' f2 <- frosting() %>% layer_predict() %>% +#' f2 <- frosting() %>% +#' layer_predict() %>% #' layer_add_target_date() %>% #' layer_naomit(.pred) #' wf2 <- wf %>% add_frosting(f2) @@ -73,51 +75,56 @@ layer_add_target_date <- } layer_add_target_date_new <- function(id = id, target_date = target_date) { - layer("add_target_date", target_date = target_date, id = id) + layer("add_target_date", target_date = target_date, id = id) } #' @export slather.layer_add_target_date <- function(object, components, workflow, new_data, ...) { - the_recipe <- workflows::extract_recipe(workflow) the_frosting <- extract_frosting(workflow) if (!is.null(object$target_date)) { - target_date = as.Date(object$target_date) + target_date <- as.Date(object$target_date) } else { # null target date case if (detect_layer(the_frosting, "layer_add_forecast_date") && - !is.null(extract_argument(the_frosting, - "layer_add_forecast_date", "forecast_date"))) { - forecast_date <- extract_argument(the_frosting, - "layer_add_forecast_date", "forecast_date") + !is.null(extract_argument( + the_frosting, + "layer_add_forecast_date", "forecast_date" + ))) { + forecast_date <- extract_argument( + the_frosting, + "layer_add_forecast_date", "forecast_date" + ) ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - target_date = forecast_date + ahead + target_date <- forecast_date + ahead } else { - max_time_value <- max(workflows::extract_preprocessor(workflow)$max_time_value, - workflow$fit$meta$max_time_value, - max(new_data$time_value)) + max_time_value <- max( + workflows::extract_preprocessor(workflow)$max_time_value, + workflow$fit$meta$max_time_value, + max(new_data$time_value) + ) ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead") - target_date = max_time_value + ahead + target_date <- max_time_value + ahead } } components$predictions <- dplyr::bind_cols(components$predictions, - target_date = target_date) + target_date = target_date + ) components } #' @export print.layer_add_target_date <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Adding target date" td <- ifelse(is.null(x$target_date), "", - as.character(x$target_date)) + as.character(x$target_date) + ) td <- rlang::enquos(td) print_layer(td, title = title, width = width) } - diff --git a/R/layer_naomit.R b/R/layer_naomit.R index ba1081e8d..33c93f0ab 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -58,9 +58,6 @@ slather.layer_naomit <- function(object, components, workflow, new_data, ...) { #' @export print.layer_naomit <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Removing na predictions from" print_layer(x$terms, title = title, width = width) } - - diff --git a/R/layer_point_from_distn.R b/R/layer_point_from_distn.R index 855d8b194..9c7b0eb3e 100644 --- a/R/layer_point_from_distn.R +++ b/R/layer_point_from_distn.R @@ -69,20 +69,22 @@ layer_point_from_distn <- function(frosting, layer_point_from_distn_new <- function(type, name, id) { layer("point_from_distn", - type = type, - name = name, - id = id) + type = type, + name = name, + id = id + ) } #' @export slather.layer_point_from_distn <- function(object, components, workflow, new_data, ...) { - dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { rlang::warn( c("`layer_point_from_distn` requires distributional predictions.", - i = "These are of class {class(dstn)}. Ignoring this layer.")) + i = "These are of class {class(dstn)}. Ignoring this layer." + ) + ) return(components) } @@ -100,7 +102,6 @@ slather.layer_point_from_distn <- #' @export print.layer_point_from_distn <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Extracting point predictions" if (is.null(x$name)) { cnj <- NULL @@ -111,4 +112,3 @@ print.layer_point_from_distn <- function( } print_layer(title = title, width = width, conjunction = cnj, extra_text = ext) } - diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index eb0bff290..3cffbdd87 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -51,13 +51,15 @@ #' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% #' dplyr::select(geo_value, time_value, cases) #' -#' pop_data = data.frame(states = c("ca", "ny"), value = c(20000, 30000)) +#' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) #' #' r <- epi_recipe(jhu) %>% -#' step_population_scaling(df = pop_data, -#' df_pop_col = "value", -#' by = c("geo_value" = "states"), -#' cases, suffix = "_scaled") %>% +#' step_population_scaling( +#' df = pop_data, +#' df_pop_col = "value", +#' by = c("geo_value" = "states"), +#' cases, suffix = "_scaled" +#' ) %>% #' step_epi_lag(cases_scaled, lag = c(0, 7, 14)) %>% #' step_epi_ahead(cases_scaled, ahead = 7, role = "outcome") %>% #' step_epi_naomit() @@ -66,9 +68,11 @@ #' layer_predict() %>% #' layer_threshold(.pred) %>% #' layer_naomit(.pred) %>% -#' layer_population_scaling(.pred, df = pop_data, -#' by = c("geo_value" = "states"), -#' df_pop_col = "value") +#' layer_population_scaling(.pred, +#' df = pop_data, +#' by = c("geo_value" = "states"), +#' df_pop_col = "value" +#' ) #' #' wf <- epi_workflow(r, parsnip::linear_reg()) %>% #' fit(jhu) %>% @@ -77,27 +81,30 @@ #' latest <- get_test_data( #' recipe = r, #' x = epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter(time_value > "2021-11-01", -#' geo_value %in% c("ca", "ny")) %>% -#' dplyr::select(geo_value, time_value, cases)) +#' dplyr::filter( +#' time_value > "2021-11-01", +#' geo_value %in% c("ca", "ny") +#' ) %>% +#' dplyr::select(geo_value, time_value, cases) +#' ) #' #' predict(wf, latest) layer_population_scaling <- function(frosting, - ..., - df, - by = NULL, - df_pop_col, - rate_rescaling = 1, - create_new = TRUE, - suffix = "_scaled", - id = rand_id("population_scaling")) { - + ..., + df, + by = NULL, + df_pop_col, + rate_rescaling = 1, + create_new = TRUE, + suffix = "_scaled", + id = rand_id("population_scaling")) { arg_is_scalar(df_pop_col, rate_rescaling, create_new, suffix, id) arg_is_lgl(create_new) arg_is_chr(df_pop_col, suffix, id) arg_is_chr(by, allow_null = TRUE) - if (rate_rescaling <= 0) + if (rate_rescaling <= 0) { cli_stop("`rate_rescaling` should be a positive number") + } add_layer( frosting, @@ -116,37 +123,46 @@ layer_population_scaling <- function(frosting, layer_population_scaling_new <- function(df, by, df_pop_col, rate_rescaling, terms, create_new, suffix, id) { - layer("population_scaling", - df = df, - by = by, - df_pop_col = df_pop_col, - rate_rescaling = rate_rescaling, - terms = terms, - create_new = create_new, - suffix = suffix, - id = id) -} + layer("population_scaling", + df = df, + by = by, + df_pop_col = df_pop_col, + rate_rescaling = rate_rescaling, + terms = terms, + create_new = create_new, + suffix = suffix, + id = id + ) + } #' @export slather.layer_population_scaling <- function(object, components, workflow, new_data, ...) { - stopifnot("Only one population column allowed for scaling" = - length(object$df_pop_col) == 1) + stopifnot( + "Only one population column allowed for scaling" = + length(object$df_pop_col) == 1 + ) - try_join <- try(dplyr::left_join(components$predictions, object$df, - by = object$by), - silent = TRUE) + try_join <- try( + dplyr::left_join(components$predictions, object$df, + by = object$by + ), + silent = TRUE + ) if (any(grepl("Join columns must be present in data", unlist(try_join)))) { - cli_stop(c("columns in `by` selectors of `layer_population_scaling` ", - "must be present in data and match"))} + cli_stop(c( + "columns in `by` selectors of `layer_population_scaling` ", + "must be present in data and match" + )) + } object$df <- object$df %>% dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) - pop_col = rlang::sym(object$df_pop_col) + pop_col <- rlang::sym(object$df_pop_col) exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) - suffix = ifelse(object$create_new, object$suffix, "") + suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions)) components$predictions <- dplyr::left_join( @@ -167,9 +183,6 @@ slather.layer_population_scaling <- #' @export print.layer_population_scaling <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Scaling predictions by population" print_layer(x$terms, title = title, width = width) } - - diff --git a/R/layer_predict.R b/R/layer_predict.R index e60f0595c..b40c24be5 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -20,9 +20,9 @@ #' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% -#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% -#' step_epi_ahead(death_rate, ahead = 7) %>% -#' step_epi_naomit() +#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_ahead(death_rate, ahead = 7) %>% +#' step_epi_naomit() #' #' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) #' latest <- jhu %>% filter(time_value >= max(time_value) - 14) @@ -63,26 +63,24 @@ layer_predict_new <- function(type, opts, dots_list, id) { #' @export slather.layer_predict <- function(object, components, workflow, new_data, ...) { - the_fit <- workflows::extract_fit_parsnip(workflow) components$predictions <- predict( the_fit, components$forged$predictors, - type = object$type, opts = object$opts) + type = object$type, opts = object$opts + ) components$predictions <- dplyr::bind_cols( - components$keys, components$predictions) + components$keys, components$predictions + ) components } #' @export print.layer_predict <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Creating predictions" td <- "" td <- rlang::enquos(td) print_layer(td, title = title, width = width) } - - diff --git a/R/layer_predictive_distn.R b/R/layer_predictive_distn.R index c951d9ccd..b72be6ec3 100644 --- a/R/layer_predictive_distn.R +++ b/R/layer_predictive_distn.R @@ -65,14 +65,15 @@ layer_predictive_distn <- function(frosting, } layer_predictive_distn_new <- function(dist_type, truncate, name, id) { - layer("predictive_distn", dist_type = dist_type, truncate = truncate, - name = name, id = id) + layer("predictive_distn", + dist_type = dist_type, truncate = truncate, + name = name, id = id + ) } #' @export slather.layer_predictive_distn <- function(object, components, workflow, new_data, ...) { - the_fit <- workflows::extract_fit_parsnip(workflow) m <- components$predictions$.pred @@ -82,9 +83,8 @@ slather.layer_predictive_distn <- papprox <- ncol(components$mold$predictors) + 1 if (is.null(df)) df <- n - papprox mse <- sum(r^2, na.rm = TRUE) / df - s <- sqrt(mse * (1 + papprox / df )) # E[x (X'X)^1 x] if E[X'X] ~= (n-p) I - dstn <- switch( - object$dist_type, + s <- sqrt(mse * (1 + papprox / df)) # E[x (X'X)^1 x] if E[X'X] ~= (n-p) I + dstn <- switch(object$dist_type, gaussian = distributional::dist_normal(m, s), student_t = distributional::dist_student_t(df, m, s) ) @@ -101,11 +101,11 @@ slather.layer_predictive_distn <- #' @export print.layer_predictive_distn <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Creating approximate predictive intervals" td <- "" td <- rlang::enquos(td) - print_layer(td, title = title, width = width, conjunction = "type", - extra_text = x$dist_type) + print_layer(td, + title = title, width = width, conjunction = "type", + extra_text = x$dist_type + ) } - diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index 97d546ed1..2b63206b2 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -63,21 +63,24 @@ layer_quantile_distn <- function(frosting, layer_quantile_distn_new <- function(levels, truncate, name, id) { layer("quantile_distn", - levels = levels, - truncate = truncate, - name = name, - id = id) + levels = levels, + truncate = truncate, + name = name, + id = id + ) } #' @export slather.layer_quantile_distn <- function(object, components, workflow, new_data, ...) { - dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { rlang::abort( - c("`layer_quantile_distn` requires distributional predictions.", - "These are of class {class(dstn)}.")) + c( + "`layer_quantile_distn` requires distributional predictions.", + "These are of class {class(dstn)}." + ) + ) } dstn <- dist_quantiles(quantile(dstn, object$levels), object["levels"]) @@ -94,14 +97,12 @@ slather.layer_quantile_distn <- #' @export print.layer_quantile_distn <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Creating predictive quantiles" td <- "" td <- rlang::enquos(td) ext <- x$levels - print_layer(td, title = title, width = width, conjunction = "levels", - extra_text = ext) + print_layer(td, + title = title, width = width, conjunction = "levels", + extra_text = ext + ) } - - - diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index c97525b41..a9a8cab24 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -66,17 +66,20 @@ layer_residual_quantiles <- function(frosting, ..., } layer_residual_quantiles_new <- function(probs, symmetrize, by_key, name, id) { - layer("residual_quantiles", probs = probs, symmetrize = symmetrize, - by_key = by_key, name = name, id = id) + layer("residual_quantiles", + probs = probs, symmetrize = symmetrize, + by_key = by_key, name = name, id = id + ) } #' @export slather.layer_residual_quantiles <- function(object, components, workflow, new_data, ...) { - the_fit <- workflows::extract_fit_parsnip(workflow) - if (is.null(object$probs)) return(components) + if (is.null(object$probs)) { + return(components) + } s <- ifelse(object$symmetrize, -1, NA) r <- grab_residuals(the_fit, components) @@ -127,8 +130,9 @@ slather.layer_residual_quantiles <- } grab_residuals <- function(the_fit, components) { - if (the_fit$spec$mode != "regression") + if (the_fit$spec$mode != "regression") { rlang::abort("For meaningful residuals, the predictor should be a regression model.") + } r_generic <- attr(utils::methods(class = class(the_fit$fit)[1]), "info")$generic if ("residuals" %in% r_generic) { # Try to use the available method. cl <- class(the_fit$fit)[1] @@ -169,12 +173,12 @@ grab_residuals <- function(the_fit, components) { #' @export print.layer_residual_quantiles <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Resampling residuals for predictive quantiles" td <- "" td <- rlang::enquos(td) ext <- x$probs - print_layer(td, title = title, width = width, conjunction = "levels", - extra_text = ext) + print_layer(td, + title = title, width = width, conjunction = "levels", + extra_text = ext + ) } - diff --git a/R/layer_threshold_preds.R b/R/layer_threshold_preds.R index eb1cb0577..4107504a9 100644 --- a/R/layer_threshold_preds.R +++ b/R/layer_threshold_preds.R @@ -108,7 +108,8 @@ slather.layer_threshold <- dplyr::across( dplyr::all_of(col_names), ~ snap(.x, object$lower, object$upper) - )) + ) + ) components } @@ -116,12 +117,12 @@ slather.layer_threshold <- #' @export print.layer_threshold <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Thresholding predictions" lwr <- ifelse(is.infinite(x$lower), "(", "[") upr <- ifelse(is.infinite(x$upper), ")", "]") rng <- paste0(lwr, round(x$lower, 3), ", ", round(x$upper, 3), upr) - print_layer(x$terms, title = title, width = width, conjunction = "to", - extra_text = rng) + print_layer(x$terms, + title = title, width = width, conjunction = "to", + extra_text = rng + ) } - diff --git a/R/layer_unnest.R b/R/layer_unnest.R index 8b545c9cd..64b17a306 100644 --- a/R/layer_unnest.R +++ b/R/layer_unnest.R @@ -40,7 +40,6 @@ slather.layer_unnest <- #' @export print.layer_unnest <- function( x, width = max(20, options()$width - 30), ...) { - title <- "Unnesting prediction list-cols" print_layer(x$terms, title = title, width = width) } diff --git a/R/make_flatline_reg.R b/R/make_flatline_reg.R index 33c135f08..0f3076639 100644 --- a/R/make_flatline_reg.R +++ b/R/make_flatline_reg.R @@ -11,7 +11,8 @@ make_flatline_reg <- function() { protect = c("formula", "data"), func = c(pkg = "epipredict", fun = "flatline"), defaults = list() - )) + ) + ) parsnip::set_encoding( model = "linear_reg", @@ -35,5 +36,4 @@ make_flatline_reg <- function() { args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) - } diff --git a/R/make_quantile_reg.R b/R/make_quantile_reg.R index b181e8a80..eef4d4c97 100644 --- a/R/make_quantile_reg.R +++ b/R/make_quantile_reg.R @@ -1,4 +1,3 @@ - #' Quantile regression #' #' @description @@ -23,9 +22,9 @@ #' rq_spec <- quantile_reg(tau = c(.2, .8)) %>% set_engine("rq") #' ff <- rq_spec %>% fit(y ~ ., data = tib) #' predict(ff, new_data = tib) -quantile_reg <- function(mode = "regression", engine = "rq", tau = 0.5) { +quantile_reg <- function(mode = "regression", engine = "rq", tau = 0.5) { # Check for correct mode - if (mode != "regression") { + if (mode != "regression") { rlang::abort("`mode` should be 'regression'") } @@ -78,7 +77,8 @@ make_quantile_reg <- function() { defaults = list( method = "br", na.action = rlang::expr(stats::na.omit), - model = FALSE) + model = FALSE + ) ) ) @@ -100,15 +100,15 @@ make_quantile_reg <- function() { # can't make a method because object is second - out <- switch( - type, + out <- switch(type, rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile rqs = { x <- lapply(unname(split(x, seq(nrow(x)))), function(q) sort(q)) dist_quantiles(x, list(object$tau)) }, rlang::abort(c("Prediction not implemented for this `rq` type.", - i = "See `?quantreg::rq`.")) + i = "See `?quantreg::rq`." + )) ) return(data.frame(.pred = out)) } @@ -127,4 +127,3 @@ make_quantile_reg <- function() { ) ) } - diff --git a/R/make_smooth_quantile_reg.R b/R/make_smooth_quantile_reg.R index b4e197a7b..6eab2a132 100644 --- a/R/make_smooth_quantile_reg.R +++ b/R/make_smooth_quantile_reg.R @@ -1,4 +1,3 @@ - #' Smooth quantile regression #' #' @description @@ -27,9 +26,10 @@ #' tib <- data.frame( #' y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100), #' y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100), -#' x1 = rnorm(100), x2 = rnorm(100)) +#' x1 = rnorm(100), x2 = rnorm(100) +#' ) #' qr_spec <- smooth_quantile_reg(tau = c(.2, .5, .8), outcome_locations = 1:6) -#' ff <- qr_spec %>% fit(cbind(y1, y2 , y3 , y4 , y5 , y6) ~ ., data = tib) +#' ff <- qr_spec %>% fit(cbind(y1, y2, y3, y4, y5, y6) ~ ., data = tib) #' p <- predict(ff, new_data = tib) #' #' x <- -99:99 / 100 * 2 * pi @@ -38,21 +38,23 @@ #' XY <- smoothqr::lagmat(y[1:(length(y) - 20)], c(-20:20)) #' XY <- tibble::as_tibble(XY) #' qr_spec <- smooth_quantile_reg(tau = c(.2, .5, .8), outcome_locations = 20:1) -#' tt <- qr_spec %>% fit_xy(x = XY[,21:41], y = XY[,1:20]) +#' tt <- qr_spec %>% fit_xy(x = XY[, 21:41], y = XY[, 1:20]) #' #' library(tidyr) #' library(dplyr) #' pl <- predict( -#' object = tt, -#' new_data = XY[max(which(complete.cases(XY[,21:41]))), 21:41] -#' ) +#' object = tt, +#' new_data = XY[max(which(complete.cases(XY[, 21:41]))), 21:41] +#' ) #' pl <- pl %>% -#' unnest(.pred) %>% -#' mutate(distn = nested_quantiles(distn)) %>% -#' unnest(distn) %>% -#' mutate(x = x[length(x) - 20] + ahead / 100 * 2 * pi, -#' ahead = NULL) %>% -#' pivot_wider(names_from = tau, values_from = q) +#' unnest(.pred) %>% +#' mutate(distn = nested_quantiles(distn)) %>% +#' unnest(distn) %>% +#' mutate( +#' x = x[length(x) - 20] + ahead / 100 * 2 * pi, +#' ahead = NULL +#' ) %>% +#' pivot_wider(names_from = tau, values_from = q) #' plot(x, y, pch = 16, xlim = c(pi, 2 * pi), col = "lightgrey") #' curve(sin(x), add = TRUE) #' abline(v = fd, lty = 2) @@ -76,7 +78,6 @@ smooth_quantile_reg <- function( outcome_locations = NULL, tau = 0.5, degree = 3L) { - # Check for correct mode if (mode != "regression") rlang::abort("`mode` must be 'regression'") if (engine != "smoothqr") rlang::abort("`engine` must be 'smoothqr'") @@ -90,8 +91,10 @@ smooth_quantile_reg <- function( tau <- sort(tau) } - args <- list(tau = rlang::enquo(tau), degree = rlang::enquo(degree), - outcome_locations = rlang::enquo(outcome_locations)) + args <- list( + tau = rlang::enquo(tau), degree = rlang::enquo(degree), + outcome_locations = rlang::enquo(outcome_locations) + ) # Save some empty slots for future parts of the specification parsnip::new_model_spec( @@ -169,8 +172,8 @@ make_smooth_quantile_reg <- function() { object <- parsnip::extract_fit_engine(object) list_of_pred_distns <- lapply(x, function(p) { x <- lapply(unname(split( - p, seq(nrow(p)))), function(q) unname(sort(q, na.last = TRUE) - )) + p, seq(nrow(p)) + )), function(q) unname(sort(q, na.last = TRUE))) dist_quantiles(x, list(object$tau)) }) n_preds <- length(list_of_pred_distns[[1]]) @@ -178,7 +181,8 @@ make_smooth_quantile_reg <- function() { tib <- tibble::tibble( ids = rep(seq(n_preds), times = nout), ahead = rep(object$aheads, each = n_preds), - distn = do.call(c, unname(list_of_pred_distns))) %>% + distn = do.call(c, unname(list_of_pred_distns)) + ) %>% tidyr::nest(.pred = c(ahead, distn)) return(tib[".pred"]) @@ -197,4 +201,3 @@ make_smooth_quantile_reg <- function() { ) ) } - diff --git a/R/print_epi_step.R b/R/print_epi_step.R index 557a70a81..0af52a4e7 100644 --- a/R/print_epi_step.R +++ b/R/print_epi_step.R @@ -3,17 +3,19 @@ print_epi_step <- function( width = max(20, options()$width - 30), case_weights = NULL, conjunction = NULL, extra_text = NULL) { theme_div_id <- cli::cli_div( - theme = list(.pkg = list(`vec-trunc` = Inf, `vec-last` = ", ")) - ) + theme = list(.pkg = list(`vec-trunc` = Inf, `vec-last` = ", ")) + ) title <- trimws(title) trained_text <- dplyr::if_else(trained, "Trained", "") case_weights_text <- dplyr::case_when( is.null(case_weights) ~ "", isTRUE(case_weights) ~ "weighted", - isFALSE(case_weights) ~ "ignored weights") + isFALSE(case_weights) ~ "ignored weights" + ) vline_seperator <- dplyr::if_else(trained_text == "", "", "|") comma_seperator <- dplyr::if_else( - trained_text != "" && case_weights_text != "", true = ",", false = "") + trained_text != "" && case_weights_text != "", true = ",", false = "" + ) extra_text <- recipes::format_ch_vec(extra_text) width_title <- nchar(paste0( "* ", title, ":", " ", conjunction, " ", extra_text, " ", vline_seperator, @@ -42,7 +44,8 @@ print_epi_step <- function( ) more_dots <- ifelse(first_line == length(elements), "", ", ...") cli::cli_bullets( - c(`*` = "\n {title}: \\\n {.pkg {cli::cli_vec(elements[seq_len(first_line)])}}\\\n {more_dots} \\\n {conjunction} \\\n {.pkg {extra_text}} \\\n {vline_seperator} \\\n {.emph {trained_text}}\\\n {comma_seperator} \\\n {.emph {case_weights_text}}\n ")) + c(`*` = "\n {title}: \\\n {.pkg {cli::cli_vec(elements[seq_len(first_line)])}}\\\n {more_dots} \\\n {conjunction} \\\n {.pkg {extra_text}} \\\n {vline_seperator} \\\n {.emph {trained_text}}\\\n {comma_seperator} \\\n {.emph {case_weights_text}}\n ") + ) cli::cli_end(theme_div_id) invisible(NULL) } diff --git a/R/print_layer.R b/R/print_layer.R index 777eab513..9863bf5e7 100644 --- a/R/print_layer.R +++ b/R/print_layer.R @@ -5,7 +5,8 @@ print_layer <- function( width_title <- nchar(paste0("* ", title, ":", " ")) extra_text <- recipes::format_ch_vec(extra_text) width_title <- nchar(paste0( - "* ", title, ":", " ", conjunction, " ", extra_text)) + "* ", title, ":", " ", conjunction, " ", extra_text + )) width_diff <- cli::console_width() * 1 - width_title elements <- lapply(layer_obj, function(x) { rlang::expr_deparse(rlang::quo_get_expr(x), width = Inf) @@ -24,6 +25,7 @@ print_layer <- function( ) more_dots <- ifelse(first_line == length(elements), "", ", ...") cli::cli_bullets( - c(`*` = "\n {title}: \\\n {.pkg {elements[seq_len(first_line)]}}\\\n {more_dots} \\\n {conjunction} \\\n {.pkg {extra_text}}")) + c(`*` = "\n {title}: \\\n {.pkg {elements[seq_len(first_line)]}}\\\n {more_dots} \\\n {conjunction} \\\n {.pkg {extra_text}}") + ) invisible(NULL) } diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index adec97d1f..ec5428d8f 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -48,7 +48,7 @@ #' @examples #' r <- epi_recipe(case_death_rate_subset) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% -#' step_epi_lag(death_rate, lag = c(0,7,14)) +#' step_epi_lag(death_rate, lag = c(0, 7, 14)) #' r step_epi_lag <- function(recipe, @@ -61,32 +61,39 @@ step_epi_lag <- columns = NULL, skip = FALSE, id = rand_id("epi_lag")) { - if (!is_epi_recipe(recipe)) + if (!is_epi_recipe(recipe)) { rlang::abort("This recipe step can only operate on an `epi_recipe`.") + } if (missing(lag)) { rlang::abort( c("The `lag` argument must not be empty.", - i = "Did you perhaps pass an integer in `...` accidentally?")) + i = "Did you perhaps pass an integer in `...` accidentally?" + ) + ) } arg_is_nonneg_int(lag) arg_is_chr_scalar(prefix, id) - if (!is.null(columns)) + if (!is.null(columns)) { rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to lag.")) - add_step(recipe, - step_epi_lag_new( - terms = dplyr::enquos(...), - role = role, - trained = trained, - lag = lag, - prefix = prefix, - default = default, - keys = epi_keys(recipe), - columns = columns, - skip = skip, - id = id - )) + i = "Use `tidyselect` methods to choose columns to lag." + )) + } + add_step( + recipe, + step_epi_lag_new( + terms = dplyr::enquos(...), + role = role, + trained = trained, + lag = lag, + prefix = prefix, + default = default, + keys = epi_keys(recipe), + columns = columns, + skip = skip, + id = id + ) + ) } #' Create a shifted predictor @@ -105,32 +112,39 @@ step_epi_ahead <- columns = NULL, skip = FALSE, id = rand_id("epi_ahead")) { - if (!is_epi_recipe(recipe)) + if (!is_epi_recipe(recipe)) { rlang::abort("This recipe step can only operate on an `epi_recipe`.") + } if (missing(ahead)) { rlang::abort( c("The `ahead` argument must not be empty.", - i = "Did you perhaps pass an integer in `...` accidentally?")) + i = "Did you perhaps pass an integer in `...` accidentally?" + ) + ) } arg_is_nonneg_int(ahead) arg_is_chr_scalar(prefix, id) - if (!is.null(columns)) + if (!is.null(columns)) { rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to lead.")) - add_step(recipe, - step_epi_ahead_new( - terms = dplyr::enquos(...), - role = role, - trained = trained, - ahead = ahead, - prefix = prefix, - default = default, - keys = epi_keys(recipe), - columns = columns, - skip = skip, - id = id - )) + i = "Use `tidyselect` methods to choose columns to lead." + )) + } + add_step( + recipe, + step_epi_ahead_new( + terms = dplyr::enquos(...), + role = role, + trained = trained, + ahead = ahead, + prefix = prefix, + default = default, + keys = epi_keys(recipe), + columns = columns, + skip = skip, + id = id + ) + ) } @@ -209,19 +223,24 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { #' @export bake.step_epi_lag <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, lag = object$lag) %>% - dplyr::mutate(newname = glue::glue("{object$prefix}{lag}_{col}"), - shift_val = lag, - lag = NULL) + dplyr::mutate( + newname = glue::glue("{object$prefix}{lag}_{col}"), + shift_val = lag, + lag = NULL + ) ## ensure no name clashes new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { rlang::abort( - paste0("Name collision occured in `", class(object)[1], - "`. The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - ".")) + paste0( + "Name collision occured in `", class(object)[1], + "`. The following variable names already exists: ", + paste0(new_data_names[intersection], collapse = ", "), + "." + ) + ) } ok <- object$keys shifted <- reduce( @@ -234,25 +253,29 @@ bake.step_epi_lag <- function(object, new_data, ...) { dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% dplyr::arrange(time_value) %>% dplyr::ungroup() - } #' @export bake.step_epi_ahead <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, ahead = object$ahead) %>% - dplyr::mutate(newname = glue::glue("{object$prefix}{ahead}_{col}"), - shift_val = -ahead, - ahead = NULL) + dplyr::mutate( + newname = glue::glue("{object$prefix}{ahead}_{col}"), + shift_val = -ahead, + ahead = NULL + ) ## ensure no name clashes new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { rlang::abort( - paste0("Name collision occured in `", class(object)[1], - "`. The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - ".")) + paste0( + "Name collision occured in `", class(object)[1], + "`. The following variable names already exists: ", + paste0(new_data_names[intersection], collapse = ", "), + "." + ) + ) } ok <- object$keys shifted <- reduce( @@ -265,21 +288,24 @@ bake.step_epi_ahead <- function(object, new_data, ...) { dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% dplyr::arrange(time_value) %>% dplyr::ungroup() - } #' @export print.step_epi_lag <- function(x, width = max(20, options()$width - 30), ...) { - print_epi_step(x$columns, x$terms, x$trained, "Lagging", conjunction = "by", - extra_text = x$lag) + print_epi_step(x$columns, x$terms, x$trained, "Lagging", + conjunction = "by", + extra_text = x$lag + ) invisible(x) } #' @export print.step_epi_ahead <- function(x, width = max(20, options()$width - 30), ...) { - print_epi_step(x$columns, x$terms, x$trained, "Leading", conjunction = "by", - extra_text = x$ahead) + print_epi_step(x$columns, x$terms, x$trained, "Leading", + conjunction = "by", + extra_text = x$ahead + ) invisible(x) } @@ -287,19 +313,24 @@ print.step_epi_ahead <- function(x, width = max(20, options()$width - 30), ...) print_step_shift <- function( tr_obj = NULL, untr_obj = NULL, trained = FALSE, title = NULL, width = max(20, options()$width - 30), case_weights = NULL, shift = NULL) { - cat(title) - if (trained) txt <- recipes::format_ch_vec(tr_obj, width = width) - else txt <- recipes::format_selectors(untr_obj, width = width) + if (trained) { + txt <- recipes::format_ch_vec(tr_obj, width = width) + } else { + txt <- recipes::format_selectors(untr_obj, width = width) + } if (length(txt) == 0L) txt <- "" cat(txt) if (trained) { - if (is.null(case_weights)) cat(" [trained]") - else { + if (is.null(case_weights)) { + cat(" [trained]") + } else { case_weights_ind <- ifelse(case_weights, "weighted", - "ignored weights") + "ignored weights" + ) trained_txt <- paste(case_weights_ind, "trained", - sep = ", ") + sep = ", " + ) trained_txt <- paste0(" [", trained_txt, "]") cat(trained_txt) } diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index 8d573ebcc..f6ad29a5b 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -38,27 +38,28 @@ #' step_growth_rate(case_rate, death_rate) #' r #' -#' r %>% recipes::prep() %>% recipes::bake(case_death_rate_subset) +#' r %>% +#' recipes::prep() %>% +#' recipes::bake(case_death_rate_subset) step_growth_rate <- function( - recipe, - ..., - role = "predictor", - trained = FALSE, - horizon = 7, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), - log_scale = FALSE, - replace_Inf = NA, - prefix = "gr_", - columns = NULL, - skip = FALSE, - id = rand_id("growth_rate"), - additional_gr_args_list = list() - ) { - - if (!is_epi_recipe(recipe)) + recipe, + ..., + role = "predictor", + trained = FALSE, + horizon = 7, + method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + log_scale = FALSE, + replace_Inf = NA, + prefix = "gr_", + columns = NULL, + skip = FALSE, + id = rand_id("growth_rate"), + additional_gr_args_list = list()) { + if (!is_epi_recipe(recipe)) { rlang::abort("This recipe step can only operate on an `epi_recipe`.") - method = match.arg(method) + } + method <- match.arg(method) arg_is_pos_int(horizon) arg_is_scalar(horizon) if (!is.null(replace_Inf)) { @@ -73,30 +74,35 @@ step_growth_rate <- if (!is.list(additional_gr_args_list)) { rlang::abort( c("`additional_gr_args_list` must be a list.", - i = "See `?epiprocess::growth_rate` for available options.")) + i = "See `?epiprocess::growth_rate` for available options." + ) + ) } if (!is.null(columns)) { rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to use.")) + i = "Use `tidyselect` methods to choose columns to use." + )) } - add_step(recipe, - step_growth_rate_new( - terms = dplyr::enquos(...), - role = role, - trained = trained, - horizon = horizon, - method = method, - log_scale = log_scale, - replace_Inf = replace_Inf, - prefix = prefix, - keys = epi_keys(recipe), - columns = columns, - skip = skip, - id = id, - additional_gr_args_list = additional_gr_args_list - )) + add_step( + recipe, + step_growth_rate_new( + terms = dplyr::enquos(...), + role = role, + trained = trained, + horizon = horizon, + method = method, + log_scale = log_scale, + replace_Inf = replace_Inf, + prefix = prefix, + keys = epi_keys(recipe), + columns = columns, + skip = skip, + id = id, + additional_gr_args_list = additional_gr_args_list + ) + ) } @@ -167,10 +173,13 @@ bake.step_growth_rate <- function(object, new_data, ...) { if (any(intersection)) { rlang::abort( c(paste0("Name collision occured in `", class(object)[1], "`."), - i = paste("The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - ".") - )) + i = paste( + "The following variable names already exists: ", + paste0(new_data_names[intersection], collapse = ", "), + "." + ) + ) + ) } ok <- object$keys @@ -181,12 +190,14 @@ bake.step_growth_rate <- function(object, new_data, ...) { dplyr::across( dplyr::all_of(object$columns), ~ epiprocess::growth_rate( - time_value, .x, method = object$method, + time_value, .x, + method = object$method, h = object$horizon, log_scale = object$log_scale, !!!object$additional_gr_args_list ), .names = "{object$prefix}{object$horizon}_{object$method}_{.col}" - )) %>% + ) + ) %>% dplyr::ungroup() %>% dplyr::mutate(time_value = time_value + object$horizon) # shift x0 right @@ -212,7 +223,8 @@ print.step_growth_rate <- function(x, width = max(20, options()$width - 30), ... print_epi_step( x$columns, x$terms, x$trained, title = "Calculating growth_rate for ", - conjunction = "by", extra_text = x$method) + conjunction = "by", extra_text = x$method + ) invisible(x) } diff --git a/R/step_lag_difference.R b/R/step_lag_difference.R index e4096e113..2482be46a 100644 --- a/R/step_lag_difference.R +++ b/R/step_lag_difference.R @@ -19,22 +19,23 @@ #' step_lag_difference(case_rate, death_rate, horizon = c(7, 14)) #' r #' -#' r %>% recipes::prep() %>% recipes::bake(case_death_rate_subset) +#' r %>% +#' recipes::prep() %>% +#' recipes::bake(case_death_rate_subset) step_lag_difference <- function( - recipe, - ..., - role = "predictor", - trained = FALSE, - horizon = 7, - prefix = "lag_diff_", - columns = NULL, - skip = FALSE, - id = rand_id("lag_diff") - ) { - - if (!is_epi_recipe(recipe)) + recipe, + ..., + role = "predictor", + trained = FALSE, + horizon = 7, + prefix = "lag_diff_", + columns = NULL, + skip = FALSE, + id = rand_id("lag_diff")) { + if (!is_epi_recipe(recipe)) { rlang::abort("This recipe step can only operate on an `epi_recipe`.") + } arg_is_pos_int(horizon) arg_is_chr(role) arg_is_chr_scalar(prefix, id) @@ -43,22 +44,25 @@ step_lag_difference <- if (!is.null(columns)) { rlang::abort( c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to use.") + i = "Use `tidyselect` methods to choose columns to use." + ) ) } - add_step(recipe, - step_lag_difference_new( - terms = dplyr::enquos(...), - role = role, - trained = trained, - horizon = horizon, - prefix = prefix, - keys = epi_keys(recipe), - columns = columns, - skip = skip, - id = id - )) + add_step( + recipe, + step_lag_difference_new( + terms = dplyr::enquos(...), + role = role, + trained = trained, + horizon = horizon, + prefix = prefix, + keys = epi_keys(recipe), + columns = columns, + skip = skip, + id = id + ) + ) } @@ -110,7 +114,7 @@ epi_shift_single_diff <- function(x, col, horizon, newname, key_cols) { dplyr::mutate(time_value = time_value + horizon) %>% dplyr::rename(!!newname := {{ col }}) x <- dplyr::left_join(x, y, by = key_cols) - x[ ,newname] <- x[ ,col] - x[ ,newname] + x[, newname] <- x[, col] - x[, newname] x %>% dplyr::select(tidyselect::all_of(c(key_cols, newname))) } @@ -126,10 +130,13 @@ bake.step_lag_difference <- function(object, new_data, ...) { if (any(intersection)) { rlang::abort( c(paste0("Name collision occured in `", class(object)[1], "`."), - i = paste("The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - ".") - )) + i = paste( + "The following variable names already exists: ", + paste0(new_data_names[intersection], collapse = ", "), + "." + ) + ) + ) } ok <- object$keys @@ -149,8 +156,9 @@ bake.step_lag_difference <- function(object, new_data, ...) { #' @export print.step_lag_difference <- function(x, width = max(20, options()$width - 30), ...) { print_epi_step(x$columns, x$terms, x$trained, - title = "Calculating lag_difference for", - conjunction = "by", - extra_text = x$horizon) + title = "Calculating lag_difference for", + conjunction = "by", + extra_text = x$horizon + ) invisible(x) } diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 529c08e0a..ce87ea759 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -69,13 +69,15 @@ #' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% #' dplyr::select(geo_value, time_value, cases) #' -#' pop_data = data.frame(states = c("ca", "ny"), value = c(20000, 30000)) +#' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) #' #' r <- epi_recipe(jhu) %>% -#' step_population_scaling(df = pop_data, -#' df_pop_col = "value", -#' by = c("geo_value" = "states"), -#' cases, suffix = "_scaled") %>% +#' step_population_scaling( +#' df = pop_data, +#' df_pop_col = "value", +#' by = c("geo_value" = "states"), +#' cases, suffix = "_scaled" +#' ) %>% #' step_epi_lag(cases_scaled, lag = c(0, 7, 14)) %>% #' step_epi_ahead(cases_scaled, ahead = 7, role = "outcome") %>% #' step_epi_naomit() @@ -84,9 +86,11 @@ #' layer_predict() %>% #' layer_threshold(.pred) %>% #' layer_naomit(.pred) %>% -#' layer_population_scaling(.pred, df = pop_data, -#' by = c("geo_value" = "states"), -#' df_pop_col = "value") +#' layer_population_scaling(.pred, +#' df = pop_data, +#' by = c("geo_value" = "states"), +#' df_pop_col = "value" +#' ) #' #' wf <- epi_workflow(r, parsnip::linear_reg()) %>% #' fit(jhu) %>% @@ -95,8 +99,10 @@ #' latest <- get_test_data( #' recipe = r, #' epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter(time_value > "2021-11-01", -#' geo_value %in% c("ca", "ny")) %>% +#' dplyr::filter( +#' time_value > "2021-11-01", +#' geo_value %in% c("ca", "ny") +#' ) %>% #' dplyr::select(geo_value, time_value, cases) #' ) #' @@ -104,43 +110,44 @@ #' predict(wf, latest) step_population_scaling <- function(recipe, - ..., - role = "raw", - trained = FALSE, - df, - by = NULL, - df_pop_col, - rate_rescaling = 1, - create_new = TRUE, - suffix = "_scaled", - columns = NULL, - skip = FALSE, - id = rand_id("population_scaling")){ - arg_is_scalar(role, trained, df_pop_col, rate_rescaling, create_new, suffix, id) - arg_is_lgl(create_new, skip) - arg_is_chr(df_pop_col, suffix, id) - arg_is_chr(by, columns, allow_null = TRUE) - if (rate_rescaling <= 0) - cli_stop("`rate_rescaling` should be a positive number") + ..., + role = "raw", + trained = FALSE, + df, + by = NULL, + df_pop_col, + rate_rescaling = 1, + create_new = TRUE, + suffix = "_scaled", + columns = NULL, + skip = FALSE, + id = rand_id("population_scaling")) { + arg_is_scalar(role, trained, df_pop_col, rate_rescaling, create_new, suffix, id) + arg_is_lgl(create_new, skip) + arg_is_chr(df_pop_col, suffix, id) + arg_is_chr(by, columns, allow_null = TRUE) + if (rate_rescaling <= 0) { + cli_stop("`rate_rescaling` should be a positive number") + } - add_step( - recipe, - step_population_scaling_new( - terms = dplyr::enquos(...), - role = role, - trained = trained, - df = df, - by = by, - df_pop_col = df_pop_col, - rate_rescaling = rate_rescaling, - create_new = create_new, - suffix = suffix, - columns = columns, - skip = skip, - id = id + add_step( + recipe, + step_population_scaling_new( + terms = dplyr::enquos(...), + role = role, + trained = trained, + df = df, + by = by, + df_pop_col = df_pop_col, + rate_rescaling = rate_rescaling, + create_new = create_new, + suffix = suffix, + columns = columns, + skip = skip, + id = id + ) ) - ) -} + } step_population_scaling_new <- function(role, trained, df, by, df_pop_col, rate_rescaling, terms, create_new, @@ -182,17 +189,22 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { #' @export bake.step_population_scaling <- function(object, - new_data, - ...) { - - stopifnot("Only one population column allowed for scaling" = - length(object$df_pop_col) == 1) + new_data, + ...) { + stopifnot( + "Only one population column allowed for scaling" = + length(object$df_pop_col) == 1 + ) try_join <- try(dplyr::left_join(new_data, object$df, by = object$by), - silent = TRUE) + silent = TRUE + ) if (any(grepl("Join columns must be present in data", unlist(try_join)))) { - cli_stop(c("columns in `by` selectors of `step_population_scaling` ", - "must be present in data and match"))} + cli_stop(c( + "columns in `by` selectors of `step_population_scaling` ", + "must be present in data and match" + )) + } if (object$suffix != "_scaled" && object$create_new == FALSE) { cli::cli_warn(c( @@ -204,16 +216,18 @@ bake.step_population_scaling <- function(object, object$df <- object$df %>% dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) - pop_col = rlang::sym(object$df_pop_col) - suffix = ifelse(object$create_new, object$suffix, "") + pop_col <- rlang::sym(object$df_pop_col) + suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(new_data)) dplyr::left_join(new_data, - object$df, - by = object$by, suffix = c("", ".df")) %>% + object$df, + by = object$by, suffix = c("", ".df") + ) %>% dplyr::mutate(dplyr::across(dplyr::all_of(object$columns), - ~.x * object$rate_rescaling /!!pop_col , - .names = "{.col}{suffix}")) %>% + ~ .x * object$rate_rescaling / !!pop_col, + .names = "{.col}{suffix}" + )) %>% # removed so the models do not use the population column dplyr::select(-dplyr::any_of(col_to_remove)) } @@ -221,9 +235,7 @@ bake.step_population_scaling <- function(object, #' @export print.step_population_scaling <- function(x, width = max(20, options()$width - 35), ...) { - title <- "Population scaling" - print_epi_step(x$terms, x$terms, x$trained, title, extra_text = "to rates") - invisible(x) -} - - + title <- "Population scaling" + print_epi_step(x$terms, x$terms, x$trained, title, extra_text = "to rates") + invisible(x) + } diff --git a/R/step_training_window.R b/R/step_training_window.R index a05ad1540..7102d29d8 100644 --- a/R/step_training_window.R +++ b/R/step_training_window.R @@ -28,9 +28,12 @@ #' tib <- tibble::tibble( #' x = 1:10, #' y = 1:10, -#' time_value = rep(seq(as.Date("2020-01-01"), by = 1, -#' length.out = 5), times = 2), -#' geo_value = rep(c("ca", "hi"), each = 5)) %>% +#' time_value = rep(seq(as.Date("2020-01-01"), +#' by = 1, +#' length.out = 5 +#' ), times = 2), +#' geo_value = rep(c("ca", "hi"), each = 5) +#' ) %>% #' as_epi_df() #' #' epi_recipe(y ~ x, data = tib) %>% @@ -50,7 +53,6 @@ step_training_window <- n_recent = 50, epi_keys = NULL, id = rand_id("training_window")) { - arg_is_lgl_scalar(trained) arg_is_scalar(n_recent, id) arg_is_pos(n_recent) @@ -85,7 +87,6 @@ step_training_window_new <- #' @export prep.step_training_window <- function(x, training, info = NULL, ...) { - ekt <- kill_time_value(epi_keys(training)) ek <- x$epi_keys %||% ekt %||% character(0L) @@ -103,7 +104,6 @@ prep.step_training_window <- function(x, training, info = NULL, ...) { #' @export bake.step_training_window <- function(object, new_data, ...) { - hardhat::validate_column_names(new_data, object$epi_keys) if (object$n_recent < Inf) { @@ -121,9 +121,11 @@ bake.step_training_window <- function(object, new_data, ...) { print.step_training_window <- function(x, width = max(20, options()$width - 30), ...) { title <- "# of recent observations per key limited to:" - n_recent = x$n_recent - tr_obj = format_selectors(rlang::enquos(n_recent), width) - recipes::print_step(tr_obj, rlang::enquos(n_recent), - x$trained, title, width) + n_recent <- x$n_recent + tr_obj <- format_selectors(rlang::enquos(n_recent), width) + recipes::print_step( + tr_obj, rlang::enquos(n_recent), + x$trained, title, width + ) invisible(x) } diff --git a/R/utils-arg.R b/R/utils-arg.R index 68d2211c1..091987722 100644 --- a/R/utils-arg.R +++ b/R/utils-arg.R @@ -2,20 +2,21 @@ # http://adv-r.had.co.nz/Computing-on-the-language.html#substitute # Modeled after / copied from rundel/ghclass -handle_arg_list = function(..., tests) { - values = list(...) - names = eval(substitute(alist(...))) - names = map(names, deparse) +handle_arg_list <- function(..., tests) { + values <- list(...) + names <- eval(substitute(alist(...))) + names <- map(names, deparse) walk2(names, values, tests) } -arg_is_scalar = function(..., allow_null = FALSE, allow_na = FALSE) { +arg_is_scalar <- function(..., allow_null = FALSE, allow_na = FALSE) { handle_arg_list( ..., tests = function(name, value) { - if (length(value) > 1 | (!allow_null & length(value) == 0)) + if (length(value) > 1 | (!allow_null & length(value) == 0)) { cli::cli_abort("Argument {.val {name}} must be of length 1.") + } if (!is.null(value)) { if (is.na(value) & !allow_na) { cli::cli_abort( @@ -28,18 +29,22 @@ arg_is_scalar = function(..., allow_null = FALSE, allow_na = FALSE) { } -arg_is_lgl = function(..., allow_null = FALSE, allow_na = FALSE, allow_empty = FALSE) { +arg_is_lgl <- function(..., allow_null = FALSE, allow_na = FALSE, allow_empty = FALSE) { handle_arg_list( ..., tests = function(name, value) { - if (is.null(value) & !allow_null) + if (is.null(value) & !allow_null) { cli::cli_abort("Argument {.val {name}} must be of logical type.") - if (any(is.na(value)) & !allow_na) + } + if (any(is.na(value)) & !allow_na) { cli::cli_abort("Argument {.val {name}} must not contain any missing values ({.val {NA}}).") - if (!is.null(value) & (length(value) == 0 & !allow_empty)) + } + if (!is.null(value) & (length(value) == 0 & !allow_empty)) { cli::cli_abort("Argument {.val {name}} must have length >= 1.") - if (!is.null(value) & length(value) != 0 & !is.logical(value)) + } + if (!is.null(value) & length(value) != 0 & !is.logical(value)) { cli::cli_abort("Argument {.val {name}} must be of logical type.") + } } ) } @@ -49,130 +54,144 @@ arg_is_lgl_scalar <- function(..., allow_null = FALSE, allow_na = FALSE) { arg_is_scalar(..., allow_null = allow_null, allow_na = allow_na) } -arg_is_numeric = function(..., allow_null = FALSE) { +arg_is_numeric <- function(..., allow_null = FALSE) { handle_arg_list( ..., tests = function(name, value) { - if (!(is.numeric(value) | (is.null(value) & allow_null))) + if (!(is.numeric(value) | (is.null(value) & allow_null))) { cli::cli_abort("All {.val {name}} must numeric.") + } } ) } -arg_is_pos = function(..., allow_null = FALSE) { +arg_is_pos <- function(..., allow_null = FALSE) { arg_is_numeric(..., allow_null = allow_null) handle_arg_list( ..., tests = function(name, value) { - if (!(all(value > 0) | (is.null(value) & allow_null))) + if (!(all(value > 0) | (is.null(value) & allow_null))) { cli::cli_abort("All {.val {name}} must be positive number(s).") + } } ) } -arg_is_nonneg = function(..., allow_null = FALSE) { +arg_is_nonneg <- function(..., allow_null = FALSE) { arg_is_numeric(..., allow_null = allow_null) handle_arg_list( ..., tests = function(name, value) { - if (!(all(value >= 0) | (is.null(value) & allow_null))) + if (!(all(value >= 0) | (is.null(value) & allow_null))) { cli::cli_abort("All {.val {name}} must be nonnegative number(s).") + } } ) - } -arg_is_int = function(..., allow_null = FALSE) { +arg_is_int <- function(..., allow_null = FALSE) { arg_is_numeric(..., allow_null = allow_null) handle_arg_list( ..., tests = function(name, value) { - if (!(all(value %% 1 == 0) | (is.null(value) & allow_null))) + if (!(all(value %% 1 == 0) | (is.null(value) & allow_null))) { cli::cli_abort("All {.val {name}} must be whole positive number(s).") + } } ) } -arg_is_pos_int = function(..., allow_null = FALSE) { +arg_is_pos_int <- function(..., allow_null = FALSE) { arg_is_int(..., allow_null = allow_null) arg_is_pos(..., allow_null = allow_null) } -arg_is_nonneg_int = function(..., allow_null = FALSE) { +arg_is_nonneg_int <- function(..., allow_null = FALSE) { arg_is_int(..., allow_null = allow_null) arg_is_nonneg(..., allow_null = allow_null) } -arg_is_date = function(..., allow_null = FALSE, allow_na = FALSE) { +arg_is_date <- function(..., allow_null = FALSE, allow_na = FALSE) { handle_arg_list( ..., tests = function(name, value) { - if (is.null(value) & !allow_null) + if (is.null(value) & !allow_null) { cli::cli_abort("Argument {.val {name}} may not be `NULL`.") - if (any(is.na(value)) & !allow_na) + } + if (any(is.na(value)) & !allow_na) { cli::cli_abort("Argument {.val {name}} must not contain any missing values ({.val {NA}}).") - if (!(is(value, "Date") | is.null(value) | all(is.na(value)))) + } + if (!(is(value, "Date") | is.null(value) | all(is.na(value)))) { cli::cli_abort("Argument {.val {name}} must be a Date. Try `as.Date()`.") + } } ) } -arg_is_probabilities = function(..., allow_null = FALSE) { +arg_is_probabilities <- function(..., allow_null = FALSE) { arg_is_numeric(..., allow_null = allow_null) handle_arg_list( ..., tests = function(name, value) { - if (!((all(value >= 0) && all(value <= 1)) | (is.null(value) & allow_null))) + if (!((all(value >= 0) && all(value <= 1)) | (is.null(value) & allow_null))) { cli::cli_abort("All {.val {name}} must be in [0,1].") + } } ) } -arg_is_chr = function(..., allow_null = FALSE, allow_na = FALSE, allow_empty = FALSE) { +arg_is_chr <- function(..., allow_null = FALSE, allow_na = FALSE, allow_empty = FALSE) { handle_arg_list( ..., tests = function(name, value) { - if (is.null(value) & !allow_null) + if (is.null(value) & !allow_null) { cli::cli_abort("Argument {.val {name}} may not be `NULL`.") - if (any(is.na(value)) & !allow_na) + } + if (any(is.na(value)) & !allow_na) { cli::cli_abort("Argument {.val {name}} must not contain any missing values ({.val {NA}}).") - if (!is.null(value) & (length(value) == 0L & !allow_empty)) + } + if (!is.null(value) & (length(value) == 0L & !allow_empty)) { cli::cli_abort("Argument {.val {name}} must have length > 0.") - if (!(is.character(value) | is.null(value) | all(is.na(value)))) + } + if (!(is.character(value) | is.null(value) | all(is.na(value)))) { cli::cli_abort("Argument {.val {name}} must be of character type.") + } } ) } -arg_is_chr_scalar = function(..., allow_null = FALSE, allow_na = FALSE) { +arg_is_chr_scalar <- function(..., allow_null = FALSE, allow_na = FALSE) { arg_is_chr(..., allow_null = allow_null, allow_na = allow_na) arg_is_scalar(..., allow_null = allow_null, allow_na = allow_na) } -arg_is_function = function(..., allow_null = FALSE) { +arg_is_function <- function(..., allow_null = FALSE) { handle_arg_list( ..., tests = function(name, value) { - if (is.null(value) & !allow_null) + if (is.null(value) & !allow_null) { cli::cli_abort("Argument {.val {name}} must be a function.") - if (!is.null(value) & !is.function(value)) + } + if (!is.null(value) & !is.function(value)) { cli::cli_abort("Argument {.val {name}} must be a function.") + } } ) } -arg_is_sorted = function(..., allow_null = FALSE) { +arg_is_sorted <- function(..., allow_null = FALSE) { handle_arg_list( ..., tests = function(name, value) { - if (is.unsorted(value, na.rm = TRUE) | (is.null(value) & !allow_null)) + if (is.unsorted(value, na.rm = TRUE) | (is.null(value) & !allow_null)) { cli::cli_abort("{.val {name}} must be sorted in increasing order.") - - }) + } + } + ) } diff --git a/R/utils-cli.R b/R/utils-cli.R index 7170d7476..ad43c95eb 100644 --- a/R/utils-cli.R +++ b/R/utils-cli.R @@ -1,20 +1,19 @@ - # Modeled after / copied from rundel/ghclass -cli_glue = function(..., .envir = parent.frame()) { - txt = cli::cli_format_method(cli::cli_text(..., .envir = .envir)) +cli_glue <- function(..., .envir = parent.frame()) { + txt <- cli::cli_format_method(cli::cli_text(..., .envir = .envir)) # cli_format_method does wrapping which we dont want at this stage # so glue things back together. paste(txt, collapse = " ") } -cli_stop = function(..., .envir = parent.frame()) { - text = cli_glue(..., .envir = .envir) +cli_stop <- function(..., .envir = parent.frame()) { + text <- cli_glue(..., .envir = .envir) stop(paste(text, collapse = "\n"), call. = FALSE) } -cli_warn = function(..., .envir = parent.frame()) { - text = cli_glue(..., .envir = .envir) +cli_warn <- function(..., .envir = parent.frame()) { + text <- cli_glue(..., .envir = .envir) warning(paste(text, collapse = "\n"), call. = FALSE) } diff --git a/R/utils-enframer.R b/R/utils-enframer.R index d55a611af..387d04356 100644 --- a/R/utils-enframer.R +++ b/R/utils-enframer.R @@ -2,18 +2,21 @@ enframer <- function(df, x, fill = NA) { stopifnot(is.data.frame(df)) stopifnot(length(fill) == 1 || length(fill) == nrow(df)) arg_is_chr(x, allow_null = TRUE) - if (is.null(x)) return(df) - if (any(names(df) %in% x)) + if (is.null(x)) { + return(df) + } + if (any(names(df) %in% x)) { stop("In enframer: some new cols match existing column names") + } for (v in x) df <- dplyr::mutate(df, !!v := fill) df } enlist <- function(...) { # in epiprocess - x = list(...) - n = as.character(sys.call())[-1] - if (!is.null(n0 <- names(x))) n[n0 != ""] = n0[n0 != ""] - names(x) = n + x <- list(...) + n <- as.character(sys.call())[-1] + if (!is.null(n0 <- names(x))) n[n0 != ""] <- n0[n0 != ""] + names(x) <- n x } diff --git a/R/utils-knn.R b/R/utils-knn.R index 08ddbf6c9..90ac67435 100644 --- a/R/utils-knn.R +++ b/R/utils-knn.R @@ -2,4 +2,4 @@ embedding <- function(dat) { dat <- as.matrix(dat) dat <- dat / sqrt(rowSums(dat^2) + 1e-12) return(dat) -} \ No newline at end of file +} diff --git a/R/utils-misc.R b/R/utils-misc.R index c7c7a69bb..ffc19ab83 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -45,8 +45,8 @@ grab_forged_keys <- function(forged, mold, new_data) { if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) { cli::cli_warn(c( "Not all epi keys that were present in the training data are available", - "in `new_data`. Predictions will have only the available keys.") - ) + "in `new_data`. Predictions will have only the available keys." + )) } if (epiprocess::is_epi_df(new_data)) { extras <- epiprocess::as_epi_df(extras) @@ -60,11 +60,14 @@ grab_forged_keys <- function(forged, mold, new_data) { } get_parsnip_mode <- function(trainer) { - if (inherits(trainer, "model_spec")) return(trainer$mode) + if (inherits(trainer, "model_spec")) { + return(trainer$mode) + } cc <- class(trainer) cli::cli_abort( c("`trainer` must be a `parsnip` model.", - i = "This trainer has class(s) {cc}.") + i = "This trainer has class(s) {cc}." + ) ) } diff --git a/README.Rmd b/README.Rmd index 6b924b03e..7f1e4f168 100644 --- a/README.Rmd +++ b/README.Rmd @@ -75,14 +75,14 @@ To create and train a simple auto-regressive forecaster to predict the death rat ```{r make-forecasts, warning=FALSE} two_week_ahead <- arx_forecaster( - jhu, - outcome = "death_rate", + jhu, + outcome = "death_rate", predictors = c("case_rate", "death_rate"), args_list = arx_args_list( - lags = list(c(0,1,2,3,7,14), c(0,7,14)), + lags = list(c(0, 1, 2, 3, 7, 14), c(0, 7, 14)), ahead = 14 ) -) +) ``` In this case, we have used a number of different lags for the case rate, while only using 3 weekly lags for the death rate (as predictors). The result is both a fitted model object which could be used any time in the future to create different forecasts, as well as a set of predicted values (and prediction intervals) for each location 14 days after the last available time value in the data. @@ -111,12 +111,12 @@ feel very familiar to anyone working in `R`+`{tidyverse}`. **Simple linear autoregressive model with scaling (modular)** ```{r ideal-framework, eval=FALSE} -my_fcaster = new_epi_predictor() %>% +my_fcaster <- new_epi_predictor() %>% add_preprocessor(scaler, var = cases, by = pop) %>% add_preprocessor(lagger, var = dv_cli, lags = c(0, 7, 14)) %>% add_trainer(lm) %>% add_predictor(lm.predict) %>% - add_postprocessor(scaler, by = 1/pop) + add_postprocessor(scaler, by = 1 / pop) ``` Then you could run this on an `epi_df` with one line. diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R index de6b9ffa3..dcd7a1cfe 100644 --- a/tests/testthat/test-arx_args_list.R +++ b/tests/testthat/test-arx_args_list.R @@ -8,7 +8,7 @@ test_that("arx_args checks inputs", { expect_error(arx_args_list(n_training = -1)) expect_error(arx_args_list(n_training = 1.5)) expect_error(arx_args_list(lags = c(-1, 0))) - expect_error(arx_args_list(lags = list(c(1:5,6.5), 2:8))) + expect_error(arx_args_list(lags = list(c(1:5, 6.5), 2:8))) expect_error(arx_args_list(symmetrize = 4)) expect_error(arx_args_list(nonneg = 4)) @@ -53,27 +53,36 @@ test_that("arx forecaster disambiguates quantiles", { }) test_that("arx_lags_validator handles named & unnamed lists as expected", { - # Fully named list of lags in order of predictors pred_vec <- c("death_rate", "case_rate") lags_init_fn <- list(death_rate = c(0, 7, 14), case_rate = c(0, 1, 2, 3, 7, 14)) - expect_equal(arx_lags_validator(pred_vec, lags_init_fn), - lags_init_fn) + expect_equal( + arx_lags_validator(pred_vec, lags_init_fn), + lags_init_fn + ) # Fully named list of lags not in order of predictors lags_finit_fn_switch <- list(case_rate = c(0, 1, 2, 3, 7, 14), death_rate = c(0, 7, 14)) - expect_equal(arx_lags_validator(pred_vec, lags_finit_fn_switch), - list(death_rate = c(0, 7, 14), case_rate = c(0, 1, 2, 3, 7, 14))) + expect_equal( + arx_lags_validator(pred_vec, lags_finit_fn_switch), + list(death_rate = c(0, 7, 14), case_rate = c(0, 1, 2, 3, 7, 14)) + ) # Fully named list of lags not in order of predictors (longer ex.) - pred_vec2 <- c("death_rate", "other_var", "case_rate") - lags_finit_fn_switch2 <- list(case_rate = c(0, 1, 2, 3, 7, 14), death_rate = c(0, 7, 14), - other_var = c(0, 1)) - expect_equal(arx_lags_validator(pred_vec2, lags_finit_fn_switch2), - list(death_rate = c(0, 7, 14), - other_var = c(0, 1), case_rate = c(0, 1, 2, 3, 7, 14))) + pred_vec2 <- c("death_rate", "other_var", "case_rate") + lags_finit_fn_switch2 <- list( + case_rate = c(0, 1, 2, 3, 7, 14), death_rate = c(0, 7, 14), + other_var = c(0, 1) + ) + expect_equal( + arx_lags_validator(pred_vec2, lags_finit_fn_switch2), + list( + death_rate = c(0, 7, 14), + other_var = c(0, 1), case_rate = c(0, 1, 2, 3, 7, 14) + ) + ) # More lags than predictors - Error expect_error(arx_lags_validator(pred_vec, lags_finit_fn_switch2)) @@ -98,6 +107,4 @@ test_that("arx_lags_validator handles named & unnamed lists as expected", { lags_init_other_name <- list(death_rate = c(0, 7, 14), test_var = c(0, 1, 2, 3, 7, 14)) expect_error(arx_lags_validator(pred_vec, lags_init_other_name)) - }) - diff --git a/tests/testthat/test-arx_cargs_list.R b/tests/testthat/test-arx_cargs_list.R index 40035890d..31ed7cd10 100644 --- a/tests/testthat/test-arx_cargs_list.R +++ b/tests/testthat/test-arx_cargs_list.R @@ -8,7 +8,7 @@ test_that("arx_class_args checks inputs", { expect_error(arx_class_args_list(n_training = -1)) expect_error(arx_class_args_list(n_training = 1.5)) expect_error(arx_class_args_list(lags = c(-1, 0))) - expect_error(arx_class_args_list(lags = list(c(1:5,6.5), 2:8))) + expect_error(arx_class_args_list(lags = list(c(1:5, 6.5), 2:8))) expect_error(arx_class_args_list(target_date = "2022-01-01")) @@ -17,4 +17,3 @@ test_that("arx_class_args checks inputs", { as.Date("2022-01-01") ) }) - diff --git a/tests/testthat/test-blueprint.R b/tests/testthat/test-blueprint.R index b16b0e123..2d22aff6e 100644 --- a/tests/testthat/test-blueprint.R +++ b/tests/testthat/test-blueprint.R @@ -20,5 +20,4 @@ test_that("epi_recipe blueprint keeps the class, mold works", { bp <- hardhat:::update_blueprint(bp, recipe = r) run_mm <- run_mold(bp, data = jhu) expect_false(is.factor(run_mm$extras$roles$geo_value$geo_value)) - }) diff --git a/tests/testthat/test-dist_quantiles.R b/tests/testthat/test-dist_quantiles.R index cdefcb7fe..07d1530d2 100644 --- a/tests/testthat/test-dist_quantiles.R +++ b/tests/testthat/test-dist_quantiles.R @@ -4,10 +4,10 @@ test_that("constructor returns reasonable quantiles", { expect_error(new_quantiles(rnorm(5), rnorm(5))) expect_silent(new_quantiles(sort(rnorm(5)), sort(runif(5)))) expect_error(new_quantiles(sort(rnorm(5)), sort(runif(2)))) - expect_silent(new_quantiles(1:5, 1:5/10)) - expect_error(new_quantiles(c(2,1,3,4,5), c(.1,.1,.2,.5,.8))) - expect_error(new_quantiles(c(2,1,3,4,5), c(.1,.15,.2,.5,.8))) - expect_error(new_quantiles(c(1,2,3), c(.1, .2, 3))) + expect_silent(new_quantiles(1:5, 1:5 / 10)) + expect_error(new_quantiles(c(2, 1, 3, 4, 5), c(.1, .1, .2, .5, .8))) + expect_error(new_quantiles(c(2, 1, 3, 4, 5), c(.1, .15, .2, .5, .8))) + expect_error(new_quantiles(c(1, 2, 3), c(.1, .2, 3))) }) test_that("tail functions give reasonable output", { @@ -30,11 +30,11 @@ test_that("single dist_quantiles works, quantiles are accessible", { expect_equal(quantile(z, c(.3, .7), middle = "cubic"), Q(c(.3, .7))) expect_identical( extrapolate_quantiles(z, c(.3, .7), middle = "linear"), - new_quantiles(q = c(1,1.5,2,3,4,4.5,5), tau = 2:8/10)) + new_quantiles(q = c(1, 1.5, 2, 3, 4, 4.5, 5), tau = 2:8 / 10) + ) }) test_that("quantile extrapolator works", { - dstn <- dist_normal(c(10, 2), c(5, 10)) qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) expect_s3_class(qq, "distribution") @@ -42,7 +42,7 @@ test_that("quantile extrapolator works", { expect_length(parameters(qq[1])$q[[1]], 3L) - dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2,.4,.6,.8))) + dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) qq <- extrapolate_quantiles(dstn, p = c(.25, 0.5, .75)) expect_s3_class(qq, "distribution") expect_s3_class(vctrs::vec_data(qq[1])[[1]], "dist_quantiles") @@ -50,23 +50,23 @@ test_that("quantile extrapolator works", { }) test_that("unary math works on quantiles", { - dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2,.4,.6,.8))) - dstn2 <- dist_quantiles(list(log(1:4), log(8:11)), list(c(.2,.4,.6,.8))) + dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) + dstn2 <- dist_quantiles(list(log(1:4), log(8:11)), list(c(.2, .4, .6, .8))) expect_identical(log(dstn), dstn2) - dstn2 <- dist_quantiles(list(cumsum(1:4), cumsum(8:11)), list(c(.2,.4,.6,.8))) + dstn2 <- dist_quantiles(list(cumsum(1:4), cumsum(8:11)), list(c(.2, .4, .6, .8))) expect_identical(cumsum(dstn), dstn2) }) test_that("arithmetic works on quantiles", { - dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2,.4,.6,.8))) - dstn2 <- dist_quantiles(list(1:4+1, 8:11+1), list(c(.2,.4,.6,.8))) + dstn <- dist_quantiles(list(1:4, 8:11), list(c(.2, .4, .6, .8))) + dstn2 <- dist_quantiles(list(1:4 + 1, 8:11 + 1), list(c(.2, .4, .6, .8))) expect_identical(dstn + 1, dstn2) expect_identical(1 + dstn, dstn2) - dstn2 <- dist_quantiles(list(1:4 / 4, 8:11 / 4), list(c(.2,.4,.6,.8))) + dstn2 <- dist_quantiles(list(1:4 / 4, 8:11 / 4), list(c(.2, .4, .6, .8))) expect_identical(dstn / 4, dstn2) - expect_identical((1/4) * dstn, dstn2) + expect_identical((1 / 4) * dstn, dstn2) expect_error(sum(dstn)) expect_error(suppressWarnings(dstn + distributional::dist_normal())) diff --git a/tests/testthat/test-enframer.R b/tests/testthat/test-enframer.R index bf5d730a9..c555ea9b2 100644 --- a/tests/testthat/test-enframer.R +++ b/tests/testthat/test-enframer.R @@ -1,12 +1,13 @@ test_that("enframer errors/works as needed", { - template1 <- data.frame(aa = 1:5, a=NA, b=NA, c=NA) - template2 <- data.frame(aa = 1:5, a=2:6, b=2:6, c=2:6) + template1 <- data.frame(aa = 1:5, a = NA, b = NA, c = NA) + template2 <- data.frame(aa = 1:5, a = 2:6, b = 2:6, c = 2:6) expect_error(enframer(1:5, letters[1])) expect_error(enframer(data.frame(a = 1:5), 1:3)) expect_error(enframer(data.frame(a = 1:5), letters[1:3])) expect_identical(enframer(data.frame(aa = 1:5), letters[1:3]), template1) - expect_error(enframer(data.frame(aa = 1:5), letters[1:2], fill=1:4)) + expect_error(enframer(data.frame(aa = 1:5), letters[1:2], fill = 1:4)) expect_identical( - enframer(data.frame(aa = 1:5), letters[1:3], fill=2:6), - template2) + enframer(data.frame(aa = 1:5), letters[1:3], fill = 2:6), + template2 + ) }) diff --git a/tests/testthat/test-epi_keys.R b/tests/testthat/test-epi_keys.R index c960f1ed4..3e794542e 100644 --- a/tests/testthat/test-epi_keys.R +++ b/tests/testthat/test-epi_keys.R @@ -15,7 +15,7 @@ test_that("epi_keys returns possible keys if they exist", { test_that("Extracts keys from an epi_df", { - expect_equal(epi_keys(case_death_rate_subset), c("time_value","geo_value")) + expect_equal(epi_keys(case_death_rate_subset), c("time_value", "geo_value")) }) test_that("Extracts keys from a recipe; roles are NA, giving an empty vector", { @@ -34,15 +34,18 @@ test_that("epi_keys_mold extracts time_value and geo_value, but not raw", { add_model(linear_reg()) %>% fit(data = case_death_rate_subset) - expect_setequal(epi_keys_mold(my_workflow$pre$mold), - c("time_value","geo_value")) + expect_setequal( + epi_keys_mold(my_workflow$pre$mold), + c("time_value", "geo_value") + ) }) test_that("epi_keys_mold extracts additional keys when they are present", { my_data <- tibble::tibble( geo_value = rep(c("ca", "fl", "pa"), each = 3), time_value = rep(seq(as.Date("2020-06-01"), as.Date("2020-06-03"), - by = "day"), length.out = length(geo_value)), + by = "day" + ), length.out = length(geo_value)), pol = rep(c("blue", "swing", "swing"), each = 3), # extra key state = rep(c("ca", "fl", "pa"), each = 3), # extra key value = 1:length(geo_value) + 0.01 * rnorm(length(geo_value)) @@ -52,12 +55,13 @@ test_that("epi_keys_mold extracts additional keys when they are present", { ) my_recipe <- epi_recipe(my_data) %>% - step_epi_ahead(value , ahead = 7) %>% + step_epi_ahead(value, ahead = 7) %>% step_epi_naomit() my_workflow <- epi_workflow(my_recipe, linear_reg()) %>% fit(my_data) expect_setequal( epi_keys_mold(my_workflow$pre$mold), - c("time_value", "geo_value", "state", "pol")) + c("time_value", "geo_value", "state", "pol") + ) }) diff --git a/tests/testthat/test-epi_recipe.R b/tests/testthat/test-epi_recipe.R index f74221691..df169adda 100644 --- a/tests/testthat/test-epi_recipe.R +++ b/tests/testthat/test-epi_recipe.R @@ -1,5 +1,3 @@ - - test_that("epi_recipe produces default recipe", { # these all call recipes::recipe(), but the template will always have 1 row tib <- tibble( @@ -7,25 +5,23 @@ test_that("epi_recipe produces default recipe", { time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5) ) rec <- recipes::recipe(tib) - rec$template <- rec$template[1,] + rec$template <- rec$template[1, ] expect_identical(rec, epi_recipe(tib)) expect_equal(nrow(rec$template), 1L) - rec <- recipes::recipe(y~x, tib) - rec$template <- rec$template[1,] + rec <- recipes::recipe(y ~ x, tib) + rec$template <- rec$template[1, ] expect_identical(rec, epi_recipe(y ~ x, tib)) expect_equal(nrow(rec$template), 1L) m <- as.matrix(tib) rec <- recipes::recipe(m) - rec$template <- rec$template[1,] + rec$template <- rec$template[1, ] expect_identical(rec, epi_recipe(m)) expect_equal(nrow(rec$template), 1L) - }) test_that("epi_recipe formula works", { - tib <- tibble( x = 1:5, y = 1:5, time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5), @@ -35,7 +31,7 @@ test_that("epi_recipe formula works", { # simple case r <- epi_recipe(y ~ x, tib) ref_var_info <- tibble::tribble( - ~ variable, ~ type, ~ role, ~ source, + ~variable, ~type, ~role, ~source, "x", c("integer", "numeric"), "predictor", "original", "y", c("integer", "numeric"), "outcome", "original", "time_value", "date", "time_value", "original", @@ -50,7 +46,8 @@ test_that("epi_recipe formula works", { tibble::add_row( variable = "geo_value", type = list(c("string", "unordered", "nominal")), role = "predictor", - source = "original", .after = 1) + source = "original", .after = 1 + ) expect_identical(r$var_info, ref_var_info) expect_equal(nrow(r$template), 1L) @@ -67,14 +64,13 @@ test_that("epi_recipe formula works", { tibble::add_row( variable = "z", type = list(c("string", "unordered", "nominal")), role = "key", - source = "original") + source = "original" + ) expect_identical(r$var_info, ref_var_info) - }) test_that("epi_recipe epi_df works", { - tib <- tibble( x = 1:5, y = 1:5, time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5), @@ -83,7 +79,7 @@ test_that("epi_recipe epi_df works", { r <- epi_recipe(tib) ref_var_info <- tibble::tribble( - ~ variable, ~ type, ~ role, ~ source, + ~variable, ~type, ~role, ~source, "time_value", "date", "time_value", "original", "geo_value", c("string", "unordered", "nominal"), "geo_value", "original", "x", c("integer", "numeric"), "raw", "original", @@ -94,7 +90,7 @@ test_that("epi_recipe epi_df works", { r <- epi_recipe(tib, formula = y ~ x) ref_var_info <- tibble::tribble( - ~ variable, ~ type, ~ role, ~ source, + ~variable, ~type, ~role, ~source, "x", c("integer", "numeric"), "predictor", "original", "y", c("integer", "numeric"), "outcome", "original", "time_value", "date", "time_value", "original", @@ -116,5 +112,3 @@ test_that("epi_recipe epi_df works", { expect_identical(r$var_info, ref_var_info) expect_equal(nrow(r$template), 1L) }) - - diff --git a/tests/testthat/test-epi_shift.R b/tests/testthat/test-epi_shift.R index 89e2a4c8b..b0ab3a21f 100644 --- a/tests/testthat/test-epi_shift.R +++ b/tests/testthat/test-epi_shift.R @@ -1,5 +1,5 @@ x <- data.frame(x1 = 1:10, x2 = -10:-1) -lags <- list(c(0,4), 1:3) +lags <- list(c(0, 4), 1:3) test_that("epi shift works with NULL keys", { time_value <- 1:10 @@ -10,7 +10,7 @@ test_that("epi shift works with NULL keys", { }) test_that("epi shift works with groups", { - keys <- data.frame(a = rep(letters[1:2], each=5), b = "z") + keys <- data.frame(a = rep(letters[1:2], each = 5), b = "z") time_value <- 1:10 out <- epi_shift(x, lags, time_value, keys) expect_length(out, 8L) @@ -27,5 +27,4 @@ test_that("epi shift single works, renames", { ess <- epi_shift_single(tib, "x", 1, "test", epi_keys(tib)) expect_named(ess, c("time_value", "geo_value", "test")) expect_equal(ess$time_value, tib$time_value + 1) - }) diff --git a/tests/testthat/test-epi_workflow.R b/tests/testthat/test-epi_workflow.R index 63f44f869..41708708a 100644 --- a/tests/testthat/test-epi_workflow.R +++ b/tests/testthat/test-epi_workflow.R @@ -1,4 +1,3 @@ - test_that("postprocesser was evaluated", { r <- epi_recipe(case_death_rate_subset) s <- parsnip::linear_reg() @@ -30,5 +29,5 @@ test_that("outcome of the two methods are the same", { ef <- epi_workflow(r, s, f) ef2 <- epi_workflow(r, s) %>% add_frosting(f) - expect_equal(ef,ef2) + expect_equal(ef, ef2) }) diff --git a/tests/testthat/test-extract_argument.R b/tests/testthat/test-extract_argument.R index f9de817de..974a50888 100644 --- a/tests/testthat/test-extract_argument.R +++ b/tests/testthat/test-extract_argument.R @@ -8,18 +8,21 @@ test_that("layer argument extractor works", { expect_error(extract_argument(f$layers[[1]], "layer_predict", "bubble")) expect_identical( extract_argument(f$layers[[2]], "layer_residual_quantiles", "probs"), - c(0.0275, 0.9750)) + c(0.0275, 0.9750) + ) expect_error(extract_argument(f, "layer_thresh", "probs")) expect_identical( extract_argument(f, "layer_residual_quantiles", "probs"), - c(0.0275, 0.9750)) + c(0.0275, 0.9750) + ) wf <- epi_workflow(postprocessor = f) expect_error(extract_argument(epi_workflow(), "layer_residual_quantiles", "probs")) expect_identical( extract_argument(wf, "layer_residual_quantiles", "probs"), - c(0.0275, 0.9750)) + c(0.0275, 0.9750) + ) expect_error(extract_argument(wf, "layer_predict", c("type", "opts"))) }) @@ -46,16 +49,13 @@ test_that("recipe argument extractor works", { expect_error(extract_argument(r, "step_lightly", "probs")) expect_identical( extract_argument(r, "step_epi_lag", "lag"), - list(c(0,7,14), c(0,7,14)) + list(c(0, 7, 14), c(0, 7, 14)) ) wf <- epi_workflow(preprocessor = r) expect_error(extract_argument(epi_workflow(), "step_epi_lag", "lag")) expect_identical( extract_argument(wf, "step_epi_lag", "lag"), - list(c(0,7,14), c(0,7,14)) + list(c(0, 7, 14), c(0, 7, 14)) ) }) - - - diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index d5cec1c4d..77674f4e5 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -26,7 +26,6 @@ test_that("frosting can be created/added/removed", { test_that("prediction works without any postprocessor", { - jhu <- case_death_rate_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) r <- epi_recipe(jhu) %>% @@ -49,7 +48,6 @@ test_that("prediction works without any postprocessor", { test_that("layer_predict is added by default if missing", { - jhu <- case_death_rate_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) @@ -75,6 +73,4 @@ test_that("layer_predict is added by default if missing", { wf2 <- wf %>% add_frosting(f2) expect_equal(predict(wf1, latest), predict(wf2, latest)) - }) - diff --git a/tests/testthat/test-get_test_data.R b/tests/testthat/test-get_test_data.R index 535830df9..035fc6463 100644 --- a/tests/testthat/test-get_test_data.R +++ b/tests/testthat/test-get_test_data.R @@ -9,8 +9,10 @@ test_that("return expected number of rows and returned dataset is ungrouped", { test <- get_test_data(recipe = r, x = case_death_rate_subset) - expect_equal(nrow(test), - dplyr::n_distinct(case_death_rate_subset$geo_value) * 29) + expect_equal( + nrow(test), + dplyr::n_distinct(case_death_rate_subset$geo_value) * 29 + ) expect_false(dplyr::is.grouped_df(test)) }) @@ -28,7 +30,7 @@ test_that("expect insufficient training data error", { test_that("expect error that geo_value or time_value does not exist", { - r <- epi_recipe(case_death_rate_subset) %>% + r <- epi_recipe(case_death_rate_subset) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% @@ -42,50 +44,49 @@ test_that("expect error that geo_value or time_value does not exist", { test_that("NA fill behaves as desired", { - df <- tibble::tibble( - geo_value = rep(c("ca", "ny"), each = 10), - time_value = rep(1:10, times = 2), - x1 = rnorm(20), - x2 = rnorm(20)) %>% - epiprocess::as_epi_df() - - r <- epi_recipe(df) %>% - step_epi_ahead(x1, ahead = 3) %>% - step_epi_lag(x1, x2, lag = c(1, 3)) %>% - step_epi_naomit() - - expect_silent(tt <- get_test_data(r, df)) - expect_s3_class(tt, "epi_df") + df <- tibble::tibble( + geo_value = rep(c("ca", "ny"), each = 10), + time_value = rep(1:10, times = 2), + x1 = rnorm(20), + x2 = rnorm(20) + ) %>% + epiprocess::as_epi_df() - expect_error(get_test_data(r, df, "A")) - expect_error(get_test_data(r, df, TRUE, -3)) + r <- epi_recipe(df) %>% + step_epi_ahead(x1, ahead = 3) %>% + step_epi_lag(x1, x2, lag = c(1, 3)) %>% + step_epi_naomit() - df2 <- df - df2$x1[df2$geo_value == "ca"] <- NA + expect_silent(tt <- get_test_data(r, df)) + expect_s3_class(tt, "epi_df") - td <- get_test_data(r, df2) - expect_true(any(is.na(td))) - expect_error(get_test_data(r, df2, TRUE)) + expect_error(get_test_data(r, df, "A")) + expect_error(get_test_data(r, df, TRUE, -3)) - df1 <- df2 - df1$x1[1:4] <- 1:4 - td1 <- get_test_data(r, df1, TRUE, n_recent = 7) - expect_true(!any(is.na(td1))) + df2 <- df + df2$x1[df2$geo_value == "ca"] <- NA - df2$x1[7:8] <- 1:2 - td2 <- get_test_data(r, df2, TRUE) - expect_true(!any(is.na(td2))) + td <- get_test_data(r, df2) + expect_true(any(is.na(td))) + expect_error(get_test_data(r, df2, TRUE)) + df1 <- df2 + df1$x1[1:4] <- 1:4 + td1 <- get_test_data(r, df1, TRUE, n_recent = 7) + expect_true(!any(is.na(td1))) + df2$x1[7:8] <- 1:2 + td2 <- get_test_data(r, df2, TRUE) + expect_true(!any(is.na(td2))) }) test_that("forecast date behaves", { - df <- tibble::tibble( geo_value = rep(c("ca", "ny"), each = 10), time_value = rep(1:10, times = 2), x1 = rnorm(20), - x2 = rnorm(20)) %>% + x2 = rnorm(20) + ) %>% epiprocess::as_epi_df() r <- epi_recipe(df) %>% @@ -109,8 +110,10 @@ test_that("Omit end rows according to minimum lag when that’s not lag 0", { # Simple toy ex toy_epi_df <- tibble::tibble( - time_value = seq(as.Date("2020-01-01"), by = 1, - length.out = 10), + time_value = seq(as.Date("2020-01-01"), + by = 1, + length.out = 10 + ), geo_value = "ak", x = 1:10 ) %>% epiprocess::as_epi_df() @@ -127,8 +130,8 @@ test_that("Omit end rows according to minimum lag when that’s not lag 0", { expect_equal(ncol(toy_td_res), 6L) expect_equal(nrow(toy_td_res), 1L) expect_equal(toy_td_res$time_value, as.Date("2020-01-10")) - expect_equal(toy_epi_df[toy_epi_df$time_value == as.Date("2020-01-08"),]$x, toy_td_res$lag_2_x) - expect_equal(toy_epi_df[toy_epi_df$time_value == as.Date("2020-01-06"),]$x, toy_td_res$lag_4_x) + expect_equal(toy_epi_df[toy_epi_df$time_value == as.Date("2020-01-08"), ]$x, toy_td_res$lag_2_x) + expect_equal(toy_epi_df[toy_epi_df$time_value == as.Date("2020-01-06"), ]$x, toy_td_res$lag_4_x) expect_equal(toy_td_res$x, NA_integer_) expect_equal(toy_td_res$ahead_3_x, NA_integer_) @@ -145,12 +148,12 @@ test_that("Omit end rows according to minimum lag when that’s not lag 0", { td <- get_test_data(rec, ca) td_res <- bake(prep(rec, ca), td) - td_row1to5_res <- bake(prep(rec, ca), td[1:5, ]) + td_row1to5_res <- bake(prep(rec, ca), td[1:5, ]) expect_equal(td_res, td_row1to5_res) expect_equal(nrow(td_res), 1L) expect_equal(td_res$time_value, as.Date("2021-12-31")) - expect_equal(ca[ca$time_value == as.Date("2021-12-29"),]$case_rate, td_res$lag_2_case_rate) - expect_equal(ca[ca$time_value == as.Date("2021-12-27"),]$case_rate, td_res$lag_4_case_rate) - expect_equal(ca[ca$time_value == as.Date("2021-12-25"),]$case_rate, td_res$lag_6_case_rate) + expect_equal(ca[ca$time_value == as.Date("2021-12-29"), ]$case_rate, td_res$lag_2_case_rate) + expect_equal(ca[ca$time_value == as.Date("2021-12-27"), ]$case_rate, td_res$lag_4_case_rate) + expect_equal(ca[ca$time_value == as.Date("2021-12-25"), ]$case_rate, td_res$lag_6_case_rate) }) diff --git a/tests/testthat/test-grab_names.R b/tests/testthat/test-grab_names.R index 2e7954ab3..6e0376f5a 100644 --- a/tests/testthat/test-grab_names.R +++ b/tests/testthat/test-grab_names.R @@ -1,7 +1,8 @@ -df <- data.frame(b=1,c=2,ca=3,cat=4) +df <- data.frame(b = 1, c = 2, ca = 3, cat = 4) test_that("Names are grabbed properly", { - expect_identical(grab_names(df,dplyr::starts_with("ca")), - subset(names(df),startsWith(names(df), "ca")) - ) + expect_identical( + grab_names(df, dplyr::starts_with("ca")), + subset(names(df), startsWith(names(df), "ca")) + ) }) diff --git a/tests/testthat/test-layer_add_forecast_date.R b/tests/testthat/test-layer_add_forecast_date.R index 5d965e7b3..1830118dc 100644 --- a/tests/testthat/test-layer_add_forecast_date.R +++ b/tests/testthat/test-layer_add_forecast_date.R @@ -9,18 +9,17 @@ wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14) -test_that("layer validation works",{ +test_that("layer validation works", { f <- frosting() expect_error(layer_add_forecast_date(f, "a")) expect_error(layer_add_forecast_date(f, "2022-05-31", id = c("a", "b"))) expect_silent(layer_add_forecast_date(f, "2022-05-31")) expect_silent(layer_add_forecast_date(f)) expect_silent(layer_add_forecast_date(f, as.Date("2022-05-31"))) - expect_silent(layer_add_forecast_date(f, as.Date("2022-05-31"), id="a")) + expect_silent(layer_add_forecast_date(f, as.Date("2022-05-31"), id = "a")) }) test_that("Specify a `forecast_date` that is greater than or equal to `as_of` date", { - f <- frosting() %>% layer_predict() %>% layer_add_forecast_date(forecast_date = as.Date("2022-05-31")) %>% @@ -36,15 +35,16 @@ test_that("Specify a `forecast_date` that is greater than or equal to `as_of` da }) test_that("Specify a `forecast_date` that is less than `as_of` date", { - f2 <- frosting() %>% layer_predict() %>% layer_add_forecast_date(forecast_date = as.Date("2021-12-31")) %>% layer_naomit(.pred) wf2 <- wf %>% add_frosting(f2) - expect_warning(p2 <- predict(wf2, latest), - "forecast_date is less than the most recent update date of the data.") + expect_warning( + p2 <- predict(wf2, latest), + "forecast_date is less than the most recent update date of the data." + ) expect_equal(ncol(p2), 4L) expect_s3_class(p2, "epi_df") expect_equal(nrow(p2), 3L) @@ -53,15 +53,16 @@ test_that("Specify a `forecast_date` that is less than `as_of` date", { }) test_that("Do not specify a forecast_date in `layer_add_forecast_date()`", { - f3 <- frosting() %>% layer_predict() %>% layer_add_forecast_date() %>% layer_naomit(.pred) wf3 <- wf %>% add_frosting(f3) - expect_warning(p3 <- predict(wf3, latest), - "forecast_date is less than the most recent update date of the data.") + expect_warning( + p3 <- predict(wf3, latest), + "forecast_date is less than the most recent update date of the data." + ) expect_equal(ncol(p3), 4L) expect_s3_class(p3, "epi_df") expect_equal(nrow(p3), 3L) diff --git a/tests/testthat/test-layer_add_target_date.R b/tests/testthat/test-layer_add_target_date.R index b8627571c..287956612 100644 --- a/tests/testthat/test-layer_add_target_date.R +++ b/tests/testthat/test-layer_add_target_date.R @@ -10,7 +10,6 @@ latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14) test_that("Use ahead + max time value from pre, fit, post", { - f <- frosting() %>% layer_predict() %>% layer_add_target_date() %>% @@ -38,11 +37,9 @@ test_that("Use ahead + max time value from pre, fit, post", { expect_equal(nrow(p2), 3L) expect_equal(p2$target_date, rep(as.Date("2022-01-07"), times = 3)) expect_named(p2, c("geo_value", "time_value", ".pred", "forecast_date", "target_date")) - }) test_that("Use ahead + specified forecast date", { - f <- frosting() %>% layer_predict() %>% layer_add_forecast_date(forecast_date = "2022-05-31") %>% @@ -56,11 +53,9 @@ test_that("Use ahead + specified forecast date", { expect_equal(nrow(p), 3L) expect_equal(p$target_date, rep(as.Date("2022-06-07"), times = 3)) expect_named(p, c("geo_value", "time_value", ".pred", "forecast_date", "target_date")) - }) test_that("Specify own target date", { - # No forecast date layer f <- frosting() %>% layer_predict() %>% diff --git a/tests/testthat/test-layer_naomit.R b/tests/testthat/test-layer_naomit.R index b7ba2eac6..1d5b4ee25 100644 --- a/tests/testthat/test-layer_naomit.R +++ b/tests/testthat/test-layer_naomit.R @@ -1,11 +1,11 @@ jhu <- case_death_rate_subset %>% - dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) r <- epi_recipe(jhu) %>% - step_epi_lag(death_rate, lag = c(0, 7, 14, 30)) %>% - step_epi_ahead(death_rate, ahead = 7) %>% - recipes::step_naomit(all_predictors()) %>% - recipes::step_naomit(all_outcomes(), skip = TRUE) + step_epi_lag(death_rate, lag = c(0, 7, 14, 30)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + recipes::step_naomit(all_predictors()) %>% + recipes::step_naomit(all_outcomes(), skip = TRUE) wf <- epipredict::epi_workflow(r, parsnip::linear_reg()) %>% parsnip::fit(jhu) @@ -24,7 +24,5 @@ test_that("Removing NA after predict", { expect_silent(p <- predict(wf1, latest)) expect_s3_class(p, "epi_df") expect_equal(nrow(p), 2L) # ak is NA so removed - expect_named(p, c("geo_value", "time_value",".pred")) + expect_named(p, c("geo_value", "time_value", ".pred")) }) - - diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index f98bec2a0..bd10de08c 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -11,7 +11,6 @@ latest <- jhu %>% test_that("predict layer works alone", { - f <- frosting() %>% layer_predict() wf1 <- wf %>% add_frosting(f) @@ -23,7 +22,6 @@ test_that("predict layer works alone", { }) test_that("prediction with interval works", { - f <- frosting() %>% layer_predict(type = "pred_int") wf2 <- wf %>% add_frosting(f) diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index bb1e74fbe..967eee1a5 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -22,7 +22,7 @@ test_that("Returns expected number or rows and columns", { expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") expect_equal(nrow(p), 3L) - expect_named(p, c("geo_value", "time_value",".pred",".pred_distn")) + expect_named(p, c("geo_value", "time_value", ".pred", ".pred_distn")) nested <- p %>% dplyr::mutate(.quantiles = nested_quantiles(.pred_distn)) unnested <- nested %>% tidyr::unnest(.quantiles) diff --git a/tests/testthat/test-layer_threshold_preds.R b/tests/testthat/test-layer_threshold_preds.R index 56787763f..80b6a42a9 100644 --- a/tests/testthat/test-layer_threshold_preds.R +++ b/tests/testthat/test-layer_threshold_preds.R @@ -10,7 +10,6 @@ latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14) test_that("Default pred_lower and pred_upper work as intended", { - f <- frosting() %>% layer_predict() %>% layer_threshold(.pred) %>% @@ -27,7 +26,6 @@ test_that("Default pred_lower and pred_upper work as intended", { }) test_that("Specified pred_lower and pred_upper work as intended", { - f <- frosting() %>% layer_predict() %>% layer_threshold(.pred, lower = 0.180, upper = 0.31) %>% @@ -43,7 +41,6 @@ test_that("Specified pred_lower and pred_upper work as intended", { }) test_that("thresholds additional columns", { - f <- frosting() %>% layer_predict() %>% layer_residual_quantiles(probs = c(.1, .9)) %>% @@ -62,5 +59,5 @@ test_that("thresholds additional columns", { dplyr::mutate(.quantiles = nested_quantiles(.pred_distn)) %>% tidyr::unnest(.quantiles) expect_equal(round(p$q, digits = 3), c(0.180, 0.31, 0.180, .18, 0.310, .31)) - expect_equal(p$tau, rep(c(.1,.9), times = 3)) + expect_equal(p$tau, rep(c(.1, .9), times = 3)) }) diff --git a/tests/testthat/test-pad_to_end.R b/tests/testthat/test-pad_to_end.R index 43c73c291..474b9001b 100644 --- a/tests/testthat/test-pad_to_end.R +++ b/tests/testthat/test-pad_to_end.R @@ -30,6 +30,8 @@ test_that("test set padding works", { expect_identical(p$value, as.integer(c(1, 3, 4, 6, 2, NA, 5, 7))) # make sure it maintains the epi_df - dat <- dat %>% dplyr::rename(geo_value = gr1) %>% as_epi_df(dat) + dat <- dat %>% + dplyr::rename(geo_value = gr1) %>% + as_epi_df(dat) expect_s3_class(pad_to_end(dat, "geo_value", 2), "epi_df") }) diff --git a/tests/testthat/test-pivot_quantiles.R b/tests/testthat/test-pivot_quantiles.R index a77825493..85694aace 100644 --- a/tests/testthat/test-pivot_quantiles.R +++ b/tests/testthat/test-pivot_quantiles.R @@ -6,7 +6,7 @@ test_that("quantile pivotting behaves", { d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:5, 1:4 / 5)) # different quantiles - tib <- tib[1:2,] + tib <- tib[1:2, ] tib$d1 <- d1 expect_error(pivot_quantiles(tib, d1)) diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index c44c3dec5..165d042a3 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -4,49 +4,61 @@ library(workflows) ## Preprocessing test_that("Column names can be passed with and without the tidy way", { - pop_data = data.frame(states = c("ak","al","ar","as","az","ca"), - value = c(1000, 2000, 3000, 4000, 5000, 6000)) + pop_data <- data.frame( + states = c("ak", "al", "ar", "as", "az", "ca"), + value = c(1000, 2000, 3000, 4000, 5000, 6000) + ) - newdata = case_death_rate_subset %>% filter(geo_value %in% c("ak","al","ar","as","az","ca")) + newdata <- case_death_rate_subset %>% filter(geo_value %in% c("ak", "al", "ar", "as", "az", "ca")) r1 <- epi_recipe(newdata) %>% step_population_scaling(c("case_rate", "death_rate"), - df = pop_data, - df_pop_col = "value", by = c("geo_value" = "states")) + df = pop_data, + df_pop_col = "value", by = c("geo_value" = "states") + ) r2 <- epi_recipe(newdata) %>% step_population_scaling(case_rate, death_rate, - df = pop_data, - df_pop_col = "value", by = c("geo_value" = "states")) + df = pop_data, + df_pop_col = "value", by = c("geo_value" = "states") + ) prep1 <- prep(r1, newdata) prep2 <- prep(r2, newdata) expect_equal(bake(prep1, newdata), bake(prep2, newdata)) - }) test_that("Number of columns and column names returned correctly, Upper and lower cases handled properly ", { - pop_data = data.frame(states = c(rep("a",5), rep("B", 5)), - counties = c("06059","06061","06067", - "12111","12113","12117", - "42101","42103","42105", "42111"), - value = 1000:1009) - - newdata = tibble(geo_value = c(rep("a",5), rep("b", 5)), - county = c("06059","06061","06067", - "12111","12113","12117", - "42101","42103","42105", "42111"), - time_value = rep(as.Date("2021-01-01") + 0:4, 2), - case = 1:10, - death = 1:10) %>% + pop_data <- data.frame( + states = c(rep("a", 5), rep("B", 5)), + counties = c( + "06059", "06061", "06067", + "12111", "12113", "12117", + "42101", "42103", "42105", "42111" + ), + value = 1000:1009 + ) + + newdata <- tibble( + geo_value = c(rep("a", 5), rep("b", 5)), + county = c( + "06059", "06061", "06067", + "12111", "12113", "12117", + "42101", "42103", "42105", "42111" + ), + time_value = rep(as.Date("2021-01-01") + 0:4, 2), + case = 1:10, + death = 1:10 + ) %>% epiprocess::as_epi_df() - r <-epi_recipe(newdata) %>% + r <- epi_recipe(newdata) %>% step_population_scaling(c("case", "death"), - df = pop_data, - df_pop_col = "value", by = c("geo_value" = "states", "county" = "counties"), - suffix = "_rate") + df = pop_data, + df_pop_col = "value", by = c("geo_value" = "states", "county" = "counties"), + suffix = "_rate" + ) prep <- prep(r, newdata) @@ -57,19 +69,20 @@ test_that("Number of columns and column names returned correctly, Upper and lowe - r <-epi_recipe(newdata) %>% - step_population_scaling(df = pop_data, - df_pop_col = "value", - by = c("geo_value" = "states", "county" = "counties"), - c("case", "death"), - suffix = "_rate", # unused - create_new = FALSE) + r <- epi_recipe(newdata) %>% + step_population_scaling( + df = pop_data, + df_pop_col = "value", + by = c("geo_value" = "states", "county" = "counties"), + c("case", "death"), + suffix = "_rate", # unused + create_new = FALSE + ) expect_warning(prep <- prep(r, newdata)) expect_warning(b <- bake(prep, newdata)) expect_equal(ncol(b), 5L) - }) ## Postprocessing @@ -78,16 +91,19 @@ test_that("Postprocessing workflow works and values correct", { dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, cases) - pop_data = data.frame(states = c("ca", "ny"), - value = c(20000, 30000)) + pop_data <- data.frame( + states = c("ca", "ny"), + value = c(20000, 30000) + ) r <- epi_recipe(jhu) %>% step_population_scaling(cases, - df = pop_data, - df_pop_col = "value", - by = c("geo_value" = "states"), - role = "raw", - suffix = "_scaled") %>% + df = pop_data, + df_pop_col = "value", + by = c("geo_value" = "states"), + role = "raw", + suffix = "_scaled" + ) %>% step_epi_lag(cases_scaled, lag = c(0, 7, 14)) %>% step_epi_ahead(cases_scaled, ahead = 7, role = "outcome") %>% step_naomit(all_predictors()) %>% @@ -97,19 +113,25 @@ test_that("Postprocessing workflow works and values correct", { layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% - layer_population_scaling(.pred, df = pop_data, - by = c("geo_value" = "states"), - df_pop_col = "value") + layer_population_scaling(.pred, + df = pop_data, + by = c("geo_value" = "states"), + df_pop_col = "value" + ) wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) - latest <- get_test_data(recipe = r, - x = epiprocess::jhu_csse_daily_subset %>% - dplyr::filter(time_value > "2021-11-01", - geo_value %in% c("ca", "ny")) %>% - dplyr::select(geo_value, time_value, cases)) + latest <- get_test_data( + recipe = r, + x = epiprocess::jhu_csse_daily_subset %>% + dplyr::filter( + time_value > "2021-11-01", + geo_value %in% c("ca", "ny") + ) %>% + dplyr::select(geo_value, time_value, cases) + ) expect_silent(p <- predict(wf, latest)) @@ -121,9 +143,11 @@ test_that("Postprocessing workflow works and values correct", { layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% - layer_population_scaling(.pred, df = pop_data, rate_rescaling = 10000, - by = c("geo_value" = "states"), - df_pop_col = "value") + layer_population_scaling(.pred, + df = pop_data, rate_rescaling = 10000, + by = c("geo_value" = "states"), + df_pop_col = "value" + ) wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) @@ -131,7 +155,6 @@ test_that("Postprocessing workflow works and values correct", { expect_equal(nrow(p), 2L) expect_equal(ncol(p), 4L) expect_equal(p$.pred_scaled, p$.pred * c(2, 3)) - }) test_that("Postprocessing to get cases from case rate", { @@ -139,14 +162,18 @@ test_that("Postprocessing to get cases from case rate", { dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, case_rate) - reverse_pop_data = data.frame(states = c("ca", "ny"), - value = c(1/20000, 1/30000)) + reverse_pop_data <- data.frame( + states = c("ca", "ny"), + value = c(1 / 20000, 1 / 30000) + ) r <- epi_recipe(jhu) %>% - step_population_scaling(df = reverse_pop_data, - df_pop_col = "value", - by = c("geo_value" = "states"), - case_rate, suffix = "_scaled") %>% + step_population_scaling( + df = reverse_pop_data, + df_pop_col = "value", + by = c("geo_value" = "states"), + case_rate, suffix = "_scaled" + ) %>% step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases step_naomit(all_predictors()) %>% @@ -156,25 +183,31 @@ test_that("Postprocessing to get cases from case rate", { layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% - layer_population_scaling(.pred, df = reverse_pop_data, - by = c("geo_value" = "states"), - df_pop_col = "value") + layer_population_scaling(.pred, + df = reverse_pop_data, + by = c("geo_value" = "states"), + df_pop_col = "value" + ) wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) - latest <- get_test_data(recipe = r, - x = case_death_rate_subset %>% - dplyr::filter(time_value > "2021-11-01", - geo_value %in% c("ca", "ny")) %>% - dplyr::select(geo_value, time_value, case_rate)) + latest <- get_test_data( + recipe = r, + x = case_death_rate_subset %>% + dplyr::filter( + time_value > "2021-11-01", + geo_value %in% c("ca", "ny") + ) %>% + dplyr::select(geo_value, time_value, case_rate) + ) expect_silent(p <- predict(wf, latest)) expect_equal(nrow(p), 2L) expect_equal(ncol(p), 4L) - expect_equal(p$.pred_scaled, p$.pred * c(1/20000, 1/30000)) + expect_equal(p$.pred_scaled, p$.pred * c(1 / 20000, 1 / 30000)) }) @@ -184,15 +217,18 @@ test_that("test joining by default columns", { dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, case_rate) - reverse_pop_data = data.frame(geo_value = c("ca", "ny"), - values = c(1/20000, 1/30000)) + reverse_pop_data <- data.frame( + geo_value = c("ca", "ny"), + values = c(1 / 20000, 1 / 30000) + ) r <- epi_recipe(jhu) %>% step_population_scaling(case_rate, - df = reverse_pop_data, - df_pop_col = "values", - by = NULL, - suffix = "_scaled") %>% + df = reverse_pop_data, + df_pop_col = "values", + by = NULL, + suffix = "_scaled" + ) %>% step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases step_naomit(all_predictors()) %>% @@ -206,9 +242,11 @@ test_that("test joining by default columns", { layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% - layer_population_scaling(.pred, df = reverse_pop_data, - by = NULL, - df_pop_col = "values") + layer_population_scaling(.pred, + df = reverse_pop_data, + by = NULL, + df_pop_col = "values" + ) suppressMessages( wf <- epi_workflow(r, parsnip::linear_reg()) %>% @@ -227,7 +265,6 @@ test_that("test joining by default columns", { ) suppressMessages(p <- predict(wf, latest)) - }) @@ -237,15 +274,18 @@ test_that("expect error if `by` selector does not match", { dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, case_rate) - reverse_pop_data = data.frame(geo_value = c("ca", "ny"), - values = c(1/20000, 1/30000)) + reverse_pop_data <- data.frame( + geo_value = c("ca", "ny"), + values = c(1 / 20000, 1 / 30000) + ) r <- epi_recipe(jhu) %>% step_population_scaling(case_rate, - df = reverse_pop_data, - df_pop_col = "values", - by = c("a" = "b"), - suffix = "_scaled") %>% + df = reverse_pop_data, + df_pop_col = "values", + by = c("a" = "b"), + suffix = "_scaled" + ) %>% step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases step_naomit(all_predictors()) %>% @@ -255,21 +295,25 @@ test_that("expect error if `by` selector does not match", { layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% - layer_population_scaling(.pred, df = reverse_pop_data, - by = NULL, - df_pop_col = "values") + layer_population_scaling(.pred, + df = reverse_pop_data, + by = NULL, + df_pop_col = "values" + ) expect_error( wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% - add_frosting(f)) + add_frosting(f) + ) r <- epi_recipe(jhu) %>% step_population_scaling(case_rate, - df = reverse_pop_data, - df_pop_col = "values", - by = c("geo_value" = "geo_value"), - suffix = "_scaled") %>% + df = reverse_pop_data, + df_pop_col = "values", + by = c("geo_value" = "geo_value"), + suffix = "_scaled" + ) %>% step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases step_naomit(all_predictors()) %>% @@ -279,16 +323,22 @@ test_that("expect error if `by` selector does not match", { layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% - layer_population_scaling(.pred, df = reverse_pop_data, - by = c("nothere" = "nope"), - df_pop_col = "values") + layer_population_scaling(.pred, + df = reverse_pop_data, + by = c("nothere" = "nope"), + df_pop_col = "values" + ) - latest <- get_test_data(recipe = r, - x = case_death_rate_subset %>% - dplyr::filter(time_value > "2021-11-01", - geo_value %in% c("ca", "ny")) %>% - dplyr::select(geo_value, time_value, case_rate)) + latest <- get_test_data( + recipe = r, + x = case_death_rate_subset %>% + dplyr::filter( + time_value > "2021-11-01", + geo_value %in% c("ca", "ny") + ) %>% + dplyr::select(geo_value, time_value, case_rate) + ) wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% @@ -299,34 +349,46 @@ test_that("expect error if `by` selector does not match", { test_that("Rate rescaling behaves as expected", { - x <- tibble(geo_value = rep("place",50), - time_value = as.Date("2021-01-01") + 0:49, - case_rate = rep(0.0005, 50), - cases = rep(5000, 50)) %>% + x <- tibble( + geo_value = rep("place", 50), + time_value = as.Date("2021-01-01") + 0:49, + case_rate = rep(0.0005, 50), + cases = rep(5000, 50) + ) %>% as_epi_df() - reverse_pop_data = data.frame(states = c("place"), - value = c(1/1000)) + reverse_pop_data <- data.frame( + states = c("place"), + value = c(1 / 1000) + ) r <- epi_recipe(x) %>% - step_population_scaling(df = reverse_pop_data, - df_pop_col = "value", - rate_rescaling = 100, # cases per 100 - by = c("geo_value" = "states"), - case_rate, suffix = "_scaled") + step_population_scaling( + df = reverse_pop_data, + df_pop_col = "value", + rate_rescaling = 100, # cases per 100 + by = c("geo_value" = "states"), + case_rate, suffix = "_scaled" + ) - expect_equal(unique(bake(prep(r,x),x)$case_rate_scaled), - 0.0005*100/(1/1000)) # done testing step_* + expect_equal( + unique(bake(prep(r, x), x)$case_rate_scaled), + 0.0005 * 100 / (1 / 1000) + ) # done testing step_* f <- frosting() %>% - layer_population_scaling(.pred, df = reverse_pop_data, - rate_rescaling = 100, # revert back to case rate per 100 - by = c("geo_value" = "states"), - df_pop_col = "value") - - x <- tibble(geo_value = rep("place",50), - time_value = as.Date("2021-01-01") + 0:49, - case_rate = rep(0.0005, 50)) %>% + layer_population_scaling(.pred, + df = reverse_pop_data, + rate_rescaling = 100, # revert back to case rate per 100 + by = c("geo_value" = "states"), + df_pop_col = "value" + ) + + x <- tibble( + geo_value = rep("place", 50), + time_value = as.Date("2021-01-01") + 0:49, + case_rate = rep(0.0005, 50) + ) %>% as_epi_df() r <- epi_recipe(x) %>% @@ -339,10 +401,12 @@ test_that("Rate rescaling behaves as expected", { layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% - layer_population_scaling(.pred, df = reverse_pop_data, - rate_rescaling = 100, # revert back to case rate per 100 - by = c("geo_value" = "states"), - df_pop_col = "value") + layer_population_scaling(.pred, + df = reverse_pop_data, + rate_rescaling = 100, # revert back to case rate per 100 + by = c("geo_value" = "states"), + df_pop_col = "value" + ) wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(x) %>% @@ -351,8 +415,10 @@ test_that("Rate rescaling behaves as expected", { latest <- get_test_data(recipe = r, x = x) # suppress warning: prediction from a rank-deficient fit may be misleading - suppressWarnings(expect_equal(unique(predict(wf, latest)$.pred)*(1/1000)/100, - unique(predict(wf, latest)$.pred_scaled))) + suppressWarnings(expect_equal( + unique(predict(wf, latest)$.pred) * (1 / 1000) / 100, + unique(predict(wf, latest)$.pred_scaled) + )) }) test_that("Extra Columns are ignored", { diff --git a/tests/testthat/test-replace_Inf.R b/tests/testthat/test-replace_Inf.R index 8f4e9c334..f9993ca13 100644 --- a/tests/testthat/test-replace_Inf.R +++ b/tests/testthat/test-replace_Inf.R @@ -4,12 +4,12 @@ test_that("replace_inf works", { expect_identical(vec_replace_inf(x, 3), as.double(1:5)) df <- tibble( geo_value = letters[1:5], time_value = 1:5, - v1 = 1:5, v2 = c(1,2,Inf, -Inf,NA) + v1 = 1:5, v2 = c(1, 2, Inf, -Inf, NA) ) library(dplyr) ok <- c("geo_value", "time_value") df2 <- df %>% mutate(across(!all_of(ok), ~ vec_replace_inf(.x, NA))) - expect_identical(df[,1:3], df2[,1:3]) - expect_identical(df2$v2, c(1,2,NA,NA,NA)) + expect_identical(df[, 1:3], df2[, 1:3]) + expect_identical(df2$v2, c(1, 2, NA, NA, NA)) }) diff --git a/tests/testthat/test-step_epi_naomit.R b/tests/testthat/test-step_epi_naomit.R index d65734ff6..2fb173f01 100644 --- a/tests/testthat/test-step_epi_naomit.R +++ b/tests/testthat/test-step_epi_naomit.R @@ -3,16 +3,18 @@ library(parsnip) library(workflows) # Random generated dataset -x <- tibble(geo_value = rep("nowhere",200), - time_value = as.Date("2021-01-01") + 0:199, - case_rate = 1:200, - death_rate = 1:200) %>% +x <- tibble( + geo_value = rep("nowhere", 200), + time_value = as.Date("2021-01-01") + 0:199, + case_rate = 1:200, + death_rate = 1:200 +) %>% epiprocess::as_epi_df() # Preparing the datasets to be used for comparison r <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 7) %>% - step_epi_lag(death_rate, lag = c(0,7,14)) + step_epi_lag(death_rate, lag = c(0, 7, 14)) test_that("Argument must be a recipe", { expect_error(step_epi_naomit(x)) @@ -25,13 +27,13 @@ z2 <- r %>% # Checks the behaviour of a step function, omitting the quosure and id that # differ from one another, even with identical behaviour -behav <- function(recipe,step_num) recipe$steps[[step_num]][-1][-5] +behav <- function(recipe, step_num) recipe$steps[[step_num]][-1][-5] # Checks the class type of an object -step_class <- function(recipe,step_num) class(recipe$steps[step_num]) +step_class <- function(recipe, step_num) class(recipe$steps[step_num]) test_that("Check that both functions behave the same way", { - expect_identical(behav(z1,3),behav(z2,3)) - expect_identical(behav(z1,4),behav(z2,4)) - expect_identical(step_class(z1,3),step_class(z2,3)) - expect_identical(step_class(z1,4),step_class(z2,4)) + expect_identical(behav(z1, 3), behav(z2, 3)) + expect_identical(behav(z1, 4), behav(z2, 4)) + expect_identical(step_class(z1, 3), step_class(z2, 3)) + expect_identical(step_class(z1, 4), step_class(z2, 4)) }) diff --git a/tests/testthat/test-step_epi_shift.R b/tests/testthat/test-step_epi_shift.R index 24898ad64..da04fd0f2 100644 --- a/tests/testthat/test-step_epi_shift.R +++ b/tests/testthat/test-step_epi_shift.R @@ -4,10 +4,12 @@ library(parsnip) library(workflows) # Random generated dataset -x <- tibble(geo_value = rep("place",200), - time_value = as.Date("2021-01-01") + 0:199, - case_rate = sqrt(1:200) + atan(0.1 * 1:200) + sin(5*1:200) + 1, - death_rate = atan(0.1 * 1:200) + cos(5*1:200) + 1) %>% +x <- tibble( + geo_value = rep("place", 200), + time_value = as.Date("2021-01-01") + 0:199, + case_rate = sqrt(1:200) + atan(0.1 * 1:200) + sin(5 * 1:200) + 1, + death_rate = atan(0.1 * 1:200) + cos(5 * 1:200) + 1 +) %>% as_epi_df() slm_fit <- function(recipe, data = x) { @@ -54,12 +56,12 @@ test_that("Values for ahead and lag cannot be duplicates", { test_that("Check that epi_lag shifts applies the shift", { r5 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 7) %>% - step_epi_lag(death_rate, lag = c(0,7,14)) + step_epi_lag(death_rate, lag = c(0, 7, 14)) # Two steps passed here - expect_equal(length(r5$steps),2) + expect_equal(length(r5$steps), 2) fit5 <- slm_fit(r5) # Should have four predictors, including the intercept - expect_equal(length(fit5$fit$fit$fit$coefficients),4) + expect_equal(length(fit5$fit$fit$fit$coefficients), 4) }) diff --git a/tests/testthat/test-step_growth_rate.R b/tests/testthat/test-step_growth_rate.R index 2e478f54a..d0dec170e 100644 --- a/tests/testthat/test-step_growth_rate.R +++ b/tests/testthat/test-step_growth_rate.R @@ -24,7 +24,6 @@ test_that("step_growth_rate validates arguments", { expect_error(step_growth_rate(r, value, replace_Inf = c(1, 2))) expect_silent(step_growth_rate(r, value, replace_Inf = NULL)) expect_silent(step_growth_rate(r, value, replace_Inf = NA)) - }) @@ -33,7 +32,10 @@ test_that("step_growth_rate works for a single signal", { edf <- as_epi_df(df) r <- epi_recipe(edf) - res <- r %>% step_growth_rate(value, horizon = 1) %>% prep() %>% bake(edf) + res <- r %>% + step_growth_rate(value, horizon = 1) %>% + prep() %>% + bake(edf) expect_equal(res$gr_1_rel_change_value, c(NA, 1 / 6:9)) df <- dplyr::bind_rows( @@ -42,20 +44,27 @@ test_that("step_growth_rate works for a single signal", { ) edf <- as_epi_df(df) r <- epi_recipe(edf) - res <- r %>% step_growth_rate(value, horizon = 1) %>% prep() %>% bake(edf) + res <- r %>% + step_growth_rate(value, horizon = 1) %>% + prep() %>% + bake(edf) expect_equal(res$gr_1_rel_change_value, rep(c(NA, 1 / 6:9), each = 2)) - }) test_that("step_growth_rate works for a two signals", { - df <- data.frame(time_value = 1:5, - geo_value = rep("a", 5), - v1 = 6:10, v2 = 1:5) + df <- data.frame( + time_value = 1:5, + geo_value = rep("a", 5), + v1 = 6:10, v2 = 1:5 + ) edf <- as_epi_df(df) r <- epi_recipe(edf) - res <- r %>% step_growth_rate(v1, v2, horizon = 1) %>% prep() %>% bake(edf) + res <- r %>% + step_growth_rate(v1, v2, horizon = 1) %>% + prep() %>% + bake(edf) expect_equal(res$gr_1_rel_change_v1, c(NA, 1 / 6:9)) expect_equal(res$gr_1_rel_change_v2, c(NA, 1 / 1:4)) @@ -65,8 +74,10 @@ test_that("step_growth_rate works for a two signals", { ) edf <- as_epi_df(df) r <- epi_recipe(edf) - res <- r %>% step_growth_rate(v1, v2, horizon = 1) %>% prep() %>% bake(edf) + res <- r %>% + step_growth_rate(v1, v2, horizon = 1) %>% + prep() %>% + bake(edf) expect_equal(res$gr_1_rel_change_v1, rep(c(NA, 1 / 6:9), each = 2)) expect_equal(res$gr_1_rel_change_v2, rep(c(NA, 1 / 1:4), each = 2)) - }) diff --git a/tests/testthat/test-step_lag_difference.R b/tests/testthat/test-step_lag_difference.R index 2d1581aef..dc61d12d4 100644 --- a/tests/testthat/test-step_lag_difference.R +++ b/tests/testthat/test-step_lag_difference.R @@ -17,7 +17,6 @@ test_that("step_lag_difference validates arguments", { expect_error(step_lag_difference(r, value, trained = 1)) expect_error(step_lag_difference(r, value, skip = 1)) expect_error(step_lag_difference(r, value, columns = letters[1:5])) - }) @@ -28,12 +27,14 @@ test_that("step_lag_difference works for a single signal", { res <- r %>% step_lag_difference(value, horizon = 1) %>% - prep() %>% bake(edf) + prep() %>% + bake(edf) expect_equal(res$lag_diff_1_value, c(NA, rep(1, 4))) res <- r %>% step_lag_difference(value, horizon = 1:2) %>% - prep() %>% bake(edf) + prep() %>% + bake(edf) expect_equal(res$lag_diff_1_value, c(NA, rep(1, 4))) expect_equal(res$lag_diff_2_value, c(NA, NA, rep(2, 3))) @@ -45,22 +46,27 @@ test_that("step_lag_difference works for a single signal", { ) edf <- as_epi_df(df) r <- epi_recipe(edf) - res <- r %>% step_lag_difference(value, horizon = 1) %>% prep() %>% bake(edf) + res <- r %>% + step_lag_difference(value, horizon = 1) %>% + prep() %>% + bake(edf) expect_equal(res$lag_diff_1_value, c(NA, NA, rep(1, 8))) - }) test_that("step_lag_difference works for a two signals", { - df <- data.frame(time_value = 1:5, - geo_value = rep("a", 5), - v1 = 6:10, v2 = 1:5 * 2) + df <- data.frame( + time_value = 1:5, + geo_value = rep("a", 5), + v1 = 6:10, v2 = 1:5 * 2 + ) edf <- as_epi_df(df) r <- epi_recipe(edf) res <- r %>% step_lag_difference(v1, v2, horizon = 1:2) %>% - prep() %>% bake(edf) + prep() %>% + bake(edf) expect_equal(res$lag_diff_1_v1, c(NA, rep(1, 4))) expect_equal(res$lag_diff_2_v1, c(NA, NA, rep(2, 3))) expect_equal(res$lag_diff_1_v2, c(NA, rep(2, 4))) @@ -74,10 +80,10 @@ test_that("step_lag_difference works for a two signals", { r <- epi_recipe(edf) res <- r %>% step_lag_difference(v1, v2, horizon = 1:2) %>% - prep() %>% bake(edf) + prep() %>% + bake(edf) expect_equal(res$lag_diff_1_v1, rep(c(NA, rep(1, 4)), each = 2)) expect_equal(res$lag_diff_2_v1, rep(c(NA, NA, rep(2, 3)), each = 2)) expect_equal(res$lag_diff_1_v2, c(NA, NA, rep(2:1, 4))) expect_equal(res$lag_diff_2_v2, c(rep(NA, 4), rep(c(4, 2), 3))) - }) diff --git a/tests/testthat/test-step_training_window.R b/tests/testthat/test-step_training_window.R index 4b185b99a..c8a17f43f 100644 --- a/tests/testthat/test-step_training_window.R +++ b/tests/testthat/test-step_training_window.R @@ -1,6 +1,8 @@ toy_epi_df <- tibble::tibble( - time_value = rep(seq(as.Date("2020-01-01"), by = 1, - length.out = 100), times = 2), + time_value = rep(seq(as.Date("2020-01-01"), + by = 1, + length.out = 100 + ), times = 2), geo_value = rep(c("ca", "hi"), each = 100), x = 1:200, y = 1:200, ) %>% epiprocess::as_epi_df() @@ -16,8 +18,10 @@ test_that("step_training_window works with default n_recent", { expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") expect_named(p, c("time_value", "geo_value", "x", "y")) - expect_equal(p$time_value, - rep(seq(as.Date("2020-02-20"), as.Date("2020-04-09"), by = 1), times = 2)) + expect_equal( + p$time_value, + rep(seq(as.Date("2020-02-20"), as.Date("2020-04-09"), by = 1), times = 2) + ) expect_equal(p$geo_value, rep(c("ca", "hi"), each = 50)) }) @@ -31,35 +35,41 @@ test_that("step_training_window works with specified n_recent", { expect_equal(ncol(p2), 4L) expect_s3_class(p2, "epi_df") expect_named(p2, c("time_value", "geo_value", "x", "y")) - expect_equal(p2$time_value, - rep(seq(as.Date("2020-04-05"), as.Date("2020-04-09"), by = 1), times = 2)) + expect_equal( + p2$time_value, + rep(seq(as.Date("2020-04-05"), as.Date("2020-04-09"), by = 1), times = 2) + ) expect_equal(p2$geo_value, rep(c("ca", "hi"), each = 5)) }) test_that("step_training_window does not proceed with specified new_data", { -# Should just return whatever the new_data is, unaffected by the step -# because step_training_window only effects training data, not -# testing data. + # Should just return whatever the new_data is, unaffected by the step + # because step_training_window only effects training data, not + # testing data. p3 <- epi_recipe(y ~ x, data = toy_epi_df) %>% step_training_window(n_recent = 3) %>% recipes::prep(toy_epi_df) %>% - recipes::bake(new_data = toy_epi_df[1:10,]) + recipes::bake(new_data = toy_epi_df[1:10, ]) expect_equal(nrow(p3), 10L) expect_equal(ncol(p3), 4L) expect_s3_class(p3, "epi_df") # cols will be predictors, outcomes, time_value, geo_value expect_named(p3, c("x", "y", "time_value", "geo_value")) - expect_equal(p3$time_value, - rep(seq(as.Date("2020-01-01"), as.Date("2020-01-10"), by = 1), times = 1)) + expect_equal( + p3$time_value, + rep(seq(as.Date("2020-01-01"), as.Date("2020-01-10"), by = 1), times = 1) + ) expect_equal(p3$geo_value, rep("ca", times = 10)) }) test_that("step_training_window works with multiple keys", { toy_epi_df2 <- tibble::tibble( x = 1:200, y = 1:200, - time_value = rep(seq(as.Date("2020-01-01"), by = 1, - length.out = 100), times = 2), + time_value = rep(seq(as.Date("2020-01-01"), + by = 1, + length.out = 100 + ), times = 2), geo_value = rep(c("ca", "hi"), each = 100), additional_key = as.factor(rep(1:4, each = 50)), ) %>% epiprocess::as_epi_df() @@ -77,9 +87,13 @@ test_that("step_training_window works with multiple keys", { expect_named(p4, c("time_value", "geo_value", "additional_key", "x", "y")) expect_equal( p4$time_value, - rep(c(seq(as.Date("2020-02-17"), as.Date("2020-02-19"), length.out = 3), - seq(as.Date("2020-04-07"), as.Date("2020-04-09"), - length.out = 3)), times = 2)) + rep(c( + seq(as.Date("2020-02-17"), as.Date("2020-02-19"), length.out = 3), + seq(as.Date("2020-04-07"), as.Date("2020-04-09"), + length.out = 3 + ) + ), times = 2) + ) expect_equal(p4$geo_value, rep(c("ca", "hi"), each = 6)) }) @@ -88,9 +102,12 @@ test_that("step_training_window and step_naomit interact", { tib <- tibble::tibble( x = 1:10, y = 1:10, - time_value = rep(seq(as.Date("2020-01-01"), by = 1, - length.out = 5), times = 2), - geo_value = rep(c("ca", "hi"), each = 5)) %>% + time_value = rep(seq(as.Date("2020-01-01"), + by = 1, + length.out = 5 + ), times = 2), + geo_value = rep(c("ca", "hi"), each = 5) + ) %>% as_epi_df() e1 <- epi_recipe(y ~ x, data = tib) %>% diff --git a/vignettes/articles/sliding.Rmd b/vignettes/articles/sliding.Rmd index eeb389b4c..67af7289d 100644 --- a/vignettes/articles/sliding.Rmd +++ b/vignettes/articles/sliding.Rmd @@ -62,16 +62,16 @@ versions for the less up-to-date input archive. theme_set(theme_bw()) y <- readRDS(system.file( - "extdata", "all_states_covidcast_signals.rds", + "extdata", "all_states_covidcast_signals.rds", package = "epipredict", mustWork = TRUE -)) - +)) + y <- purrr::map(y, ~ select(.x, geo_value, time_value, version = issue, value)) x <- epix_merge( y[[1]] %>% rename(percent_cli = value) %>% as_epi_archive(compactify = FALSE), y[[2]] %>% rename(case_rate = value) %>% as_epi_archive(compactify = FALSE), - sync = "locf", + sync = "locf", compactify = TRUE ) rm(y) @@ -87,10 +87,10 @@ output. ```{r make-arx-kweek, warning = FALSE} # Latest snapshot of data, and forecast dates -x_latest <- epix_as_of(x, max_version = max(x$versions_end)) +x_latest <- epix_as_of(x, max_version = max(x$versions_end)) fc_time_values <- seq( - from = as.Date("2020-08-01"), - to = as.Date("2021-11-01"), + from = as.Date("2020-08-01"), + to = as.Date("2021-11-01"), by = "1 month" ) aheads <- c(7, 14, 21, 28) @@ -99,31 +99,36 @@ k_week_ahead <- function(epi_df, outcome, predictors, ahead = 7, engine) { epi_slide( epi_df, ~ arx_forecaster( - .x, outcome, predictors, engine, - args_list = arx_args_list(ahead = ahead)) %>% - extract2("predictions") %>% - select(-geo_value), - before = 120 - 1, - ref_time_values = fc_time_values, + .x, outcome, predictors, engine, + args_list = arx_args_list(ahead = ahead) + ) %>% + extract2("predictions") %>% + select(-geo_value), + before = 120 - 1, + ref_time_values = fc_time_values, new_col_name = "fc" - ) %>% + ) %>% select(geo_value, time_value, starts_with("fc")) %>% mutate(engine_type = engine$engine) } # Generate the forecasts and bind them together fc <- bind_rows( - map(aheads, - ~ k_week_ahead( - x_latest, "case_rate", c("case_rate", "percent_cli"), .x, - engine = linear_reg()) - ) %>% list_rbind() , - map(aheads, - ~ k_week_ahead( - x_latest, "case_rate", c("case_rate", "percent_cli"), .x, - engine = rand_forest(mode = "regression")) + map( + aheads, + ~ k_week_ahead( + x_latest, "case_rate", c("case_rate", "percent_cli"), .x, + engine = linear_reg() + ) + ) %>% list_rbind(), + map( + aheads, + ~ k_week_ahead( + x_latest, "case_rate", c("case_rate", "percent_cli"), .x, + engine = rand_forest(mode = "regression") + ) ) %>% list_rbind() -) %>% +) %>% pivot_quantiles(fc_.pred_distn) ``` @@ -142,11 +147,13 @@ model performance while keeping the graphic simple. fc_cafl <- fc %>% filter(geo_value %in% c("ca", "fl")) x_latest_cafl <- x_latest %>% filter(geo_value %in% c("ca", "fl")) -ggplot(fc_cafl, aes(fc_target_date, group = time_value, fill = engine_type)) + - geom_line(data = x_latest_cafl, aes(x = time_value, y = case_rate), - inherit.aes = FALSE, color = "gray50") + +ggplot(fc_cafl, aes(fc_target_date, group = time_value, fill = engine_type)) + + geom_line( + data = x_latest_cafl, aes(x = time_value, y = case_rate), + inherit.aes = FALSE, color = "gray50" + ) + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`), alpha = 0.4) + - geom_line(aes(y = fc_.pred)) + + geom_line(aes(y = fc_.pred)) + geom_point(aes(y = fc_.pred), size = 0.5) + geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + facet_grid(vars(geo_value), vars(engine_type), scales = "free") + @@ -192,13 +199,13 @@ linear regression with those from using boosted regression trees. ```{r get-can-fc, warning = FALSE} # source("drafts/canada-case-rates.R) can <- readRDS(system.file( - "extdata", "can_prov_cases.rds", + "extdata", "can_prov_cases.rds", package = "epipredict", mustWork = TRUE )) can <- can %>% - group_by(version, geo_value) %>% - arrange(time_value) %>% + group_by(version, geo_value) %>% + arrange(time_value) %>% mutate(cr_7dav = RcppRoll::roll_meanr(case_rate, n = 7L)) %>% as_epi_archive(compactify = TRUE) @@ -206,52 +213,71 @@ can_latest <- epix_as_of(can, max_version = max(can$DT$version)) # Generate the forecasts, and bind them together can_fc <- bind_rows( - map(aheads, - ~ k_week_ahead(can_latest, "cr_7dav", "cr_7dav", .x, linear_reg()) + map( + aheads, + ~ k_week_ahead(can_latest, "cr_7dav", "cr_7dav", .x, linear_reg()) ) %>% list_rbind(), - map(aheads, - ~ k_week_ahead( - can_latest, "cr_7dav", "cr_7dav", .x, - boost_tree(mode = "regression", trees = 20)) + map( + aheads, + ~ k_week_ahead( + can_latest, "cr_7dav", "cr_7dav", .x, + boost_tree(mode = "regression", trees = 20) + ) ) %>% list_rbind() -) %>% +) %>% pivot_quantiles(fc_.pred_distn) ``` The figures below shows the results for all of the provinces. ```{r plot-can-fc-lr, message = FALSE, warning = FALSE, fig.width = 9, fig.height = 12} -ggplot(can_fc %>% filter(engine_type == "lm"), - aes(x = fc_target_date, group = time_value)) + +ggplot( + can_fc %>% filter(engine_type == "lm"), + aes(x = fc_target_date, group = time_value) +) + coord_cartesian(xlim = lubridate::ymd(c("2020-12-01", NA))) + - geom_line(data = can_latest, aes(x = time_value, y = cr_7dav), - inherit.aes = FALSE, color = "gray50") + + geom_line( + data = can_latest, aes(x = time_value, y = cr_7dav), + inherit.aes = FALSE, color = "gray50" + ) + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`, fill = geo_value), - alpha = 0.4) + - geom_line(aes(y = fc_.pred)) + geom_point(aes(y = fc_.pred), size = 0.5) + + alpha = 0.4 + ) + + geom_line(aes(y = fc_.pred)) + + geom_point(aes(y = fc_.pred), size = 0.5) + geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + facet_wrap(~geo_value, scales = "free_y", ncol = 3) + scale_x_date(minor_breaks = "month", date_labels = "%b %y") + - labs(title = "Using simple linear regression", x = "Date", - y = "Reported COVID-19 case rates") + - theme(legend.position = "none") + labs( + title = "Using simple linear regression", x = "Date", + y = "Reported COVID-19 case rates" + ) + + theme(legend.position = "none") ``` ```{r plot-can-fc-boost, message = FALSE, warning = FALSE, fig.width = 9, fig.height = 12} -ggplot(can_fc %>% filter(engine_type == "xgboost"), - aes(x = fc_target_date, group = time_value)) + +ggplot( + can_fc %>% filter(engine_type == "xgboost"), + aes(x = fc_target_date, group = time_value) +) + coord_cartesian(xlim = lubridate::ymd(c("2020-12-01", NA))) + - geom_line(data = can_latest, aes(x = time_value, y = cr_7dav), - inherit.aes = FALSE, color = "gray50") + + geom_line( + data = can_latest, aes(x = time_value, y = cr_7dav), + inherit.aes = FALSE, color = "gray50" + ) + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`, fill = geo_value), - alpha = 0.4) + - geom_line(aes(y = fc_.pred)) + geom_point(aes(y = fc_.pred), size = 0.5) + + alpha = 0.4 + ) + + geom_line(aes(y = fc_.pred)) + + geom_point(aes(y = fc_.pred), size = 0.5) + geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + - facet_wrap(~ geo_value, scales = "free_y", ncol = 3) + + facet_wrap(~geo_value, scales = "free_y", ncol = 3) + scale_x_date(minor_breaks = "month", date_labels = "%b %y") + - labs(title = "Using boosted regression trees", x = "Date", - y = "Reported COVID-19 case rates") + - theme(legend.position = "none") + labs( + title = "Using boosted regression trees", x = "Date", + y = "Reported COVID-19 case rates" + ) + + theme(legend.position = "none") ``` Both approaches tend to produce quite volatile forecasts (point predictions) @@ -280,17 +306,20 @@ k_week_version_aware <- function(ahead = 7, version_aware = TRUE) { x, ~ arx_forecaster( .x, "case_rate", c("case_rate", "percent_cli"), - args_list = arx_args_list(ahead = ahead)) %>% + args_list = arx_args_list(ahead = ahead) + ) %>% extract2("predictions"), - before = 120 - 1, - ref_time_values = fc_time_values, - new_col_name = "fc") %>% + before = 120 - 1, + ref_time_values = fc_time_values, + new_col_name = "fc" + ) %>% mutate(engine_type = "lm", version_aware = version_aware) %>% rename(geo_value = fc_geo_value) } else { k_week_ahead( - x_latest, "case_rate", c("case_rate", "percent_cli"), - ahead, linear_reg()) %>% mutate(version_aware = version_aware) + x_latest, "case_rate", c("case_rate", "percent_cli"), + ahead, linear_reg() + ) %>% mutate(version_aware = version_aware) } } @@ -304,17 +333,22 @@ fc <- bind_rows( Now we can plot the results on top of the latest case rates. As before, we will only display and focus on the results for FL and CA for simplicity. ```{r plot-ar-asof, message = FALSE, warning = FALSE, fig.width = 9, fig.height = 6} -fc_cafl = fc %>% filter(geo_value %in% c("ca", "fl")) -x_latest_cafl = x_latest %>% filter(geo_value %in% c("ca", "fl")) +fc_cafl <- fc %>% filter(geo_value %in% c("ca", "fl")) +x_latest_cafl <- x_latest %>% filter(geo_value %in% c("ca", "fl")) ggplot(fc_cafl, aes(x = fc_target_date, group = time_value, fill = version_aware)) + - geom_line(data = x_latest_cafl, aes(x = time_value, y = case_rate), - inherit.aes = FALSE, color = "gray50") + + geom_line( + data = x_latest_cafl, aes(x = time_value, y = case_rate), + inherit.aes = FALSE, color = "gray50" + ) + geom_ribbon(aes(ymin = `0.05`, ymax = `0.95`), alpha = 0.4) + - geom_line(aes(y = fc_.pred)) + geom_point(aes(y = fc_.pred), size = 0.5) + + geom_line(aes(y = fc_.pred)) + + geom_point(aes(y = fc_.pred), size = 0.5) + geom_vline(aes(xintercept = time_value), linetype = 2, alpha = 0.5) + - facet_grid(geo_value ~ version_aware, scales = "free", - labeller = labeller(version_aware = label_both)) + + facet_grid(geo_value ~ version_aware, + scales = "free", + labeller = labeller(version_aware = label_both) + ) + scale_x_date(minor_breaks = "month", date_labels = "%b %y") + labs(x = "Date", y = "Reported COVID-19 case rates") + scale_fill_brewer(palette = "Set1") + diff --git a/vignettes/epipredict.Rmd b/vignettes/epipredict.Rmd index b0eeeb5a9..17a604504 100644 --- a/vignettes/epipredict.Rmd +++ b/vignettes/epipredict.Rmd @@ -12,7 +12,8 @@ knitr::opts_chunk$set( echo = TRUE, collapse = FALSE, comment = "#>", - out.width = "100%") + out.width = "100%" +) ``` ```{r setup, message=FALSE} @@ -103,7 +104,7 @@ We'll estimate the model jointly across all locations using only the most recent ```{r demo-workflow} jhu <- jhu %>% filter(time_value >= max(time_value) - 30) out <- arx_forecaster( - jhu, + jhu, outcome = "death_rate", predictors = c("case_rate", "death_rate") ) @@ -115,11 +116,11 @@ The `out` object has two components: 1. The predictions which is just another `epi_df`. It contains the predictions for each location along with additional columns. By default, these are a 90% predictive interval, the `forecast_date` (the date on which the forecast was putatively made) and the `target_date` (the date for which the forecast is being made). ```{r} - out$predictions +out$predictions ``` 2. A list object of class `epi_workflow`. This object encapsulates all the instructions necessary to create the prediction. More details on this below. ```{r} - out$epi_workflow +out$epi_workflow ``` Note that the `time_value` in the predictions is not necessarily meaningful, @@ -137,13 +138,14 @@ knitr::opts_chunk$set(warning = FALSE, message = FALSE) ```{r differential-lags} out2week <- arx_forecaster( - jhu, - outcome = "death_rate", + jhu, + outcome = "death_rate", predictors = c("case_rate", "death_rate"), args_list = arx_args_list( lags = list(c(0, 1, 2, 3, 7, 14), c(0, 7, 14)), - ahead = 14) + ahead = 14 ) +) ``` Here, we've used different lags on the `case_rate` and are now predicting 2 weeks ahead. This example also illustrates a major difficulty with the "iterative" versions of AR models. This model doesn't produce forecasts for `case_rate`, and so, would not have data to "plug in" for the necessary lags.[^1] @@ -155,8 +157,9 @@ Another property of the basic model is the predictive interval. We describe this ```{r differential-levels} out_q <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), args_list = arx_args_list( - levels = c(.01, .025, seq(.05, .95, by = .05), .975, .99)) + levels = c(.01, .025, seq(.05, .95, by = .05), .975, .99) ) +) ``` The column `.pred_dstn` in the `predictions` object is actually a "distribution" here parameterized by its quantiles. For this default forecaster, these are created using the quantiles of the residuals of the predictive model (possibly symmetrized). Here, we used 23 quantiles, but one can grab a particular quantile, @@ -168,7 +171,7 @@ head(quantile(out_q$predictions$.pred_distn, p = .4)) or extract the entire distribution into a "long" `epi_df` with `tau` being the probability and `q` being the value associated to that quantile. ```{r q2} -out_q$predictions %>% +out_q$predictions %>% # first create a "nested" list-column mutate(.pred_distn = nested_quantiles(.pred_distn)) %>% unnest(.pred_distn) # then unnest it @@ -178,7 +181,7 @@ Additional simple adjustments to the basic forecaster can be made using the func ```{r, eval = FALSE} arx_args_list( - lags = c(0L, 7L, 14L), ahead = 7L, n_training = Inf, + lags = c(0L, 7L, 14L), ahead = 7L, n_training = Inf, forecast_date = NULL, target_date = NULL, levels = c(0.05, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), nafill_buffer = Inf @@ -192,22 +195,28 @@ The `trainer` argument determines the type of model we want. This takes a [`{parsnip}`](https://parsnip.tidymodels.org) model. The default is linear regression, but we could instead use a random forest with the `{ranger}` package: ```{r ranger, warning = FALSE} -out_rf <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), - rand_forest(mode = "regression")) +out_rf <- arx_forecaster( + jhu, "death_rate", c("case_rate", "death_rate"), + rand_forest(mode = "regression") +) ``` Or boosted regression trees with `{xgboost}`: ```{r xgboost, warning = FALSE} -out_gb <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), - boost_tree(mode = "regression", trees = 20)) +out_gb <- arx_forecaster( + jhu, "death_rate", c("case_rate", "death_rate"), + boost_tree(mode = "regression", trees = 20) +) ``` Or quantile regression, using our custom forecasting engine `quantile_reg()`: ```{r quantreg, warning = FALSE} -out_gb <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), - quantile_reg()) +out_gb <- arx_forecaster( + jhu, "death_rate", c("case_rate", "death_rate"), + quantile_reg() +) ``` FWIW, this last case (using quantile regression), is not far from what the Delphi production forecast team used for its Covid forecasts over the past few years. @@ -283,15 +292,19 @@ do "linear regression". Above we switched from `lm()` to `xgboost()` without any issue despite the fact that these functions couldn't be more different. ```{r, eval = FALSE} -lm(formula, data, subset, weights, na.action, method = "qr", - model = TRUE, x = FALSE, y = FALSE, qr = TRUE, singular.ok = TRUE, - contrasts = NULL, offset, ...) - -xgboost(data = NULL, label = NULL, missing = NA, weight = NULL, - params = list(), nrounds, verbose = 1, print_every_n = 1L, - early_stopping_rounds = NULL, maximize = NULL, save_period = NULL, - save_name = "xgboost.model", xgb_model = NULL, callbacks = list(), - ...) +lm(formula, data, subset, weights, na.action, + method = "qr", + model = TRUE, x = FALSE, y = FALSE, qr = TRUE, singular.ok = TRUE, + contrasts = NULL, offset, ... +) + +xgboost( + data = NULL, label = NULL, missing = NA, weight = NULL, + params = list(), nrounds, verbose = 1, print_every_n = 1L, + early_stopping_rounds = NULL, maximize = NULL, save_period = NULL, + save_name = "xgboost.model", xgb_model = NULL, callbacks = list(), + ... +) ``` `{epipredict}` provides a few engines/modules (the flatline forecaster and @@ -327,8 +340,9 @@ intervals at 0. The code to do this (inside the forecaster) is f <- frosting() %>% layer_predict() %>% layer_residual_quantiles( - probs = c(.01, .025, seq(.05, .95, by = .05), .975, .99), - symmetrize = TRUE) %>% + probs = c(.01, .025, seq(.05, .95, by = .05), .975, .99), + symmetrize = TRUE + ) %>% layer_add_forecast_date() %>% layer_add_target_date() %>% layer_threshold(starts_with(".pred")) @@ -338,7 +352,9 @@ At predict time, we add this object onto the `epi_workflow` and call `predict()` ```{r, warning=FALSE} test_data <- get_test_data(er, jhu) -ewf %>% add_frosting(f) %>% predict(test_data) +ewf %>% + add_frosting(f) %>% + predict(test_data) ``` The above `get_test_data()` function examines the recipe and ensures that enough @@ -365,7 +381,7 @@ r <- epi_recipe(jhu) %>% add_role(all_of(epi_keys(jhu)), new_role = "predictor") # bit of a weird hack to get the latest values per key -latest <- get_test_data(epi_recipe(jhu), jhu) +latest <- get_test_data(epi_recipe(jhu), jhu) f <- frosting() %>% layer_predict() %>% diff --git a/vignettes/preprocessing-and-models.Rmd b/vignettes/preprocessing-and-models.Rmd index ac0e2e08c..f85f35f71 100644 --- a/vignettes/preprocessing-and-models.Rmd +++ b/vignettes/preprocessing-and-models.Rmd @@ -12,32 +12,33 @@ knitr::opts_chunk$set( echo = TRUE, collapse = FALSE, comment = "#>", - out.width = "100%") + out.width = "100%" +) ``` -## Introduction +## Introduction -The `epipredict` package utilizes the `tidymodels` framework, namely -[`{recipes}`](https://recipes.tidymodels.org/) for -[dplyr](https://dplyr.tidyverse.org/)-like pipeable sequences -of feature engineering and [`{parsnip}`](https://parsnip.tidymodels.org/) for a -unified interface to a range of models. +The `epipredict` package utilizes the `tidymodels` framework, namely +[`{recipes}`](https://recipes.tidymodels.org/) for +[dplyr](https://dplyr.tidyverse.org/)-like pipeable sequences +of feature engineering and [`{parsnip}`](https://parsnip.tidymodels.org/) for a +unified interface to a range of models. -`epipredict` has additional customized feature engineering and preprocessing -steps, such as `step_epi_lag()`, `step_population_scaling()`, -`step_epi_naomit()`. They can be used along with -steps from the `{recipes}` package for more feature engineering. +`epipredict` has additional customized feature engineering and preprocessing +steps, such as `step_epi_lag()`, `step_population_scaling()`, +`step_epi_naomit()`. They can be used along with +steps from the `{recipes}` package for more feature engineering. In this vignette, we will illustrate some examples of how to use `epipredict` with `recipes` and `parsnip` for different purposes of epidemiological forecasting. -We will focus on basic autoregressive models, in which COVID cases and -deaths in the near future are predicted using a linear combination of cases and +We will focus on basic autoregressive models, in which COVID cases and +deaths in the near future are predicted using a linear combination of cases and deaths in the near past. -The remaining vignette will be split into three sections. The first section, we +The remaining vignette will be split into three sections. The first section, we will use a Poisson regression to predict death counts. In the second section, we will use a linear regression to predict death rates. Last but not least, we -will create a classification model for hotspot predictions. +will create a classification model for hotspot predictions. ```{r, warning=FALSE, message=FALSE} library(tidyr) @@ -49,18 +50,18 @@ library(workflows) library(poissonreg) ``` -## Poisson Regression +## Poisson Regression During COVID-19, the US Center for Disease Control and Prevention (CDC) collected models and forecasts to characterize the state of an outbreak and its course. They use -it to inform public health decision makers on potential consequences of +it to inform public health decision makers on potential consequences of deploying control measures. One of the outcomes that the CDC forecasts is [death counts from COVID-19](https://www.cdc.gov/coronavirus/2019-ncov/science/forecasting/forecasting-us.html). -Although there are many state-of-the-art models, we choose to use Poisson +Although there are many state-of-the-art models, we choose to use Poisson regression, the textbook example for modeling count data, as an illustration -for using the `epipredict` package with other existing tidymodels packages. +for using the `epipredict` package with other existing tidymodels packages. ```{r poisson-reg-data} x <- pub_covidcast( @@ -69,7 +70,8 @@ x <- pub_covidcast( time_type = "day", geo_type = "state", time_values = epirange(20210604, 20211231), - geo_values = "ca,fl,tx,ny,nj") %>% + geo_values = "ca,fl,tx,ny,nj" +) %>% select(geo_value, time_value, cases = value) y <- pub_covidcast( @@ -78,72 +80,73 @@ y <- pub_covidcast( time_type = "day", geo_type = "state", time_values = epirange(20210604, 20211231), - geo_values = "ca,fl,tx,ny,nj") %>% + geo_values = "ca,fl,tx,ny,nj" +) %>% select(geo_value, time_value, deaths = value) counts_subset <- full_join(x, y, by = c("geo_value", "time_value")) %>% as_epi_df() ``` -The `counts_subset` dataset comes from the `epidatr` package, and -contains the number of confirmed cases and deaths from June 4, 2021 to -Dec 31, 2021 in some U.S. states. +The `counts_subset` dataset comes from the `epidatr` package, and +contains the number of confirmed cases and deaths from June 4, 2021 to +Dec 31, 2021 in some U.S. states. We wish to predict the 7-day ahead death counts with lagged cases and deaths. -Furthermore, we will let each state be a dummy variable. Using differential +Furthermore, we will let each state be a dummy variable. Using differential intercept coefficients, we can allow for an intercept shift between states. The model takes the form \begin{aligned} -\log\left( \mu_{t+7} \right) &= \beta_0 + \delta_1 s_{\text{state}_1} + -\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ &\quad\beta_1 \text{deaths}_{t} + -\beta_2 \text{deaths}_{t-7} + \beta_3 \text{cases}_{t} + -\beta_4 \text{cases}_{t-7}, +\log\left( \mu*{t+7} \right) &= \beta_0 + \delta_1 s*{\text{state}_1} + +\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ &\quad\beta_1 \text{deaths}_{t} + +\beta*2 \text{deaths}*{t-7} + \beta*3 \text{cases}*{t} + +\beta*4 \text{cases}*{t-7}, \end{aligned} -where $\mu_{t+7} = \mathbb{E}(y_{t+7})$, and $y_{t+7}$ is assumed to follow a -Poisson distribution with mean $\mu_{t+7}$; $s_{\text{state}}$ are dummy -variables for each state and take values of either 0 or 1. +where $\mu_{t+7} = \mathbb{E}(y_{t+7})$, and $y_{t+7}$ is assumed to follow a +Poisson distribution with mean $\mu_{t+7}$; $s_{\text{state}}$ are dummy +variables for each state and take values of either 0 or 1. Preprocessing steps will be performed to prepare the -data for model fitting. But before diving into them, it will be helpful to understand what `roles` are in the `recipes` framework. +data for model fitting. But before diving into them, it will be helpful to understand what `roles` are in the `recipes` framework. --- #### Aside on `recipes` -`recipes` can assign one or more roles to each column in the data. The roles -are not restricted to a predefined set; they can be anything. -For most conventional situations, they are typically “predictor” and/or -"outcome". Additional roles enable targeted `step_*()` operations on specific +`recipes` can assign one or more roles to each column in the data. The roles +are not restricted to a predefined set; they can be anything. +For most conventional situations, they are typically “predictor” and/or +"outcome". Additional roles enable targeted `step_*()` operations on specific variables or groups of variables. In our case, the role `predictor` is given to explanatory variables on the -right-hand side of the model (in the equation above). -The role `outcome` is the response variable -that we wish to predict. `geo_value` and `time_value` are predefined roles -that are unique to the `epipredict` package. Since we work with `epi_df` +right-hand side of the model (in the equation above). +The role `outcome` is the response variable +that we wish to predict. `geo_value` and `time_value` are predefined roles +that are unique to the `epipredict` package. Since we work with `epi_df` objects, all datasets should have `geo_value` and `time_value` passed through automatically with these two roles assigned to the appropriate columns in the data. - -The `recipes` package also allows [manual alterations of roles](https://recipes.tidymodels.org/reference/roles.html) -in bulk. There are a few handy functions that can be used together to help us -manipulate variable roles easily. -> `update_role()` alters an existing role in the recipe or assigns an initial role +The `recipes` package also allows [manual alterations of roles](https://recipes.tidymodels.org/reference/roles.html) +in bulk. There are a few handy functions that can be used together to help us +manipulate variable roles easily. + +> `update_role()` alters an existing role in the recipe or assigns an initial role > to variables that do not yet have a declared role. -> -> `add_role()` adds an additional role to variables that already have a role in +> +> `add_role()` adds an additional role to variables that already have a role in > the recipe, without overwriting old roles. -> +> > `remove_role()` eliminates a single existing role in the recipe. #### End aside --- -Notice in the following preprocessing steps, we used `add_role()` on +Notice in the following preprocessing steps, we used `add_role()` on `geo_value_factor` since, currently, the default role for it is `raw`, but -we would like to reuse this variable as `predictor`s. +we would like to reuse this variable as `predictor`s. ```{r} counts_subset <- counts_subset %>% @@ -157,7 +160,7 @@ r <- epi_recipe(counts_subset) %>% step_dummy(geo_value_factor) %>% ## Occasionally, data reporting errors / corrections result in negative ## cases / deaths - step_mutate(cases = pmax(cases, 0), deaths = pmax(deaths, 0)) %>% + step_mutate(cases = pmax(cases, 0), deaths = pmax(deaths, 0)) %>% step_epi_lag(cases, deaths, lag = c(0, 7)) %>% step_epi_ahead(deaths, ahead = 7, role = "outcome") %>% step_epi_naomit() @@ -165,7 +168,7 @@ r <- epi_recipe(counts_subset) %>% After specifying the preprocessing steps, we will use the `parsnip` package for modeling and producing the prediction for death count, 7 days after the -latest available date in the dataset. +latest available date in the dataset. ```{r} latest <- get_test_data(r, counts_subset) @@ -176,71 +179,71 @@ wf <- epi_workflow(r, parsnip::poisson_reg()) %>% predict(wf, latest) %>% filter(!is.na(.pred)) ``` -Note that the `time_value` corresponds to the last available date in the -training set, **NOT** to the target date of the forecast +Note that the `time_value` corresponds to the last available date in the +training set, **NOT** to the target date of the forecast (`r max(latest$time_value) + 7`). - Let's take a look at the fit: + ```{r} extract_fit_engine(wf) ``` -Up to now, we've used the Poisson regression to model count data. Poisson +Up to now, we've used the Poisson regression to model count data. Poisson regression can also be used to model rate data, such as case rates or death -rates, by incorporating offset terms in the model. +rates, by incorporating offset terms in the model. To model death rates, the Poisson regression would be expressed as: \begin{aligned} -\log\left( \mu_{t+7} \right) &= \log(\text{population}) + -\beta_0 + \delta_1 s_{\text{state}_1} + -\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ &\quad\beta_1 \text{deaths}_{t} + -\beta_2 \text{deaths}_{t-7} + \beta_3 \text{cases}_{t} + -\beta_4 \text{cases}_{t-7}\end{aligned} -where $\log(\text{population})$ is the log of the state population that was +\log\left( \mu*{t+7} \right) &= \log(\text{population}) + +\beta_0 + \delta_1 s*{\text{state}_1} + +\delta_2 s_{\text{state}_2} + \cdots + \nonumber \\ &\quad\beta_1 \text{deaths}_{t} + +\beta*2 \text{deaths}*{t-7} + \beta*3 \text{cases}*{t} + +\beta*4 \text{cases}*{t-7}\end{aligned} +where $\log(\text{population})$ is the log of the state population that was used to scale the count data on the left-hand side of the equation. This offset is simply a predictor with coefficient fixed at 1 rather than estimated. -There are several ways to model rate data given count and population data. -First, in the `parsnip` framework, we could specify the formula in `fit()`. -However, by doing so we lose the ability to use the `recipes` framework to -create new variables since variables that do not exist in the -original dataset (such as, here, the lags and leads) cannot be called directly in `fit()`. +There are several ways to model rate data given count and population data. +First, in the `parsnip` framework, we could specify the formula in `fit()`. +However, by doing so we lose the ability to use the `recipes` framework to +create new variables since variables that do not exist in the +original dataset (such as, here, the lags and leads) cannot be called directly in `fit()`. -Alternatively, `step_population_scaling()` and `layer_population_scaling()` -in the `epipredict` package can perform the population scaling if we provide the +Alternatively, `step_population_scaling()` and `layer_population_scaling()` +in the `epipredict` package can perform the population scaling if we provide the population data, which we will illustrate in the next section. +## Linear Regression -## Linear Regression - -For COVID-19, the CDC required submission of case and death count predictions. -However, the Delphi Group preferred to train on rate data instead, because it -puts different locations on a similar scale (eliminating the need for location-specific intercepts). +For COVID-19, the CDC required submission of case and death count predictions. +However, the Delphi Group preferred to train on rate data instead, because it +puts different locations on a similar scale (eliminating the need for location-specific intercepts). We can use a liner regression to predict the death -rates and use state population data to scale the rates to counts.[^pois] We will do so -using `layer_population_scaling()` from the `epipredict` package. +rates and use state population data to scale the rates to counts.[^pois] We will do so +using `layer_population_scaling()` from the `epipredict` package. [^pois]: We could continue with the Poisson model, but we'll switch to the Gaussian likelihood just for simplicity. -Additionally, when forecasts are submitted, prediction intervals should be +Additionally, when forecasts are submitted, prediction intervals should be provided along with the point estimates. This can be obtained via postprocessing using -`layer_residual_quantiles()`. It is worth pointing out, however, that -`layer_residual_quantiles()` should be used before population scaling or else -the transformation will make the results uninterpretable. +`layer_residual_quantiles()`. It is worth pointing out, however, that +`layer_residual_quantiles()` should be used before population scaling or else +the transformation will make the results uninterpretable. We wish, now, to predict the 7-day ahead death counts with lagged case rates and death rates, along with some extra behaviourial predictors. Namely, we will use survey data from [COVID-19 Trends and Impact Survey](https://cmu-delphi.github.io/delphi-epidata/api/covidcast-signals/fb-survey.html#behavior-indicators). -The survey data provides the estimated percentage of people who wore a mask for -most or all of the time while in public in the past 7 days and the estimated -percentage of respondents who reported that all or most people they encountered -in public in the past 7 days maintained a distance of at least 6 feet. +The survey data provides the estimated percentage of people who wore a mask for +most or all of the time while in public in the past 7 days and the estimated +percentage of respondents who reported that all or most people they encountered +in public in the past 7 days maintained a distance of at least 6 feet. State-wise population data from the 2019 U.S. Census is included in this package -and will be used in `layer_population_scaling()`. +and will be used in `layer_population_scaling()`. + ```{r} behav_ind_mask <- pub_covidcast( source = "fb-survey", @@ -248,7 +251,8 @@ behav_ind_mask <- pub_covidcast( time_type = "day", geo_type = "state", time_values = epirange(20210604, 20211231), - geo_values = "ca,fl,tx,ny,nj") %>% + geo_values = "ca,fl,tx,ny,nj" +) %>% select(geo_value, time_value, masking = value) behav_ind_distancing <- pub_covidcast( @@ -257,13 +261,14 @@ behav_ind_distancing <- pub_covidcast( time_type = "day", geo_type = "state", time_values = epirange(20210604, 20211231), - geo_values = "ca,fl,tx,ny,nj") %>% - select(geo_value, time_value, distancing = value) + geo_values = "ca,fl,tx,ny,nj" +) %>% + select(geo_value, time_value, distancing = value) pop_dat <- state_census %>% select(abbr, pop) behav_ind <- behav_ind_mask %>% - full_join(behav_ind_distancing, by = c("geo_value", "time_value")) + full_join(behav_ind_distancing, by = c("geo_value", "time_value")) ``` Rather than using raw mask-wearing / social-distancing metrics, for the sake @@ -277,50 +282,53 @@ behav_ind %>% geom_density(alpha = 0.5) + scale_fill_brewer(palette = "Set1", name = "") + theme_bw() + - scale_x_continuous(expand = c(0,0)) + - scale_y_continuous(expand = expansion(c(0,.05))) + + scale_x_continuous(expand = c(0, 0)) + + scale_y_continuous(expand = expansion(c(0, .05))) + facet_wrap(~name, scales = "free") + theme(legend.position = "bottom") ``` -We will take a subset of death rate and case rate data from the built-in dataset +We will take a subset of death rate and case rate data from the built-in dataset `case_death_rate_subset`. ```{r} jhu <- filter( case_death_rate_subset, - time_value >= "2021-06-04", + time_value >= "2021-06-04", time_value <= "2021-12-31", - geo_value %in% c("ca","fl","tx","ny","nj") + geo_value %in% c("ca", "fl", "tx", "ny", "nj") ) ``` Preprocessing steps will again rely on functions from the `epipredict` package as well as the `recipes` package. -There are also many functions in the `recipes` package that allow for +There are also many functions in the `recipes` package that allow for [scalar transformations](https://recipes.tidymodels.org/reference/#step-functions-individual-transformations), -such as log transformations and data centering. In our case, we will -center the numerical predictors to allow for a more meaningful interpretation of the -intercept. +such as log transformations and data centering. In our case, we will +center the numerical predictors to allow for a more meaningful interpretation of the +intercept. ```{r} jhu <- jhu %>% mutate(geo_value_factor = as.factor(geo_value)) %>% left_join(behav_ind, by = c("geo_value", "time_value")) %>% as_epi_df() - + r <- epi_recipe(jhu) %>% add_role(geo_value_factor, new_role = "predictor") %>% step_dummy(geo_value_factor) %>% step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) %>% - step_mutate(masking = cut_number(masking, 5), - distancing = cut_number(distancing, 5)) %>% + step_mutate( + masking = cut_number(masking, 5), + distancing = cut_number(distancing, 5) + ) %>% step_epi_ahead(death_rate, ahead = 7, role = "outcome") %>% step_center(contains("lag"), role = "predictor") %>% step_epi_naomit() ``` As a sanity check we can examine the structure of the training data: + ```{r, warning = FALSE} glimpse(slice_sample(bake(prep(r, jhu), jhu), n = 6)) ``` @@ -334,16 +342,17 @@ to create median predictions and a 90% prediction interval. ```{r, warning=FALSE} f <- frosting() %>% layer_predict() %>% - layer_add_target_date("2022-01-07") %>% + layer_add_target_date("2022-01-07") %>% layer_threshold(.pred, lower = 0) %>% layer_quantile_distn() %>% layer_naomit(.pred) %>% layer_population_scaling( - .pred, .pred_distn, - df = pop_dat, + .pred, .pred_distn, + df = pop_dat, rate_rescaling = 1e5, - by = c("geo_value" = "abbr"), - df_pop_col = "pop") + by = c("geo_value" = "abbr"), + df_pop_col = "pop" + ) wf <- epi_workflow(r, quantile_reg(tau = c(.05, .5, .95))) %>% fit(jhu) %>% @@ -358,6 +367,7 @@ The columns marked `*_scaled` have been rescaled to the correct units, in this case `deaths` rather than deaths per 100K people (these remain in `.pred`). To look at the prediction intervals: + ```{r} p %>% select(geo_value, target_date, .pred_scaled, .pred_distn_scaled) %>% @@ -366,9 +376,9 @@ p %>% pivot_wider(names_from = tau, values_from = q) ``` - -Last but not least, let's take a look at the regression fit and check the +Last but not least, let's take a look at the regression fit and check the coefficients: + ```{r, echo =FALSE} extract_fit_engine(wf) ``` @@ -379,63 +389,66 @@ Sometimes it is preferable to create a predictive model for surges or upswings rather than for raw values. In this case, the target is to predict if the future will have increased case rates (denoted `up`), decreased case rates (`down`), or flat case rates (`flat`) relative to the current -level. Such models may be -referred to as "hotspot prediction models". We will follow the analysis +level. Such models may be +referred to as "hotspot prediction models". We will follow the analysis in [McDonald, Bien, Green, Hu, et al.](#references) but extend the application -to predict three categories instead of two. +to predict three categories instead of two. -Hotspot prediction uses a categorical outcome variable defined in terms of the -relative change of $Y_{\ell, t+a}$ compared to $Y_{\ell, t}$. -Where $Y_{\ell, t}$ denotes the case rates in location $\ell$ at time $t$. +Hotspot prediction uses a categorical outcome variable defined in terms of the +relative change of $Y_{\ell, t+a}$ compared to $Y_{\ell, t}$. +Where $Y_{\ell, t}$ denotes the case rates in location $\ell$ at time $t$. We define the response variables as follows: $$ Z_{\ell, t}= \begin{cases} - \text{up}, & \text{if}\ Y^{\Delta}_{\ell, t} > 0.25 \\ + \text{up}, & \text{if}\ Y^{\Delta}_{\ell, t} > 0.25 \\ \text{down}, & \text{if}\ Y^{\Delta}_{\ell, t} < -0.20\\ \text{flat}, & \text{otherwise} \end{cases} $$ -where $Y^{\Delta}_{\ell, t} = (Y_{\ell, t}- Y_{\ell, t-7})\ /\ (Y_{\ell, t-7})$. -We say location $\ell$ is a hotspot at time $t$ when $Z_{\ell,t}$ is -`up`, meaning the number of newly reported cases over the past 7 days has -increased by at least 25% compared to the preceding week. When $Z_{\ell,t}$ -is categorized as `down`, it suggests that there has been at least a 20% -decrease in newly reported cases over the past 7 days (a 20% decrease is the inverse of a 25% increase). Otherwise, we will -consider the trend to be `flat`. +where $Y^{\Delta}_{\ell, t} = (Y_{\ell, t}- Y_{\ell, t-7})\ /\ (Y_{\ell, t-7})$. +We say location $\ell$ is a hotspot at time $t$ when $Z_{\ell,t}$ is +`up`, meaning the number of newly reported cases over the past 7 days has +increased by at least 25% compared to the preceding week. When $Z_{\ell,t}$ +is categorized as `down`, it suggests that there has been at least a 20% +decrease in newly reported cases over the past 7 days (a 20% decrease is the inverse of a 25% increase). Otherwise, we will +consider the trend to be `flat`. The expression of the multinomial regression we will use is as follows: + $$ \pi_{j}(x) = \text{Pr}(Z_{\ell,t} = j|x) = \frac{e^{g_j(x)}}{1 + \sum_{k=0}^2 g_j(x) } $$ + where $j$ is either down, flat, or up \begin{aligned} -g_{\text{down}}(x) &= 0.\\ -g_{\text{flat}}(x)&= \text{ln}\left(\frac{Pr(Z_{\ell,t}=\text{flat}|x)}{Pr(Z_{\ell,t}=\text{down}|x)}\right) = -\beta_{10} + \beta_{11}t + \delta_{10} s_{\text{state_1}} + -\delta_{11} s_{\text{state_2}} + \cdots \nonumber \\ -&\quad + \beta_{12} Y^{\Delta}_{\ell, t} + +g*{\text{down}}(x) &= 0.\\ +g*{\text{flat}}(x)&= \text{ln}\left(\frac{Pr(Z*{\ell,t}=\text{flat}|x)}{Pr(Z*{\ell,t}=\text{down}|x)}\right) = +\beta*{10} + \beta*{11}t + \delta*{10} s*{\text{state*1}} + +\delta*{11} s*{\text{state_2}} + \cdots \nonumber \\ +&\quad + \beta*{12} Y^{\Delta}_{\ell, t} + \beta_{13} Y^{\Delta}_{\ell, t-7} \\ -g_{\text{flat}}(x) &= \text{ln}\left(\frac{Pr(Z_{\ell,t}=\text{up}|x)}{Pr(Z_{\ell,t}=\text{down}|x)}\right) = -\beta_{20} + \beta_{21}t + \delta_{20} s_{\text{state_1}} + -\delta_{21} s_{\text{state}_2} + \cdots \nonumber \\ -&\quad + \beta_{22} Y^{\Delta}_{\ell, t} + -\beta_{23} Y^{\Delta}_{\ell, t-7} +g_{\text{flat}}(x) &= \text{ln}\left(\frac{Pr(Z*{\ell,t}=\text{up}|x)}{Pr(Z*{\ell,t}=\text{down}|x)}\right) = +\beta*{20} + \beta*{21}t + \delta*{20} s*{\text{state*1}} + +\delta*{21} s*{\text{state}\_2} + \cdots \nonumber \\ +&\quad + \beta*{22} Y^{\Delta}_{\ell, t} + +\beta_{23} Y^{\Delta}\_{\ell, t-7} \end{aligned} - - -Preprocessing steps are similar to the previous models with an additional step -of categorizing the response variables. Again, we will use a subset of death rate and case rate data from our built-in dataset +Preprocessing steps are similar to the previous models with an additional step +of categorizing the response variables. Again, we will use a subset of death rate and case rate data from our built-in dataset `case_death_rate_subset`. + ```{r} jhu <- case_death_rate_subset %>% - dplyr::filter(time_value >= "2021-06-04", - time_value <= "2021-12-31", - geo_value %in% c("ca","fl","tx","ny","nj")) %>% + dplyr::filter( + time_value >= "2021-06-04", + time_value <= "2021-12-31", + geo_value %in% c("ca", "fl", "tx", "ny", "nj") + ) %>% mutate(geo_value_factor = as.factor(geo_value)) %>% as_epi_df() @@ -447,21 +460,29 @@ r <- epi_recipe(jhu) %>% step_mutate( pct_diff_ahead = case_when( lag_7_case_rate == 0 ~ 0, - TRUE ~ (ahead_7_case_rate - lag_0_case_rate) / lag_0_case_rate), + TRUE ~ (ahead_7_case_rate - lag_0_case_rate) / lag_0_case_rate + ), pct_diff_wk1 = case_when( - lag_7_case_rate == 0 ~ 0, - TRUE ~ (lag_0_case_rate - lag_7_case_rate) / lag_7_case_rate), + lag_7_case_rate == 0 ~ 0, + TRUE ~ (lag_0_case_rate - lag_7_case_rate) / lag_7_case_rate + ), pct_diff_wk2 = case_when( lag_14_case_rate == 0 ~ 0, - TRUE ~ (lag_7_case_rate - lag_14_case_rate) / lag_14_case_rate)) %>% + TRUE ~ (lag_7_case_rate - lag_14_case_rate) / lag_14_case_rate + ) + ) %>% step_mutate( response = case_when( pct_diff_ahead < -0.20 ~ "down", pct_diff_ahead > 0.25 ~ "up", - TRUE ~ "flat"), - role = "outcome") %>% - step_rm(death_rate, case_rate, lag_0_case_rate, lag_7_case_rate, - lag_14_case_rate, ahead_7_case_rate, pct_diff_ahead) %>% + TRUE ~ "flat" + ), + role = "outcome" + ) %>% + step_rm( + death_rate, case_rate, lag_0_case_rate, lag_7_case_rate, + lag_14_case_rate, ahead_7_case_rate, pct_diff_ahead + ) %>% step_epi_naomit() ``` @@ -476,15 +497,17 @@ predict(wf, latest) %>% filter(!is.na(.pred_class)) ``` We can also look at the estimated coefficients and model summary information: + ```{r} extract_fit_engine(wf) ``` -One could also use a formula in `epi_recipe()` to achieve the same results as -above. However, only one of `add_formula()`, `add_recipe()`, or -`workflow_variables()` can be specified. For the purpose of demonstrating +One could also use a formula in `epi_recipe()` to achieve the same results as +above. However, only one of `add_formula()`, `add_recipe()`, or +`workflow_variables()` can be specified. For the purpose of demonstrating `add_formula` rather than `add_recipe`, we will `prep` and `bake` our recipe to return a `data.frame` that could be used for model fitting. + ```{r} b <- bake(prep(r, jhu), jhu) @@ -497,13 +520,14 @@ epi_workflow() %>% ## Benefits of Lagging and Leading in `epipredict` The `step_epi_ahead` and `step_epi_lag` functions in the `epipredict` package -is handy for creating correct lags and leads for future predictions. +is handy for creating correct lags and leads for future predictions. Let's start with a simple dataset and preprocessing: + ```{r} ex <- filter( - case_death_rate_subset, - time_value >= "2021-12-01", + case_death_rate_subset, + time_value >= "2021-12-01", time_value <= "2021-12-31", geo_value == "ca" ) @@ -511,10 +535,11 @@ ex <- filter( dim(ex) ``` -We want to predict death rates on `r max(ex$time_value) + 7`, which is 7 days ahead of the -latest available date in our dataset. +We want to predict death rates on `r max(ex$time_value) + 7`, which is 7 days ahead of the +latest available date in our dataset. We will compare two methods of trying to create lags and leads: + ```{r} p1 <- epi_recipe(ex) %>% step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% @@ -528,13 +553,15 @@ b1 p2 <- epi_recipe(ex) %>% - step_mutate(lag0case_rate = lag(case_rate, 0), - lag7case_rate = lag(case_rate, 7), - lag14case_rate = lag(case_rate, 14), - lag0death_rate = lag(death_rate, 0), - lag7death_rate = lag(death_rate, 7), - lag14death_rate = lag(death_rate, 14), - ahead7death_rate = lead(death_rate, 7)) %>% + step_mutate( + lag0case_rate = lag(case_rate, 0), + lag7case_rate = lag(case_rate, 7), + lag14case_rate = lag(case_rate, 14), + lag0death_rate = lag(death_rate, 0), + lag7death_rate = lag(death_rate, 7), + lag14death_rate = lag(death_rate, 14), + ahead7death_rate = lead(death_rate, 7) + ) %>% step_epi_naomit() %>% prep() @@ -542,37 +569,37 @@ b2 <- bake(p2, ex) b2 ``` -Notice the difference in number of rows `b1` and `b2` returns. This is because +Notice the difference in number of rows `b1` and `b2` returns. This is because the second version, the one that doesn't use `step_epi_ahead` and `step_epi_lag`, has omitted dates compared to the one that used the `epipredict` functions. + ```{r} -dates_used_in_training1 <- b1 %>% - select(-ahead_7_death_rate) %>% - na.omit() %>% +dates_used_in_training1 <- b1 %>% + select(-ahead_7_death_rate) %>% + na.omit() %>% pull(time_value) dates_used_in_training1 -dates_used_in_training2 <- b2 %>% - select(-ahead7death_rate) %>% - na.omit() %>% +dates_used_in_training2 <- b2 %>% + select(-ahead7death_rate) %>% + na.omit() %>% pull(time_value) dates_used_in_training2 ``` -The model that is trained based on the `{recipes}` functions will predict 7 days ahead from +The model that is trained based on the `{recipes}` functions will predict 7 days ahead from `r max(dates_used_in_training2)` instead of 7 days ahead from `r max(dates_used_in_training1)`. ## References -McDonald, Bien, Green, Hu, et al. "Can auxiliary indicators improve COVID-19 -forecasting and hotspot prediction?." Proceedings of the National Academy of -Sciences 118.51 (2021): e2111453118. [doi:10.1073/pnas.2111453118]( -https://doi.org/10.1073/pnas.2111453118) +McDonald, Bien, Green, Hu, et al. "Can auxiliary indicators improve COVID-19 +forecasting and hotspot prediction?." Proceedings of the National Academy of +Sciences 118.51 (2021): e2111453118. [doi:10.1073/pnas.2111453118](https://doi.org/10.1073/pnas.2111453118) ## Attribution This object contains a modified part of the [COVID-19 Data Repository by the Center for Systems Science and Engineering (CSSE) at Johns Hopkins University](https://github.com/CSSEGISandData/COVID-19) as [republished in the COVIDcast Epidata API.](https://cmu-delphi.github.io/delphi-epidata/api/covidcast-signals/jhu-csse.html) -This data set is licensed under the terms of the [Creative Commons Attribution 4.0 International license](https://creativecommons.org/licenses/by/4.0/) by the Johns Hopkins University +This data set is licensed under the terms of the [Creative Commons Attribution 4.0 International license](https://creativecommons.org/licenses/by/4.0/) by the Johns Hopkins University on behalf of its Center for Systems Science in Engineering. Copyright Johns Hopkins University 2020.