Skip to content

325 workflow adj #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method(Add_model,epi_workflow)
S3method(Add_model,workflow)
S3method(Math,dist_quantiles)
S3method(Ops,dist_quantiles)
S3method(add_model,epi_workflow)
S3method(Remove_model,epi_workflow)
S3method(Remove_model,workflow)
S3method(Update_model,epi_workflow)
S3method(Update_model,workflow)
S3method(adjust_epi_recipe,epi_recipe)
S3method(adjust_epi_recipe,epi_workflow)
S3method(adjust_frosting,epi_workflow)
Expand Down Expand Up @@ -92,7 +97,6 @@ S3method(print,step_population_scaling)
S3method(print,step_training_window)
S3method(quantile,dist_quantiles)
S3method(refresh_blueprint,default_epi_recipe_blueprint)
S3method(remove_model,epi_workflow)
S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
Expand All @@ -115,10 +119,12 @@ S3method(tidy,check_enough_train_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
S3method(update_model,epi_workflow)
S3method(vec_ptype_abbr,dist_quantiles)
S3method(vec_ptype_full,dist_quantiles)
export("%>%")
export(Add_model)
export(Remove_model)
export(Update_model)
export(add_epi_recipe)
export(add_frosting)
export(add_layer)
Expand Down
101 changes: 0 additions & 101 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,107 +59,6 @@ is_epi_workflow <- function(x) {
}


#' Add a model to an `epi_workflow`
#'
#' @seealso [workflows::add_model()]
#' - `add_model()` adds a parsnip model to the `epi_workflow`.
#'
#' - `remove_model()` removes the model specification as well as any fitted
#' model object. Any extra formulas are also removed.
#'
#' - `update_model()` first removes the model then adds the new
#' specification to the workflow.
#'
#' @details
#' Has the same behaviour as [workflows::add_model()] but also ensures
#' that the returned object is an `epi_workflow`.
#'
#' @inheritParams workflows::add_model
#'
#' @param x An `epi_workflow`.
#'
#' @param spec A parsnip model specification.
#'
#' @param ... Not used.
#'
#' @return
#' `x`, updated with a new, updated, or removed model.
#'
#' @export
#' @examples
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(
#' time_value > "2021-11-01",
#' geo_value %in% c("ak", "ca", "ny")
#' )
#'
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7)
#'
#' rf_model <- rand_forest(mode = "regression")
#'
#' wf <- epi_workflow(r)
#'
#' wf <- wf %>% add_model(rf_model)
#' wf
#'
#' lm_model <- parsnip::linear_reg()
#'
#' wf <- update_model(wf, lm_model)
#' wf
#'
#' wf <- remove_model(wf)
#' wf
#' @export
add_model <- function(x, spec, ..., formula = NULL) {
UseMethod("add_model")
}

#' @rdname add_model
#' @export
remove_model <- function(x) {
UseMethod("remove_model")
}

#' @rdname add_model
#' @export
update_model <- function(x, spec, ..., formula = NULL) {
UseMethod("update_model")
}

#' @rdname add_model
#' @export
add_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
workflows::add_model(x, spec, ..., formula = formula)
}

#' @rdname add_model
#' @export
remove_model.epi_workflow <- function(x) {
workflows:::validate_is_workflow(x)

if (!workflows:::has_spec(x)) {
rlang::warn("The workflow has no model to remove.")
}

new_epi_workflow(
pre = x$pre,
fit = workflows:::new_stage_fit(),
post = x$post,
trained = FALSE
)
}

#' @rdname add_model
#' @export
update_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
rlang::check_dots_empty()
x <- remove_model(x)
workflows::add_model(x, spec, ..., formula = formula)
}


#' Fit an `epi_workflow` object
#'
#' @description
Expand Down
133 changes: 133 additions & 0 deletions R/model-methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#' Add a model to an `epi_workflow`
#'
#' @seealso [workflows::add_model()]
#' - `Add_model()` adds a parsnip model to the `epi_workflow`.
#'
#' - `Remove_model()` removes the model specification as well as any fitted
#' model object. Any extra formulas are also removed.
#'
#' - `Update_model()` first removes the model then adds the new
#' specification to the workflow.
#'
#' @details
#' Has the same behaviour as [workflows::add_model()] but also ensures
#' that the returned object is an `epi_workflow`.
#'
#' This family is called `Add_*` / `Update_*` / `Remove_*` to avoid
#' masking the related functions in `{workflows}`. We also provide
#' aliases with the lower-case names. However, in the event that
#' `{workflows}` is loaded after `{epipredict}`, these may fail to function
#' properly.
#'
#' @inheritParams workflows::add_model
#'
#' @param x An `epi_workflow`.
#'
#' @param spec A parsnip model specification.
#'
#' @param ... Not used.
#'
#' @return
#' `x`, updated with a new, updated, or removed model.
#'
#' @export
#' @examples
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(
#' time_value > "2021-11-01",
#' geo_value %in% c("ak", "ca", "ny")
#' )
#'
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7)
#'
#' rf_model <- rand_forest(mode = "regression")
#'
#' wf <- epi_workflow(r)
#'
#' wf <- wf %>% Add_model(rf_model)
#' wf
#'
#' lm_model <- parsnip::linear_reg()
#'
#' wf <- Update_model(wf, lm_model)
#' wf
#'
#' wf <- Remove_model(wf)
#' wf
#' @export
Add_model <- function(x, spec, ..., formula = NULL) {
UseMethod("Add_model")
}

#' @rdname Add_model
#' @export
Remove_model <- function(x) {
UseMethod("Remove_model")
}

#' @rdname Add_model
#' @export
Update_model <- function(x, spec, ..., formula = NULL) {
UseMethod("Update_model")
}

#' @rdname Add_model
#' @export
Add_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
workflows::add_model(x, spec, ..., formula = formula)
}

#' @rdname Add_model
#' @export
Remove_model.epi_workflow <- function(x) {
workflows:::validate_is_workflow(x)

if (!workflows:::has_spec(x)) {
rlang::warn("The workflow has no model to remove.")
}

new_epi_workflow(
pre = x$pre,
fit = workflows:::new_stage_fit(),
post = x$post,
trained = FALSE
)
}

#' @rdname Add_model
#' @export
Update_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
rlang::check_dots_empty()
x <- Remove_model(x)
Add_model(x, spec, ..., formula = formula)
}


#' @rdname Add_model
#' @export
Add_model.workflow <- workflows::add_model

#' @rdname Add_model
#' @export
Remove_model.workflow <- workflows::remove_model

#' @rdname Add_model
#' @export
Update_model.workflow <- workflows::update_model


# Aliases -----------------------------------------------------------------

#' @rdname Add_model
#' @export
add_model <- Add_model

#' @rdname Add_model
#' @export
remove_model <- Remove_model

#' @rdname Add_model
#' @export
update_model <- Update_model
45 changes: 23 additions & 22 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ navbar:
type: light

articles:
- title: Get started
navbar: ~
contents:
- epipredict
- preprocessing-and-models
- arx-classifier
- articles/update
- title: Get started
navbar: ~
contents:
- epipredict
- preprocessing-and-models
- arx-classifier
- update

- title: Advanced methods
contents:
- articles/sliding
- articles/smooth-qr
- articles/symptom-surveys
- panel-data
- title: Advanced methods
contents:
- articles/sliding
- articles/smooth-qr
- articles/symptom-surveys
- panel-data

repo:
url:
Expand Down Expand Up @@ -78,15 +78,16 @@ reference:
- smooth_quantile_reg
- title: Custom panel data forecasting workflows
contents:
- epi_recipe
- epi_workflow
- add_epi_recipe
- adjust_epi_recipe
- add_model
- predict.epi_workflow
- fit.epi_workflow
- augment.epi_workflow
- forecast.epi_workflow
- epi_recipe
- epi_workflow
- add_epi_recipe
- adjust_epi_recipe
- Add_model
- predict.epi_workflow
- fit.epi_workflow
- augment.epi_workflow
- forecast.epi_workflow

- title: Epi recipe preprocessing steps
contents:
- starts_with("step_")
Expand Down
Loading
Loading