Skip to content

Check that the f passed to epi[x]_slide takes enough args #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f0465c4
check num args before ... in f
nmdefries Apr 20, 2023
138cdb0
only warn if dots provided in args
nmdefries Apr 21, 2023
e2c8152
test num args errors and warnings
nmdefries Apr 21, 2023
95fc474
use older support for not raising errors in tests
nmdefries Apr 24, 2023
066fb4e
factor out args check
nmdefries Apr 24, 2023
ed2a5e9
test slide fn and arg check fn arg warnings
nmdefries Apr 24, 2023
1ec3597
reduce slide arg check test coverage
nmdefries Apr 24, 2023
a3f227c
Rename check->assert_sufficient_f_args, tweak warning text & fields
lcbrooks May 8, 2023
6825a5c
rename error class to match func name
nmdefries May 10, 2023
51b12ee
drop regexp in tests where also specify error class
nmdefries May 10, 2023
d5466d3
check if required fields already have defaults set
nmdefries May 11, 2023
46cf783
test default checking
nmdefries May 11, 2023
a572464
Suppress forwarded warning from warning+error sufficient-args test
lcbrooks May 15, 2023
b982526
Account for `...` forwarding in `assert_sufficient_f_args`
lcbrooks May 16, 2023
290b4b9
Consider unnamed dots forwarding in `assert_sufficient_f_args`
lcbrooks May 16, 2023
34649f8
factor out "dots_i -1" to var
nmdefries May 17, 2023
6ebb1a5
import tail
nmdefries May 17, 2023
f6836d3
Message about right args when `f` default is suspiciously replaced
brookslogan May 18, 2023
10fb9e4
Message about right args when they fall into `f`'s dots
lcbrooks May 18, 2023
f0f0105
Also message about args fed to `f` dots when it has dots first
lcbrooks May 18, 2023
a872728
Fix and test some other corner cases in f arg checking
lcbrooks May 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ Description: This package introduces a common data structure for epidemiological
work with revisions to these data sets over time, and offers associated
utilities to perform basic signal processing tasks.
License: MIT + file LICENSE
Imports:
Imports:
cli,
data.table,
dplyr (>= 1.0.0),
fabletools,
Expand All @@ -48,7 +49,7 @@ Suggests:
knitr,
outbreaks,
rmarkdown,
testthat (>= 3.0.0),
testthat (>= 3.1.5),
waldo (>= 0.3.1),
withr
VignetteBuilder:
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ importFrom(dplyr,ungroup)
importFrom(lubridate,days)
importFrom(lubridate,weeks)
importFrom(magrittr,"%>%")
importFrom(purrr,map_lgl)
importFrom(rlang,"!!!")
importFrom(rlang,"!!")
importFrom(rlang,.data)
importFrom(rlang,.env)
importFrom(rlang,arg_match)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
importFrom(rlang,is_missing)
importFrom(rlang,is_quosure)
importFrom(rlang,quo_is_missing)
importFrom(rlang,sym)
Expand All @@ -101,3 +103,4 @@ importFrom(tidyr,unnest)
importFrom(tidyselect,eval_select)
importFrom(tidyselect,starts_with)
importFrom(tsibble,as_tsibble)
importFrom(utils,tail)
5 changes: 5 additions & 0 deletions R/grouped_epi_archive.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ grouped_epi_archive =
ref_time_values = sort(ref_time_values)
}

# Check that `f` takes enough args
if (!missing(f) && is.function(f)) {
assert_sufficient_f_args(f, ...)
}

# Validate and pre-process `before`:
if (missing(before)) {
Abort("`before` is required (and must be passed by name);
Expand Down
7 changes: 6 additions & 1 deletion R/slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,12 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,

# Check we have an `epi_df` object
if (!inherits(x, "epi_df")) Abort("`x` must be of class `epi_df`.")


# Check that `f` takes enough args
if (!missing(f) && is.function(f)) {
assert_sufficient_f_args(f, ...)
}

# Arrange by increasing time_value
x = arrange(x, time_value)

Expand Down
81 changes: 81 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,87 @@ paste_lines = function(lines) {
Abort = function(msg, ...) rlang::abort(break_str(msg, init = "Error: "), ...)
Warn = function(msg, ...) rlang::warn(break_str(msg, init = "Warning: "), ...)

#' Assert that a sliding computation function takes enough args
#'
#' @param f Function; specifies a computation to slide over an `epi_df` or
#' `epi_archive` in `epi_slide` or `epix_slide`.
#' @param ... Dots that will be forwarded to `f` from the dots of `epi_slide` or
#' `epix_slide`.
#'
#' @importFrom rlang is_missing
#' @importFrom purrr map_lgl
#' @importFrom utils tail
#'
#' @noRd
assert_sufficient_f_args <- function(f, ...) {
mandatory_f_args_labels <- c("window data", "group key")
n_mandatory_f_args <- length(mandatory_f_args_labels)
args = formals(args(f))
args_names = names(args)
# Remove named arguments forwarded from `epi[x]_slide`'s `...`:
forwarded_dots_names = names(rlang::call_match(dots_expand = FALSE)[["..."]])
args_matched_in_dots =
# positional calling args will skip over args matched by named calling args
args_names %in% forwarded_dots_names &
# extreme edge case: `epi[x]_slide(<stuff>, dot = 1, `...` = 2)`
args_names != "..."
remaining_args = args[!args_matched_in_dots]
remaining_args_names = names(remaining_args)
# note that this doesn't include unnamed args forwarded through `...`.
dots_i <- which(remaining_args_names == "...") # integer(0) if no match
n_f_args_before_dots <- dots_i - 1L
if (length(dots_i) != 0L) { # `f` has a dots "arg"
# Keep all arg names before `...`
mandatory_args_mapped_names <- remaining_args_names[seq_len(n_f_args_before_dots)]

if (n_f_args_before_dots < n_mandatory_f_args) {
mandatory_f_args_in_f_dots =
tail(mandatory_f_args_labels, n_mandatory_f_args - n_f_args_before_dots)
cli::cli_warn(
"`f` might not have enough positional arguments before its `...`; in the current `epi[x]_slide` call, the {mandatory_f_args_in_f_dots} will be included in `f`'s `...`; if `f` doesn't expect those arguments, it may produce confusing error messages",
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots",
epiprocess__f = f,
epiprocess__mandatory_f_args_in_f_dots = mandatory_f_args_in_f_dots
)
}
} else { # `f` doesn't have a dots "arg"
if (length(args_names) < n_mandatory_f_args + rlang::dots_n(...)) {
# `f` doesn't take enough args.
if (rlang::dots_n(...) == 0L) {
# common case; try for friendlier error message
Abort(sprintf("`f` must take at least %s arguments", n_mandatory_f_args),
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args",
epiprocess__f = f)
} else {
# less common; highlight that they are (accidentally?) using dots forwarding
Abort(sprintf("`f` must take at least %s arguments plus the %s arguments forwarded through `epi[x]_slide`'s `...`, or a named argument to `epi[x]_slide` was misspelled", n_mandatory_f_args, rlang::dots_n(...)),
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded",
epiprocess__f = f)
}
}
}
# Check for args with defaults that are filled with mandatory positional
# calling args. If `f` has fewer than n_mandatory_f_args before `...`, then we
# only need to check those args for defaults. Note that `n_f_args_before_dots` is
# length 0 if `f` doesn't accept `...`.
n_remaining_args_for_default_check = min(c(n_f_args_before_dots, n_mandatory_f_args))
default_check_args = remaining_args[seq_len(n_remaining_args_for_default_check)]
default_check_args_names = names(default_check_args)
has_default_replaced_by_mandatory = map_lgl(default_check_args, ~!is_missing(.x))
if (any(has_default_replaced_by_mandatory)) {
default_check_mandatory_args_labels =
mandatory_f_args_labels[seq_len(n_remaining_args_for_default_check)]
# ^ excludes any mandatory args absorbed by f's `...`'s:
mandatory_args_replacing_defaults =
default_check_mandatory_args_labels[has_default_replaced_by_mandatory]
args_with_default_replaced_by_mandatory =
rlang::syms(default_check_args_names[has_default_replaced_by_mandatory])
cli::cli_abort("`epi[x]_slide` would pass the {mandatory_args_replacing_defaults} to `f`'s {args_with_default_replaced_by_mandatory} argument{?s}, which {?has a/have} default value{?s}; we suspect that `f` doesn't expect {?this arg/these args} at all and may produce confusing error messages. Please add additional arguments to `f` or remove defaults as appropriate.",
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults",
epiprocess__f = f)
}
}

##########

in_range = function(x, rng) pmin(pmax(x, rng[1]), rng[2])
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-epi_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,14 @@ test_that("these doesn't produce an error; the error appears only if the ref tim
dplyr::select("geo_value","slide_value_value"),
dplyr::tibble(geo_value = c("ak", "al"), slide_value_value = c(2, -2))) # not out of range for either group
})

test_that("epi_slide alerts if the provided f doesn't take enough args", {
f_xg = function(x, g) dplyr::tibble(value=mean(x$value), count=length(x$value))
# If `regexp` is NA, asserts that there should be no errors/messages.
expect_error(epi_slide(grouped, f_xg, before = 1L, ref_time_values = d+1), regexp = NA)
expect_warning(epi_slide(grouped, f_xg, before = 1L, ref_time_values = d+1), regexp = NA)

f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$value), count=length(x$value))
expect_warning(epi_slide(grouped, f_x_dots, before = 1L, ref_time_values = d+1),
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
})
11 changes: 11 additions & 0 deletions tests/testthat/test-epix_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,14 @@ test_that("epix_slide with all_versions option works as intended",{

expect_identical(xx1,xx3) # This and * Imply xx2 and xx3 are identical
})

test_that("epix_slide alerts if the provided f doesn't take enough args", {
f_xg = function(x, g) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
# If `regexp` is NA, asserts that there should be no errors/messages.
expect_error(epix_slide(xx, f = f_xg, before = 2L), regexp = NA)
expect_warning(epix_slide(xx, f = f_xg, before = 2L), regexp = NA)

f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
expect_warning(epix_slide(xx, f_x_dots, before = 2L),
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
})
78 changes: 77 additions & 1 deletion tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,80 @@ test_that("enlist works",{
my_list <- enlist(x=1,y=2,z=3)
expect_equal(my_list$x,1)
expect_true(inherits(my_list,"list"))
})
})

test_that("assert_sufficient_f_args alerts if the provided f doesn't take enough args", {
f_xg = function(x, g) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f_xg_dots = function(x, g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))

# If `regexp` is NA, asserts that there should be no errors/messages.
expect_error(assert_sufficient_f_args(f_xg), regexp = NA)
expect_warning(assert_sufficient_f_args(f_xg), regexp = NA)
expect_error(assert_sufficient_f_args(f_xg_dots), regexp = NA)
expect_warning(assert_sufficient_f_args(f_xg_dots), regexp = NA)

f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f_dots = function(...) dplyr::tibble(value=c(5), count=c(2))
f_x = function(x) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f = function() dplyr::tibble(value=c(5), count=c(2))

expect_warning(assert_sufficient_f_args(f_x_dots),
regexp = ", the group key will be included",
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
expect_warning(assert_sufficient_f_args(f_dots),
regexp = ", the window data and group key will be included",
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
expect_error(assert_sufficient_f_args(f_x),
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args")
expect_error(assert_sufficient_f_args(f),
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args")

f_xs_dots = function(x, setting="a", ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f_xs = function(x, setting="a") dplyr::tibble(value=mean(x$binary), count=length(x$binary))
expect_warning(assert_sufficient_f_args(f_xs_dots, setting="b"),
class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots")
expect_error(assert_sufficient_f_args(f_xs, setting="b"),
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded")

expect_error(assert_sufficient_f_args(f_xg, "b"),
class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded")
})

test_that("assert_sufficient_f_args alerts if the provided f has defaults for the required args", {
f_xg = function(x, g=1) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f_xg_dots = function(x=1, g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f_x_dots = function(x=1, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))

expect_error(assert_sufficient_f_args(f_xg),
regexp = "pass the group key to `f`'s g argument,",
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
expect_error(assert_sufficient_f_args(f_xg_dots),
regexp = "pass the window data to `f`'s x argument,",
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
expect_error(suppressWarnings(assert_sufficient_f_args(f_x_dots)),
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")

f_xsg = function(x, setting="a", g) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f_xsg_dots = function(x, setting="a", g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))
f_xs_dots = function(x=1, setting="a", ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary))

# forwarding named dots should prevent some complaints:
expect_no_error(assert_sufficient_f_args(f_xsg, setting = "b"))
expect_no_error(assert_sufficient_f_args(f_xsg_dots, setting = "b"))
expect_error(suppressWarnings(assert_sufficient_f_args(f_xs_dots, setting = "b")),
regexp = "window data to `f`'s x argument",
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")

# forwarding unnamed dots should not:
expect_error(assert_sufficient_f_args(f_xsg, "b"),
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
expect_error(assert_sufficient_f_args(f_xsg_dots, "b"),
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
expect_error(assert_sufficient_f_args(f_xs_dots, "b"),
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")

# forwarding no dots should produce a different error message in some cases:
expect_error(assert_sufficient_f_args(f_xs_dots),
regexp = "window data and group key to `f`'s x and setting argument",
class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults")
})