Skip to content

Commit 3d77e01

Browse files
committed
clean up some imports
1 parent 5b77c1a commit 3d77e01

File tree

7 files changed

+25
-11
lines changed

7 files changed

+25
-11
lines changed

DESCRIPTION

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Imports:
1818
cli,
1919
dplyr,
2020
epiprocess,
21+
generics,
2122
glue,
2223
hardhat (>= 1.0.0.9000),
2324
magrittr,

NAMESPACE

+3-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ export(smooth_arx_forecaster)
4545
export(step_epi_ahead)
4646
export(step_epi_lag)
4747
import(recipes)
48-
importFrom(broom,augment)
48+
importFrom(generics,augment)
49+
importFrom(generics,fit)
4950
importFrom(magrittr,"%>%")
5051
importFrom(rlang,"!!")
5152
importFrom(rlang,":=")
@@ -60,4 +61,5 @@ importFrom(stats,predict)
6061
importFrom(stats,quantile)
6162
importFrom(stats,residuals)
6263
importFrom(stats,setNames)
64+
importFrom(tibble,as_tibble)
6365
importFrom(tibble,tibble)

R/epi_recipe.R

+1
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ kill_levels <- function(x, keys) {
417417
x
418418
}
419419

420+
#' @importFrom tibble as_tibble
420421
#' @export
421422
as_tibble.epi_df <- function(x, ...) {
422423
# so that downstream calls to as_tibble don't clobber our metadata

R/epi_workflow.R

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#' @return A new `epi_workflow` object.
1414
#' @seealso workflows::workflow
1515
#' @importFrom rlang is_null
16+
#' @importFrom stats predict
17+
#' @importFrom generics fit
18+
#' @importFrom generics augment
1619
#' @export
1720
#' @examples
1821
#' library(dplyr)
@@ -168,7 +171,6 @@ grab_forged_keys <- function(forged, mold, new_data) {
168171
#' @param ... Arguments passed on to the predict method.
169172
#'
170173
#' @return new_data with additional columns containing the predicted values
171-
#' @importFrom broom augment
172174
#' @export
173175
augment.epi_workflow <- function (x, new_data, ...) {
174176
predictions <- predict(x, new_data, ...)

R/layer_predict.R

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
layer_predict <-
2-
function(frosting, id = rand_id("predict_default")) {
2+
function(frosting, type = NULL, opts = list(), ..., id = rand_id("predict_default")) {
33
add_layer(
44
frosting,
55
layer_predict_new(
6+
type = type,
7+
opts = opts,
8+
dots_list = rlang::list2(...), # can't figure how to use this
69
id = id
710
)
811
)
912
}
1013

1114

12-
layer_predict_new <- function(id) {
13-
layer("predict", id = id)
15+
layer_predict_new <- function(type, opts, dots_list, id) {
16+
layer("predict", type = type, opts = opts, dots_list = dots_list, id = id)
1417
}
1518

1619
#' @export
1720
slather.layer_predict <- function(object, components, the_fit) {
18-
components$predictions <- predict(the_fit, components$forged$predictors)
21+
22+
components$predictions <- predict(the_fit, components$forged$predictors,
23+
type = object$type, opts = object$opts)
1924
components$predictions <- dplyr::bind_cols(components$keys, components$predictions)
2025
components
2126
}

tests/testthat/test-frosting.R

+5-3
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ test_that("prediction works without any postprocessor", {
2424
step_epi_ahead(death_rate, ahead = 7) %>%
2525
step_naomit(all_predictors()) %>%
2626
step_naomit(all_outcomes(), skip = TRUE)
27-
wf <- epi_workflow(r, linear_reg()) %>% fit(jhu)
28-
latest <- jhu %>% filter(time_value >= max(time_value) - 14)
27+
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
28+
latest <- jhu %>%
29+
dplyr::filter(time_value >= max(time_value) - 14)
2930

3031
expect_silent(predict(wf, latest))
31-
p <- predict(wf, latest) %>% dplyr::filter(!is.na(.pred))
32+
p <- predict(wf, latest) %>%
33+
dplyr::filter(!is.na(.pred))
3234
expect_equal(nrow(p), 3)
3335
expect_s3_class(p, "epi_df")
3436
expect_equal(tail(p$time_value, 1), as.Date("2021-12-31"))

tests/testthat/test-layer_predict.R

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ test_that("predict layer works alone", {
66
step_epi_ahead(death_rate, ahead = 7) %>%
77
step_naomit(all_predictors()) %>%
88
step_naomit(all_outcomes(), skip = TRUE)
9-
wf <- epi_workflow(r, linear_reg()) %>% fit(jhu)
10-
latest <- jhu %>% filter(time_value >= max(time_value) - 14)
9+
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
10+
latest <- jhu %>%
11+
dplyr::filter(time_value >= max(time_value) - 14)
1112

1213
f <- frosting() %>% layer_predict()
1314
wf <- wf %>% add_frosting(f)

0 commit comments

Comments
 (0)