Skip to content

Commit 2d91421

Browse files
committed
refactor: use checkmate for validation where possible, fix tests
1 parent f0338e2 commit 2d91421

20 files changed

+164
-227
lines changed

NAMESPACE

+12
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,21 @@ export(ungroup)
6969
export(unnest)
7070
importFrom(R6,R6Class)
7171
importFrom(checkmate,anyInfinite)
72+
importFrom(checkmate,anyMissing)
7273
importFrom(checkmate,assert)
7374
importFrom(checkmate,assert_character)
75+
importFrom(checkmate,assert_class)
76+
importFrom(checkmate,assert_data_frame)
7477
importFrom(checkmate,assert_int)
78+
importFrom(checkmate,assert_list)
79+
importFrom(checkmate,assert_logical)
80+
importFrom(checkmate,assert_numeric)
81+
importFrom(checkmate,assert_scalar)
82+
importFrom(checkmate,check_atomic)
83+
importFrom(checkmate,check_data_frame)
84+
importFrom(checkmate,test_set_equal)
85+
importFrom(checkmate,test_subset)
86+
importFrom(checkmate,vname)
7587
importFrom(cli,cli_abort)
7688
importFrom(cli,cli_inform)
7789
importFrom(cli,cli_warn)

R/archive.R

+46-47
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,26 @@
2525
validate_version_bound <- function(version_bound, x, na_ok = FALSE,
2626
version_bound_arg = rlang::caller_arg(version_bound),
2727
x_arg = rlang::caller_arg(version_bound)) {
28-
29-
if (na_ok && is.na(version_bound)) return(invisible(NULL))
30-
checkmate::assert_set_equal(class(version_bound), class(x[["version"]]), .var.name = version_bound_arg)
31-
checkmate::assert_set_equal(typeof(version_bound), typeof(x[["version"]]), .var.name = version_bound_arg)
28+
if (is.null(version_bound)) {
29+
cli_abort(
30+
"{version_bound_arg} cannot be NULL"
31+
)
32+
}
33+
if (na_ok && is.na(version_bound)) {
34+
return(invisible(NULL))
35+
}
36+
if (!test_set_equal(class(version_bound), class(x[["version"]]))) {
37+
cli_abort(
38+
"{version_bound_arg} must have the same classes as x$version,
39+
which is {class(x$version)}",
40+
)
41+
}
42+
if (!test_set_equal(typeof(version_bound), typeof(x[["version"]]))) {
43+
cli_abort(
44+
"{version_bound_arg} must have the same types as x$version,
45+
which is {typeof(x$version)}",
46+
)
47+
}
3248

3349
return(invisible(NULL))
3450
}
@@ -251,26 +267,15 @@ epi_archive <-
251267
initialize = function(x, geo_type, time_type, other_keys,
252268
additional_metadata, compactify,
253269
clobberable_versions_start, versions_end) {
254-
# Check that we have a data frame
255-
if (!is.data.frame(x)) {
256-
cli_abort("`x` must be a data frame.")
257-
}
258-
259-
# Check that we have geo_value, time_value, version columns
260-
if (!("geo_value" %in% names(x))) {
261-
cli_abort("`x` must contain a `geo_value` column.")
262-
}
263-
if (!("time_value" %in% names(x))) {
264-
cli_abort("`x` must contain a `time_value` column.")
265-
}
266-
if (!("version" %in% names(x))) {
267-
cli_abort("`x` must contain a `version` column.")
268-
}
269-
if (anyNA(x$version)) {
270-
cli_abort("`x$version` must not contain `NA`s",
271-
class = "epiprocess__version_values_must_not_be_na"
270+
assert_data_frame(x)
271+
if (!test_subset(c("geo_value", "time_value", "version"), names(x))) {
272+
cli_abort(
273+
"Columns `geo_value`, `time_value`, and `version` must be present in `x`."
272274
)
273275
}
276+
if (anyMissing(x$version)) {
277+
cli_abort("Column `version` must not contain missing values.")
278+
}
274279

275280
# If geo type is missing, then try to guess it
276281
if (missing(geo_type)) {
@@ -285,7 +290,7 @@ epi_archive <-
285290
# Finish off with small checks on keys variables and metadata
286291
if (missing(other_keys)) other_keys <- NULL
287292
if (missing(additional_metadata)) additional_metadata <- list()
288-
if (!all(other_keys %in% names(x))) {
293+
if (!test_subset(other_keys, names(x))) {
289294
cli_abort("`other_keys` must be contained in the column names of `x`.")
290295
}
291296
if (any(c("geo_value", "time_value", "version") %in% other_keys)) {
@@ -298,10 +303,8 @@ epi_archive <-
298303
# Conduct checks and apply defaults for `compactify`
299304
if (missing(compactify)) {
300305
compactify <- NULL
301-
} else if (!rlang::is_bool(compactify) &&
302-
!rlang::is_null(compactify)) {
303-
cli_abort("compactify must be boolean or null.")
304306
}
307+
assert_logical(compactify, len = 1, null.ok = TRUE)
305308

306309
# Apply defaults and conduct checks for
307310
# `clobberable_versions_start`, `versions_end`:
@@ -384,7 +387,7 @@ epi_archive <-
384387
elim <- tibble::tibble()
385388
}
386389

387-
# cli_warns about redundant rows
390+
# Warns about redundant rows
388391
if (is.null(compactify) && nrow(elim) > 0) {
389392
warning_intro <- cli::format_inline(
390393
"Found rows that appear redundant based on
@@ -436,7 +439,7 @@ epi_archive <-
436439
)
437440
)
438441

439-
return(invisible(self$DT %>% print))
442+
return(invisible(self$DT %>% print()))
440443
},
441444
#####
442445
#' @description Generates a snapshot in `epi_df` format as of a given version.
@@ -467,22 +470,21 @@ epi_archive <-
467470
if (length(other_keys) == 0) other_keys <- NULL
468471

469472
# Check a few things on max_version
470-
if (!identical(class(max_version), class(self$DT$version)) ||
471-
!identical(typeof(max_version), typeof(self$DT$version))) {
472-
cli_abort("`max_version` and `DT$version` must have same `class` and `typeof`.")
473-
}
474-
if (length(max_version) != 1) {
475-
cli_abort("`max_version` cannot be a vector.")
473+
if (!test_set_equal(class(max_version), class(self$DT$version))) {
474+
cli_abort(
475+
"`max_version` must have the same classes as `self$DT$version`."
476+
)
476477
}
477-
if (is.na(max_version)) {
478-
cli_abort("`max_version` must not be NA.")
478+
if (!test_set_equal(typeof(max_version), typeof(self$DT$version))) {
479+
cli_abort(
480+
"`max_version` must have the same types as `self$DT$version`."
481+
)
479482
}
483+
assert_scalar(max_version, na.ok = FALSE)
480484
if (max_version > self$versions_end) {
481485
cli_abort("`max_version` must be at most `self$versions_end`.")
482486
}
483-
if (!rlang::is_bool(all_versions)) {
484-
cli_abort("`all_versions` must be TRUE or FALSE.")
485-
}
487+
assert_logical(all_versions, len = 1)
486488
if (!is.na(self$clobberable_versions_start) && max_version >= self$clobberable_versions_start) {
487489
cli_warn(
488490
'Getting data as of some recent version which could still be
@@ -599,16 +601,13 @@ epi_archive <-
599601
#' @param x as in [`epix_truncate_versions_after`]
600602
#' @param max_version as in [`epix_truncate_versions_after`]
601603
truncate_versions_after = function(max_version) {
602-
if (length(max_version) != 1) {
603-
cli_abort("`max_version` cannot be a vector.")
604-
}
605-
if (is.na(max_version)) {
606-
cli_abort("`max_version` must not be NA.")
604+
if (!test_set_equal(class(max_version), class(self$DT$version))) {
605+
cli_abort("`max_version` must have the same classes as `self$DT$version`.")
607606
}
608-
if (!identical(class(max_version), class(self$DT$version)) ||
609-
!identical(typeof(max_version), typeof(self$DT$version))) {
610-
cli_abort("`max_version` and `DT$version` must have same `class` and `typeof`.")
607+
if (!test_set_equal(typeof(max_version), typeof(self$DT$version))) {
608+
cli_abort("`max_version` must have the same types as `self$DT$version`.")
611609
}
610+
assert_scalar(max_version, na.ok = FALSE)
612611
if (max_version > self$versions_end) {
613612
cli_abort("`max_version` must be at most `self$versions_end`.")
614613
}

R/correlation.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@
7878
epi_cor <- function(x, var1, var2, dt1 = 0, dt2 = 0, shift_by = geo_value,
7979
cor_by = geo_value, use = "na.or.complete",
8080
method = c("pearson", "kendall", "spearman")) {
81-
# Check we have an `epi_df` object
82-
if (!inherits(x, "epi_df")) cli_abort("`x` must be of class `epi_df`.")
81+
assert_class(x, "epi_df")
8382

8483
# Check that we have variables to do computations on
8584
if (missing(var1)) cli_abort("`var1` must be specified.")

R/epi_df.R

+5-14
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,9 @@ NULL
114114
#' @export
115115
new_epi_df <- function(x = tibble::tibble(), geo_type, time_type, as_of,
116116
additional_metadata = list(), ...) {
117-
# Check that we have a data frame
118-
if (!is.data.frame(x)) {
119-
cli_abort("`x` must be a data frame.")
120-
}
117+
assert_data_frame(x)
118+
assert_list(additional_metadata)
121119

122-
if (!is.list(additional_metadata)) {
123-
cli_abort("`additional_metadata` must be a list type.")
124-
}
125120
if (is.null(additional_metadata[["other_keys"]])) {
126121
additional_metadata[["other_keys"]] <- character(0L)
127122
}
@@ -302,13 +297,9 @@ as_epi_df.epi_df <- function(x, ...) {
302297
#' @export
303298
as_epi_df.tbl_df <- function(x, geo_type, time_type, as_of,
304299
additional_metadata = list(), ...) {
305-
# Check that we have geo_value and time_value columns
306-
if (!("geo_value" %in% names(x))) {
307-
cli_abort("`x` must contain a `geo_value` column.")
308-
}
309-
if (!("time_value" %in% names(x))) {
310-
cli_abort("`x` must contain a `time_value` column.")
311-
}
300+
if (!test_subset(c("geo_value", "time_value"), names(x))) cli_abort(
301+
"Columns `geo_value` and `time_value` must be present in `x`."
302+
)
312303

313304
new_epi_df(
314305
x, geo_type, time_type, as_of,

R/epiprocess.R

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
#' measured over space and time, and offers associated utilities to perform
55
#' basic signal processing tasks.
66
#'
7-
#' @importFrom checkmate assert assert_character assert_int anyInfinite
7+
#' @importFrom checkmate assert assert_scalar assert_data_frame anyMissing
8+
#' assert_logical assert_list assert_character assert_class
9+
#' assert_int assert_numeric check_data_frame vname check_atomic
10+
#' anyInfinite test_subset test_set_equal
811
#' @importFrom cli cli_abort cli_inform cli_warn
912
#' @name epiprocess
1013
"_PACKAGE"

R/grouped_epi_archive.R

+20-50
Original file line numberDiff line numberDiff line change
@@ -59,35 +59,18 @@ grouped_epi_archive <-
5959
epiprocess__ungrouped_groups = groups(ungrouped)
6060
)
6161
}
62-
if (!inherits(ungrouped, "epi_archive")) {
63-
cli_abort("`ungrouped` must be an epi_archive",
64-
class = "epiprocess__grouped_epi_archive__ungrouped_arg_is_not_epi_archive",
65-
epiprocess__ungrouped_class = class(ungrouped)
66-
)
67-
}
68-
if (!is.character(vars)) {
69-
cli_abort("`vars` must be a character vector (any tidyselection should have already occurred in a helper method).",
70-
class = "epiprocess__grouped_epi_archive__vars_is_not_chr",
71-
epiprocess__vars_class = class(vars),
72-
epiprocess__vars_type = typeof(vars)
73-
)
74-
}
75-
if (!all(vars %in% names(ungrouped$DT))) {
76-
cli_abort("`vars` must be selected from the names of columns of `ungrouped$DT`",
77-
class = "epiprocess__grouped_epi_archive__vars_contains_invalid_entries",
78-
epiprocess__vars = vars,
79-
epiprocess__DT_names = names(ungrouped$DT)
62+
assert_class(ungrouped, "epi_archive")
63+
assert_character(vars)
64+
if (!test_subset(vars, names(ungrouped$DT))) {
65+
cli_abort(
66+
"All grouping variables `vars` must be present in the data.",
8067
)
8168
}
8269
if ("version" %in% vars) {
8370
cli_abort("`version` has a special interpretation and cannot be used by itself as a grouping variable")
8471
}
85-
if (!rlang::is_bool(drop)) {
86-
cli_abort("`drop` must be a Boolean",
87-
class = "epiprocess__grouped_epi_archive__drop_is_not_bool",
88-
epiprocess__drop = drop
89-
)
90-
}
72+
assert_logical(drop, len = 1)
73+
9174
# -----
9275
private$ungrouped <- ungrouped
9376
private$vars <- vars
@@ -136,9 +119,7 @@ grouped_epi_archive <-
136119
invisible(self)
137120
},
138121
group_by = function(..., .add = FALSE, .drop = dplyr::group_by_drop_default(self)) {
139-
if (!rlang::is_bool(.add)) {
140-
cli_abort("`.add` must be a Boolean")
141-
}
122+
assert_logical(.add, len = 1)
142123
if (!.add) {
143124
cli_abort('`group_by` on a `grouped_epi_archive` with `.add=FALSE` is forbidden
144125
(neither automatic regrouping nor nested grouping is supported).
@@ -230,15 +211,14 @@ grouped_epi_archive <-
230211

231212
if (missing(ref_time_values)) {
232213
ref_time_values <- epix_slide_ref_time_values_default(private$ungrouped)
233-
} else if (length(ref_time_values) == 0L) {
234-
cli_abort("`ref_time_values` must have at least one element.")
235-
} else if (any(is.na(ref_time_values))) {
236-
cli_abort("`ref_time_values` must not include `NA`.")
237-
} else if (anyDuplicated(ref_time_values) != 0L) {
238-
cli_abort("`ref_time_values` must not contain any duplicates; use `unique` if appropriate.")
239-
} else if (any(ref_time_values > private$ungrouped$versions_end)) {
240-
cli_abort("All `ref_time_values` must be `<=` the `versions_end`.")
241214
} else {
215+
assert_numeric(ref_time_values, min.len = 1L, null.ok = FALSE, any.missing = FALSE)
216+
if (any(ref_time_values > private$ungrouped$versions_end)) {
217+
cli_abort("Some `ref_time_values` are greater than the latest version in the archive.")
218+
}
219+
if (anyDuplicated(ref_time_values) != 0L) {
220+
cli_abort("Some `ref_time_values` are duplicated.")
221+
}
242222
# Sort, for consistency with `epi_slide`, although the current
243223
# implementation doesn't take advantage of it.
244224
ref_time_values <- sort(ref_time_values)
@@ -253,9 +233,7 @@ grouped_epi_archive <-
253233
`before=365000`).")
254234
}
255235
before <- vctrs::vec_cast(before, integer())
256-
if (length(before) != 1L || is.na(before) || before < 0L) {
257-
cli_abort("`before` must be length-1, non-NA, non-negative.")
258-
}
236+
assert_int(before, lower = 0L, null.ok = FALSE, na.ok = FALSE)
259237

260238
# If a custom time step is specified, then redefine units
261239

@@ -265,15 +243,9 @@ grouped_epi_archive <-
265243
new_col <- sym(new_col_name)
266244

267245
# Validate rest of parameters:
268-
if (!rlang::is_bool(as_list_col)) {
269-
cli_abort("`as_list_col` must be TRUE or FALSE.")
270-
}
271-
if (!(rlang::is_string(names_sep) || is.null(names_sep))) {
272-
cli_abort("`names_sep` must be a (single) string or NULL.")
273-
}
274-
if (!rlang::is_bool(all_versions)) {
275-
cli_abort("`all_versions` must be TRUE or FALSE.")
276-
}
246+
assert_logical(as_list_col, len = 1L)
247+
assert_logical(all_versions, len = 1L)
248+
assert_character(names_sep, len = 1L, null.ok = TRUE)
277249

278250
# Computation for one group, one time value
279251
comp_one_grp <- function(.data_group, .group_key,
@@ -290,9 +262,7 @@ grouped_epi_archive <-
290262
.data_group <- .data_group$DT
291263
}
292264

293-
if (!(is.atomic(comp_value) || is.data.frame(comp_value))) {
294-
cli_abort("The slide computation must return an atomic vector or a data frame.")
295-
}
265+
assert(check_atomic(comp_value, any.missing = TRUE), check_data_frame(comp_value), combine = "or", .var.name = vname(comp_value))
296266

297267
# Label every result row with the `ref_time_value`
298268
res <- list(time_value = ref_time_value)

R/growth_rate.R

-2
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ growth_rate <- function(x = seq_along(y), y, x0 = x,
120120
# Check x, y, x0
121121
if (length(x) != length(y)) cli_abort("`x` and `y` must have the same length.")
122122
if (!all(x0 %in% x)) cli_abort("`x0` must be a subset of `x`.")
123-
124-
# Check the method
125123
method <- match.arg(method)
126124

127125
# Arrange in increasing order of x

0 commit comments

Comments
 (0)