Skip to content

Refactor arx #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
^epipredict\.Rproj$
^\.Rproj\.user$
^LICENSE\.md$
^drafts$
^man-roxygen$
12 changes: 10 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.1
Version: 0.0.1.9000
Authors@R:
c(
person(given = "Jacob",
Expand All @@ -21,7 +21,15 @@ RoxygenNote: 7.1.2
Imports:
dplyr,
magrittr,
data.table
tibble,
rlang,
purrr,
cli,
stats,
splines,
tidyr,
assertthat,
tidyselect
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
17 changes: 17 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
# Generated by roxygen2: do not edit by hand

export("%>%")
export(arx_args_list)
export(arx_forecaster)
export(create_lags_and_leads)
export(df_mat_mul)
export(get_precision)
export(grab_names)
export(smooth_arx_args_list)
export(smooth_arx_forecaster)
importFrom(magrittr,"%>%")
importFrom(rlang,"!!")
importFrom(rlang,":=")
importFrom(stats,lm)
importFrom(stats,poly)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,residuals)
importFrom(stats,setNames)
importFrom(tibble,tibble)
106 changes: 106 additions & 0 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#' AR forecaster with optional covariates
#'
#' @param x Covariates. Allowed to be missing (resulting in AR on `y`).
#' @param y Response.
#' @param key_vars Factor(s). A prediction will be made for each unique
#' combination.
#' @param time_value the time value associated with each row of measurements.
#' @param args Additional arguments specifying the forecasting task. Created
#' by calling `arx_args_list()`.
#'
#' @return A data frame of point (and optionally interval) forecasts at a single
#' ahead (unique horizon) for each unique combination of `key_vars`.
#' @export
arx_forecaster <- function(x, y, key_vars, time_value,
args = arx_args_list()) {

# TODO: function to verify standard forecaster signature inputs

assign_arg_list(args)
if (is.null(key_vars)) { # this is annoying/repetitive, seemingly necessary?
keys <- NULL
distinct_keys <- tibble(.dump = NA)
} else {
keys <- tibble::tibble(key_vars)
distinct_keys <- dplyr::distinct(keys)
}

# Return NA if insufficient training data
if (length(y) < min_train_window + max_lags + ahead) {
qnames <- probs_to_string(levels)
out <- dplyr::bind_cols(distinct_keys, point = NA) %>%
dplyr::select(!dplyr::any_of(".dump"))
return(enframer(out, qnames))
}

dat <- create_lags_and_leads(x, y, lags, ahead, time_value, keys)
if (intercept) dat$x0 <- 1

obj <- stats::lm(
y1 ~ . + 0, data = dat %>% dplyr::select(starts_with(c("x","y"))))

point <- make_predictions(obj, dat, time_value, keys)

# Residuals, simplest case, requires
# 1. same quantiles for all keys
# 2. `residuals(obj)` works
r <- residuals(obj)
q <- residual_quantiles(r, point, levels, symmetrize)

# Harder case requires handling failures of 1 and or 2, neither implemented
# 1. different quantiles by key, need to bind the keys, then group_modify
# 2 fails. need to bind the keys, grab, y and yhat, subtract
if (nonneg)
q <- dplyr::mutate(q, dplyr::across(dplyr::everything(), ~ pmax(.x, 0)))

return(
dplyr::bind_cols(distinct_keys, q) %>%
dplyr::select(!dplyr::any_of(".dump"))
)
}


#' ARX forecaster argument constructor
#'
#' Constructs a list of arguments for [arx_forecaster()].
#'
#' @template param-lags
#' @template param-ahead
#' @template param-min_train_window
#' @template param-levels
#' @template param-intercept
#' @template param-symmetrize
#' @template param-nonneg
#' @param quantile_by_key Not currently implemented
#'
#' @return A list containing updated parameter choices.
#' @export
#'
#' @examples
#' arx_args_list()
#' arx_args_list(symmetrize = FALSE)
#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120)
arx_args_list <- function(
lags = c(0, 7, 14), ahead = 7, min_train_window = 20,
levels = c(0.05, 0.95), intercept = TRUE,
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = FALSE) {

# error checking if lags is a list
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)

arg_is_scalar(ahead, min_train_window)
arg_is_nonneg_int(ahead, min_train_window, lags)
arg_is_lgl(intercept, symmetrize, nonneg)
arg_is_probabilities(levels, allow_null=TRUE)

max_lags <- max(lags)

list(lags = .lags, ahead = as.integer(ahead),
min_train_window = min_train_window,
levels = levels, intercept = intercept,
symmetrize = symmetrize, nonneg = nonneg,
max_lags = max_lags)
}
23 changes: 23 additions & 0 deletions R/assign_arg_list.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#' Assign argument list to inside an environment
#'
#' This function is similar to `attach()` but without the
#' need to detach. Calling it at the beginning of a forecaster
#' makes all members of the `arg_list` available inside the
#' forecaster with out the ugly `args$member` syntax.
#'
#' @param l List of named arguments.
#' @param env The environment where the args should be assigned.
#' The default goes into the calling environment.
#'
#' @return Nothing is returned. Called for the side effects.
#' @examples
#' \dontrun{
#' rm(list = ls())
#' l <- list(a=1, b=c(12, 10), ff = function() -5)
#' assign_arg_list(l)
#' a
#' }
assign_arg_list <- function(l, env = parent.frame()) {
stopifnot(is.list(l), length((nm <- names(l))) == length(l))
for (a in seq_along(l)) assign(nm[a], l[[a]], envir = env)
}
40 changes: 18 additions & 22 deletions R/create_leads_and_lags.R → R/create_lags_and_leads.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' Create lags and leads of predictors and response
#'
#' @param x Data frame or matrix. Predictor variables.
#' @param x Data frame or matrix. Predictor variables. May be
#' missing.
#' @param y Response vector. Typical usage will "lead" y by the
#' number of steps forward for the prediction horizon (ahead).
#' @param xy_lags Vector or list. If a vector, the lags will apply
Expand All @@ -11,6 +12,10 @@
#' @param y_leads Scalar or vector. If a scalar, we "lead" `y` by this
#' amount. A vector will produce multiple columns of `y` if this is
#' useful for your model. Negative values will "lag" the variable.
#' @param time_value Vector of time values at which the data are
#' observed
#' @param key_vars Factors representing different groups. May be
#' `NULL` (the default).
#'
#' @return A `data.frame`.
#' @export
Expand All @@ -19,13 +24,15 @@
#'
#' x <- 1:20
#' y <- -20:-1
#' create_lags_and_leads(x, y, c(1, 2), 1)
#' create_lags_and_leads(x, y, list(c(1, 2), 1), 1)
#' create_lags_and_leads(x, y, list(c(-1, 1), NULL), 1)
#' create_lags_and_leads(x, y, c(1, 2), c(0, 1))
create_lags_and_leads <- function(x, y, xy_lags, y_leads) {
# TODO: make it so we don't clobber names if they exist.
if (!missing(x)) x <- data.frame(x, y) else x <- data.frame(y)
#' time_value <- c(1:18, 20, 21)
#' create_lags_and_leads(x, y, c(1, 2), 1, time_value)
#' create_lags_and_leads(x, y, list(c(1, 2), 1), 1, time_value)
#' create_lags_and_leads(x, y, list(c(-1, 1), NULL), 1, time_value)
#' create_lags_and_leads(x, y, c(1, 2), c(0, 1), time_value)
create_lags_and_leads <- function(x, y, xy_lags, y_leads,
time_value, key_vars = NULL) {

if (!missing(x)) x <- tibble(x, y) else x <- tibble(y)
if (!is.list(xy_lags)) xy_lags <- list(xy_lags)
p = ncol(x)
assertthat::assert_that(
Expand All @@ -34,19 +41,8 @@ create_lags_and_leads <- function(x, y, xy_lags, y_leads) {
"If a list, it must have length 1 or `ncol(x) + 1`."))
xy_lags = rep(xy_lags, length.out = p)

# Build features and response
dat <- do.call(
data.frame,
unlist( # Below we loop through and build the lagged features
purrr::map(1:p, function(i) {
purrr::map(xy_lags[[i]], function(lag) dplyr::lag(x[,i], n = lag))
}),
recursive = FALSE)) %>%
magrittr::set_names(paste0("x", 1:length(unlist(xy_lags))))

# Technically, can produce multiple cols of y
y <- suppressMessages(purrr::map_dfc(y_leads, ~ dplyr::lead(y, n = .x)))
if (ncol(y) > 1) names(y) <- paste0("y", 1:ncol(y)) else names(y) <- "y"
xdat <- epi_shift(x, xy_lags, time_value, key_vars)
ydat <- epi_shift(y, -1 * y_leads, time_value, key_vars, "y")

return(dplyr::bind_cols(y, dat))
suppressMessages(dplyr::full_join(ydat, xdat))
}
34 changes: 34 additions & 0 deletions R/df_mat_mul.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#' Multiply columns of a `data.frame` by a matrix
#'
#' @param dat A data.frame
#' @param mat A matrix
#' @param out_names Character vector. Creates the names of the resulting
#' columns after multiplication. If a scalar, this is treated as a
#' prefix and the remaining columns will be numbered sequentially.
#' @param ... <[`tidy-select`][dplyr::dplyr_tidy_select]> One or more unquoted
#' expressions separated by commas. Variable names can be used as if they
#' were positions in the data frame, so expressions like `x:y` can
#' be used to select a range of variables.
#'
#' @return A data.frame with the new columns at the right. Original
#' columns are removed.
#' @export
#'
#' @examples
#' df <- data.frame(matrix(1:200, ncol = 10))
#' mat <- matrix(1:10, ncol = 2)
#' df_mat_mul(df, mat, "z", dplyr::num_range("X", 2:6))
df_mat_mul <- function(dat, mat, out_names = "out", ...) {

stopifnot(is.matrix(mat), is.data.frame(dat))
arg_is_chr(out_names)
if (length(out_names) > 1) stopifnot(length(out_names) == nrow(mat))
else out_names = paste0(out_names, seq_len(ncol(mat)))

dat_mat <- dplyr::select(dat, ...)
nm <- grab_names(dat_mat, everything())
dat_neg <- dplyr::select(dat, !dplyr::all_of(nm))
new_cols <- as.matrix(dat_mat) %*% mat
colnames(new_cols) <- out_names
dplyr::bind_cols(dat_neg, as.data.frame(new_cols))
}
31 changes: 31 additions & 0 deletions R/epi_shift.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#' Shift predictors while maintaining grouping and time_value ordering
#'
#' This is a lower-level function. As such it performs no error checking.
#'
#' @param x Data frame. Variables to lag
#' @param lags List. Each list element is a vector of lags.
#' Negative values produce leads. The list should have the same
#' length as the number of columns in `x`.
#' @param time_value Vector. Same length as `x` giving time stamps.
#' @param keys Data frame, vector, or `NULL`. Additional grouping vars.
#' @param out_name Chr. The output list will use this as a prefix.
#'
#' @return a list of tibbles
epi_shift <- function(x, lags, time_value, keys = NULL, out_name = "x") {
if (!is.data.frame(x)) x <- data.frame(x)
if (is.null(keys)) keys <- rep("empty", nrow(x))
p_in = ncol(x)
out_list <- tibble::tibble(i = 1:p_in, lag = lags) %>%
tidyr::unchop(lag) %>% # what is chop
dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>%
# One list element for each lagged feature
purrr::pmap(function(i, lag, name) {
tibble(keys,
time_value = time_value + lag, # Shift back
!!name := x[[i]])
})
if (is.data.frame(keys)) common_names <- c(names(keys), "time_value")
else common_names <- c("keys", "time_value")

purrr::reduce(out_list, dplyr::full_join, by = common_names)
}
22 changes: 22 additions & 0 deletions R/grab_names.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#' Get the names from a data frame via tidy select
#'
#' Given a data.frame, use `<tidy-select>` syntax to choose
#' some variables. Return the names of those variables
#'
#' As this is an internal function, no checks are performed.
#'
#' @param dat a data.frame
#' @param ... <[`tidy-select`][dplyr::dplyr_tidy_select]> One or more unquoted
#' expressions separated by commas. Variable names can be used as if they
#' were positions in the data frame, so expressions like `x:y` can
#' be used to select a range of variables.
#'
#' @export
#' @return a character vector
#' @examples
#' df <- data.frame(a = 1, b = 2, cc = rep(NA, 3))
#' grab_names(df, dplyr::starts_with("c"))
grab_names <- function(dat, ...) {
x <- rlang::expr(c(...))
names(tidyselect::eval_select(x, dat))
}
4 changes: 4 additions & 0 deletions R/imports.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#' @importFrom tibble tibble
#' @importFrom rlang := !!
#' @importFrom stats poly predict lm residuals quantile
NULL
26 changes: 26 additions & 0 deletions R/make_predictions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
make_predictions <- function(obj, dat, time_value, key_vars = NULL) {
# TODO: validate arguments
#
stopifnot(is.data.frame(dat))
if (is.null(key_vars)) keys <- rep("empty", length(time_value))
else keys <- key_vars
time_keys <- data.frame(keys, time_value)
common_names <- names(time_keys)
key_names <- setdiff(common_names, "time_value")

dat <- dplyr::left_join(time_keys, dat, by = common_names) %>%
dplyr::group_by(dplyr::across(dplyr::all_of(key_names))) %>%
tidyr::fill(dplyr::starts_with("x"))
## DJM: Old version below. Replaced with tidyr version above
#data.table::setDT(dat) # Convert to a data.table object by reference
#cols <- setdiff(names(dat), common_names)
#dat[, (cols) := data.table::nafill(.SD, type = "locf"),
# .SDcols = cols, by = key_names]
test_time_value <- max(time_value)
newdata <- dat %>%
dplyr::filter(time_value == test_time_value)


point <- stats::predict(obj, newdata = newdata)
point
}
Loading