Skip to content

Commit cc30c3a

Browse files
committed
refactor: use checkmate for validation where possible, fix tests
1 parent 3a94603 commit cc30c3a

21 files changed

+103
-222
lines changed

DESCRIPTION

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Description: This package introduces a common data structure for epidemiological
2626
License: MIT + file LICENSE
2727
Copyright: file inst/COPYRIGHTS
2828
Imports:
29+
checkmate,
2930
cli,
3031
data.table,
3132
dplyr (>= 1.0.0),

NAMESPACE

+16
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ export(slice)
6161
export(ungroup)
6262
export(unnest)
6363
importFrom(R6,R6Class)
64+
importFrom(checkmate,anyNaN)
65+
importFrom(checkmate,assert)
66+
importFrom(checkmate,assert_character)
67+
importFrom(checkmate,assert_class)
68+
importFrom(checkmate,assert_data_frame)
69+
importFrom(checkmate,assert_int)
70+
importFrom(checkmate,assert_list)
71+
importFrom(checkmate,assert_logical)
72+
importFrom(checkmate,assert_numeric)
73+
importFrom(checkmate,assert_scalar)
74+
importFrom(checkmate,assert_set_equal)
75+
importFrom(checkmate,assert_subset)
76+
importFrom(checkmate,check_atomic)
77+
importFrom(checkmate,check_data_frame)
78+
importFrom(checkmate,check_scalar)
79+
importFrom(checkmate,vname)
6480
importFrom(cli,cli_abort)
6581
importFrom(cli,cli_inform)
6682
importFrom(cli,cli_warn)

R/archive.R

+12-45
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ validate_version_bound <- function(version_bound, x, na_ok = FALSE,
2727
x_arg = rlang::caller_arg(version_bound)) {
2828

2929
if (na_ok && is.na(version_bound)) return(invisible(NULL))
30-
checkmate::assert_set_equal(class(version_bound), class(x[["version"]]), .var.name = version_bound_arg)
31-
checkmate::assert_set_equal(typeof(version_bound), typeof(x[["version"]]), .var.name = version_bound_arg)
30+
assert_set_equal(class(version_bound), class(x[["version"]]), .var.name = version_bound_arg)
31+
assert_set_equal(typeof(version_bound), typeof(x[["version"]]), .var.name = version_bound_arg)
3232

3333
return(invisible(NULL))
3434
}
@@ -246,26 +246,9 @@ epi_archive <-
246246
initialize = function(x, geo_type, time_type, other_keys,
247247
additional_metadata, compactify,
248248
clobberable_versions_start, versions_end) {
249-
# Check that we have a data frame
250-
if (!is.data.frame(x)) {
251-
cli_abort("`x` must be a data frame.")
252-
}
253-
254-
# Check that we have geo_value, time_value, version columns
255-
if (!("geo_value" %in% names(x))) {
256-
cli_abort("`x` must contain a `geo_value` column.")
257-
}
258-
if (!("time_value" %in% names(x))) {
259-
cli_abort("`x` must contain a `time_value` column.")
260-
}
261-
if (!("version" %in% names(x))) {
262-
cli_abort("`x` must contain a `version` column.")
263-
}
264-
if (anyNA(x$version)) {
265-
cli_abort("`x$version` must not contain `NA`s",
266-
class = "epiprocess__version_values_must_not_be_na"
267-
)
268-
}
249+
assert_data_frame(x)
250+
assert_subset(c("geo_value", "time_value", "version"), names(x))
251+
assert(!anyNaN(x$version))
269252

270253
# If geo type is missing, then try to guess it
271254
if (missing(geo_type)) {
@@ -461,22 +444,13 @@ epi_archive <-
461444
if (length(other_keys) == 0) other_keys <- NULL
462445

463446
# Check a few things on max_version
464-
if (!identical(class(max_version), class(self$DT$version)) ||
465-
!identical(typeof(max_version), typeof(self$DT$version))) {
466-
cli_abort("`max_version` and `DT$version` must have same `class` and `typeof`.")
467-
}
468-
if (length(max_version) != 1) {
469-
cli_abort("`max_version` cannot be a vector.")
470-
}
471-
if (is.na(max_version)) {
472-
cli_abort("`max_version` must not be NA.")
473-
}
447+
assert_set_equal(class(max_version), class(self$DT$version))
448+
assert_set_equal(typeof(max_version), typeof(self$DT$version))
449+
assert_scalar(max_version, na.ok = FALSE)
474450
if (max_version > self$versions_end) {
475451
cli_abort("`max_version` must be at most `self$versions_end`.")
476452
}
477-
if (!rlang::is_bool(all_versions)) {
478-
cli_abort("`all_versions` must be TRUE or FALSE.")
479-
}
453+
assert_logical(all_versions, len = 1)
480454
if (!is.na(self$clobberable_versions_start) && max_version >= self$clobberable_versions_start) {
481455
cli_warn('Getting data as of some recent version which could still be overwritten (under routine circumstances) without assigning a new version number (a.k.a. "clobbered"). Thus, the snapshot that we produce here should not be expected to be reproducible later. See `?epi_archive` for more info and `?epix_as_of` on how to muffle.',
482456
class = "epiprocess__snapshot_as_of_clobberable_version"
@@ -589,16 +563,9 @@ epi_archive <-
589563
#' @param x as in [`epix_truncate_versions_after`]
590564
#' @param max_version as in [`epix_truncate_versions_after`]
591565
truncate_versions_after = function(max_version) {
592-
if (length(max_version) != 1) {
593-
cli_abort("`max_version` cannot be a vector.")
594-
}
595-
if (is.na(max_version)) {
596-
cli_abort("`max_version` must not be NA.")
597-
}
598-
if (!identical(class(max_version), class(self$DT$version)) ||
599-
!identical(typeof(max_version), typeof(self$DT$version))) {
600-
cli_abort("`max_version` and `DT$version` must have same `class` and `typeof`.")
601-
}
566+
assert_set_equal(class(max_version), class(self$DT$version))
567+
assert_set_equal(typeof(max_version), typeof(self$DT$version))
568+
assert_scalar(max_version, na.ok = FALSE)
602569
if (max_version > self$versions_end) {
603570
cli_abort("`max_version` must be at most `self$versions_end`.")
604571
}

R/correlation.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@
7878
epi_cor <- function(x, var1, var2, dt1 = 0, dt2 = 0, shift_by = geo_value,
7979
cor_by = geo_value, use = "na.or.complete",
8080
method = c("pearson", "kendall", "spearman")) {
81-
# Check we have an `epi_df` object
82-
if (!inherits(x, "epi_df")) cli_abort("`x` must be of class `epi_df`.")
81+
assert_class(x, "epi_df")
8382

8483
# Check that we have variables to do computations on
8584
if (missing(var1)) cli_abort("`var1` must be specified.")

R/epi_df.R

+3-14
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,9 @@ NULL
114114
#' @export
115115
new_epi_df <- function(x = tibble::tibble(), geo_type, time_type, as_of,
116116
additional_metadata = list(), ...) {
117-
# Check that we have a data frame
118-
if (!is.data.frame(x)) {
119-
cli_abort("`x` must be a data frame.")
120-
}
117+
assert_data_frame(x)
118+
assert_list(additional_metadata)
121119

122-
if (!is.list(additional_metadata)) {
123-
cli_abort("`additional_metadata` must be a list type.")
124-
}
125120
if (is.null(additional_metadata[["other_keys"]])) {
126121
additional_metadata[["other_keys"]] <- character(0L)
127122
}
@@ -302,13 +297,7 @@ as_epi_df.epi_df <- function(x, ...) {
302297
#' @export
303298
as_epi_df.tbl_df <- function(x, geo_type, time_type, as_of,
304299
additional_metadata = list(), ...) {
305-
# Check that we have geo_value and time_value columns
306-
if (!("geo_value" %in% names(x))) {
307-
cli_abort("`x` must contain a `geo_value` column.")
308-
}
309-
if (!("time_value" %in% names(x))) {
310-
cli_abort("`x` must contain a `time_value` column.")
311-
}
300+
assert_subset(c("geo_value", "time_value"), names(x))
312301

313302
new_epi_df(
314303
x, geo_type, time_type, as_of,

R/epiprocess.R

+4
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,9 @@
77
#' @docType package
88
#' @name epiprocess
99
#' @importFrom cli cli_abort cli_inform cli_warn
10+
#' @importFrom checkmate assert assert_scalar assert_data_frame assert_subset
11+
#' assert_set_equal anyNaN assert_logical assert_list
12+
#' assert_character assert_class assert_int assert_numeric
13+
#' check_scalar check_data_frame vname check_atomic
1014
NULL
1115
utils::globalVariables(c(".x", ".group_key", ".ref_time_value"))

R/grouped_epi_archive.R

+12-52
Original file line numberDiff line numberDiff line change
@@ -59,35 +59,14 @@ grouped_epi_archive <-
5959
epiprocess__ungrouped_groups = groups(ungrouped)
6060
)
6161
}
62-
if (!inherits(ungrouped, "epi_archive")) {
63-
cli_abort("`ungrouped` must be an epi_archive",
64-
class = "epiprocess__grouped_epi_archive__ungrouped_arg_is_not_epi_archive",
65-
epiprocess__ungrouped_class = class(ungrouped)
66-
)
67-
}
68-
if (!is.character(vars)) {
69-
cli_abort("`vars` must be a character vector (any tidyselection should have already occurred in a helper method).",
70-
class = "epiprocess__grouped_epi_archive__vars_is_not_chr",
71-
epiprocess__vars_class = class(vars),
72-
epiprocess__vars_type = typeof(vars)
73-
)
74-
}
75-
if (!all(vars %in% names(ungrouped$DT))) {
76-
cli_abort("`vars` must be selected from the names of columns of `ungrouped$DT`",
77-
class = "epiprocess__grouped_epi_archive__vars_contains_invalid_entries",
78-
epiprocess__vars = vars,
79-
epiprocess__DT_names = names(ungrouped$DT)
80-
)
81-
}
62+
assert_class(ungrouped, "epi_archive")
63+
assert_character(vars)
64+
assert_subset(vars, names(ungrouped$DT))
8265
if ("version" %in% vars) {
8366
cli_abort("`version` has a special interpretation and cannot be used by itself as a grouping variable")
8467
}
85-
if (!rlang::is_bool(drop)) {
86-
cli_abort("`drop` must be a Boolean",
87-
class = "epiprocess__grouped_epi_archive__drop_is_not_bool",
88-
epiprocess__drop = drop
89-
)
90-
}
68+
assert_logical(drop)
69+
9170
# -----
9271
private$ungrouped <- ungrouped
9372
private$vars <- vars
@@ -136,9 +115,7 @@ grouped_epi_archive <-
136115
invisible(self)
137116
},
138117
group_by = function(..., .add = FALSE, .drop = dplyr::group_by_drop_default(self)) {
139-
if (!rlang::is_bool(.add)) {
140-
cli_abort("`.add` must be a Boolean")
141-
}
118+
assert_logical(.add)
142119
if (!.add) {
143120
cli_abort('`group_by` on a `grouped_epi_archive` with `.add=FALSE` is forbidden
144121
(neither automatic regrouping nor nested grouping is supported).
@@ -230,15 +207,8 @@ grouped_epi_archive <-
230207

231208
if (missing(ref_time_values)) {
232209
ref_time_values <- epix_slide_ref_time_values_default(private$ungrouped)
233-
} else if (length(ref_time_values) == 0L) {
234-
cli_abort("`ref_time_values` must have at least one element.")
235-
} else if (any(is.na(ref_time_values))) {
236-
cli_abort("`ref_time_values` must not include `NA`.")
237-
} else if (anyDuplicated(ref_time_values) != 0L) {
238-
cli_abort("`ref_time_values` must not contain any duplicates; use `unique` if appropriate.")
239-
} else if (any(ref_time_values > private$ungrouped$versions_end)) {
240-
cli_abort("All `ref_time_values` must be `<=` the `versions_end`.")
241210
} else {
211+
assert_numeric(ref_time_values, upper = private$ungrouped$versions_end, min.len = 1L, null.ok = FALSE, any.missing = FALSE)
242212
# Sort, for consistency with `epi_slide`, although the current
243213
# implementation doesn't take advantage of it.
244214
ref_time_values <- sort(ref_time_values)
@@ -253,9 +223,7 @@ grouped_epi_archive <-
253223
`before=365000`).")
254224
}
255225
before <- vctrs::vec_cast(before, integer())
256-
if (length(before) != 1L || is.na(before) || before < 0L) {
257-
cli_abort("`before` must be length-1, non-NA, non-negative.")
258-
}
226+
assert_int(before, lower = 0L, null.ok = FALSE, na.ok = FALSE)
259227

260228
# If a custom time step is specified, then redefine units
261229

@@ -265,15 +233,9 @@ grouped_epi_archive <-
265233
new_col <- sym(new_col_name)
266234

267235
# Validate rest of parameters:
268-
if (!rlang::is_bool(as_list_col)) {
269-
cli_abort("`as_list_col` must be TRUE or FALSE.")
270-
}
271-
if (!(rlang::is_string(names_sep) || is.null(names_sep))) {
272-
cli_abort("`names_sep` must be a (single) string or NULL.")
273-
}
274-
if (!rlang::is_bool(all_versions)) {
275-
cli_abort("`all_versions` must be TRUE or FALSE.")
276-
}
236+
assert_logical(as_list_col, len = 1L)
237+
assert_logical(all_versions, len = 1L)
238+
assert_character(names_sep, len = 1L, null.ok = TRUE)
277239

278240
# Computation for one group, one time value
279241
comp_one_grp <- function(.data_group, .group_key,
@@ -290,9 +252,7 @@ grouped_epi_archive <-
290252
.data_group <- .data_group$DT
291253
}
292254

293-
if (!(is.atomic(comp_value) || is.data.frame(comp_value))) {
294-
cli_abort("The slide computation must return an atomic vector or a data frame.")
295-
}
255+
assert(check_atomic(comp_value, any.missing = TRUE), check_data_frame(comp_value), combine = "or", .var.name = vname(comp_value))
296256

297257
# Label every result row with the `ref_time_value`
298258
res <- list(time_value = ref_time_value)

R/growth_rate.R

+1-3
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ growth_rate <- function(x = seq_along(y), y, x0 = x,
119119
dup_rm = FALSE, na_rm = FALSE, ...) {
120120
# Check x, y, x0
121121
if (length(x) != length(y)) cli_abort("`x` and `y` must have the same length.")
122-
if (!all(x0 %in% x)) cli_abort("`x0` must be a subset of `x`.")
123-
124-
# Check the method
122+
assert_subset(x0, x)
125123
method <- match.arg(method)
126124

127125
# Arrange in increasing order of x

R/methods-epi_archive.R

+8-24
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@
7070
#' max_version = max(archive_cases_dv_subset$DT$version)
7171
#' )
7272
#' },
73-
#' epiprocess__snapshot_as_of_clobberable_version = function(wrn) invokeRestart("mufflecli_warning")
73+
#' epiprocess__snapshot_as_of_clobberable_version = function(wrn) invokeRestart("muffleWarning")
7474
#' )
7575
#' # Since R 4.0, there is a `globalCallingHandlers` function that can be used
7676
#' # to globally toggle these warnings.
7777
#'
7878
#' @export
7979
epix_as_of <- function(x, max_version, min_time_value = -Inf, all_versions = FALSE) {
80-
if (!inherits(x, "epi_archive")) cli_abort("`x` must be of class `epi_archive`.")
80+
assert_class(x, "epi_archive")
8181
return(x$as_of(max_version, min_time_value, all_versions = all_versions))
8282
}
8383

@@ -113,7 +113,7 @@ epix_as_of <- function(x, max_version, min_time_value = -Inf, all_versions = FAL
113113
#' @return An `epi_archive`
114114
epix_fill_through_version <- function(x, fill_versions_end,
115115
how = c("na", "locf")) {
116-
if (!inherits(x, "epi_archive")) cli_abort("`x` must be of class `epi_archive`.")
116+
assert_class(x, "epi_archive")
117117
# Enclosing parentheses drop the invisibility flag. See description above of
118118
# potential mutation and aliasing behavior.
119119
(x$clone()$fill_through_version(fill_versions_end, how = how))
@@ -179,14 +179,8 @@ epix_fill_through_version <- function(x, fill_versions_end,
179179
epix_merge <- function(x, y,
180180
sync = c("forbid", "na", "locf", "truncate"),
181181
compactify = TRUE) {
182-
if (!inherits(x, "epi_archive")) {
183-
cli_abort("`x` must be of class `epi_archive`.")
184-
}
185-
186-
if (!inherits(y, "epi_archive")) {
187-
cli_abort("`y` must be of class `epi_archive`.")
188-
}
189-
182+
assert_class(x, "epi_archive")
183+
assert_class(y, "epi_archive")
190184
sync <- rlang::arg_match(sync)
191185

192186
if (!identical(x$geo_type, y$geo_type)) {
@@ -409,11 +403,7 @@ epix_merge <- function(x, y,
409403
#'
410404
#' @noRd
411405
new_col_modify_recorder_df <- function(parent_df) {
412-
if (!inherits(parent_df, "data.frame")) {
413-
cli_abort('`parent_df` must inherit class `"data.frame"`',
414-
internal = TRUE
415-
)
416-
}
406+
assert_class(parent_df, "data.frame")
417407
`class<-`(parent_df, c("col_modify_recorder_df", class(parent_df)))
418408
}
419409

@@ -425,11 +415,7 @@ new_col_modify_recorder_df <- function(parent_df) {
425415
#'
426416
#' @noRd
427417
destructure_col_modify_recorder_df <- function(col_modify_recorder_df) {
428-
if (!inherits(col_modify_recorder_df, "col_modify_recorder_df")) {
429-
cli_abort('`col_modify_recorder_df` must inherit class `"col_modify_recorder_df"`',
430-
internal = TRUE
431-
)
432-
}
418+
assert_class(col_modify_recorder_df, "col_modify_recorder_df")
433419
list(
434420
unchanged_parent_df = col_modify_recorder_df %>%
435421
`attr<-`("epiprocess::col_modify_recorder_df::cols", NULL) %>%
@@ -676,9 +662,7 @@ epix_detailed_restricted_mutate <- function(.data, ...) {
676662
group_by.epi_archive <- function(.data, ..., .add = FALSE, .drop = dplyr::group_by_drop_default(.data)) {
677663
# `add` makes no difference; this is an ungrouped `epi_archive`.
678664
detailed_mutate <- epix_detailed_restricted_mutate(.data, ...)
679-
if (!rlang::is_bool(.drop)) {
680-
cli_abort("`.drop` must be TRUE or FALSE")
681-
}
665+
assert_logical(.drop)
682666
if (!.drop) {
683667
grouping_cols <- as.list(detailed_mutate[["archive"]][["DT"]])[detailed_mutate[["request_names"]]]
684668
grouping_col_is_factor <- purrr::map_lgl(grouping_cols, is.factor)

R/outliers.R

+2-4
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,8 @@ detect_outlr <- function(x = seq_along(y), y,
108108
results <- do.call(method, args = c(list("x" = x, "y" = y), args))
109109

110110
# Validate the output
111-
if (!is.data.frame(results) ||
112-
!all(c("lower", "upper", "replacement") %in% colnames(results))) {
113-
cli_abort("Outlier detection method must return a data frame with columns `lower`, `upper`, and `replacement`.")
114-
}
111+
assert_data_frame(results)
112+
assert_subset(c("lower", "upper", "replacement"), colnames(results))
115113

116114
# Update column names with model abbreviation
117115
colnames(results) <- paste(abbr, colnames(results), sep = "_")

0 commit comments

Comments
 (0)