Skip to content

Commit e8cfd8e

Browse files
authored
Merge pull request #20 from lzlzlizi/main
Add support for knn enhanced ar(x) forecastors
2 parents c5b3715 + b4a45e6 commit e8cfd8e

18 files changed

+880
-22
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ Imports:
2222
stats,
2323
tibble,
2424
tidyr,
25-
tidyselect
25+
tidyselect,
26+
tensr
2627
Suggests:
2728
covidcast,
2829
data.table,

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ export(epi_keys)
2121
export(epi_recipe)
2222
export(get_precision)
2323
export(grab_names)
24+
export(knn_iteraive_ar_args_list)
25+
export(knn_iteraive_ar_forecaster)
26+
export(knnarx_args_list)
27+
export(knnarx_forecaster)
2428
export(smooth_arx_args_list)
2529
export(smooth_arx_forecaster)
2630
export(step_epi_ahead)

R/arx_forecaster.R

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ arx_forecaster <- function(x, y, key_vars, time_value,
3737
if (intercept) dat$x0 <- 1
3838

3939
obj <- stats::lm(
40-
y1 ~ . + 0, data = dat %>% dplyr::select(starts_with(c("x","y"))))
40+
y1 ~ . + 0,
41+
data = dat %>% dplyr::select(starts_with(c("x", "y")))
42+
)
4143

4244
point <- make_predictions(obj, dat, time_value, keys)
4345

@@ -50,8 +52,9 @@ arx_forecaster <- function(x, y, key_vars, time_value,
5052
# Harder case requires handling failures of 1 and or 2, neither implemented
5153
# 1. different quantiles by key, need to bind the keys, then group_modify
5254
# 2 fails. need to bind the keys, grab, y and yhat, subtract
53-
if (nonneg)
55+
if (nonneg) {
5456
q <- dplyr::mutate(q, dplyr::across(dplyr::everything(), ~ pmax(.x, 0)))
57+
}
5558

5659
return(
5760
dplyr::bind_cols(distinct_keys, q) %>%
@@ -80,12 +83,11 @@ arx_forecaster <- function(x, y, key_vars, time_value,
8083
#' arx_args_list()
8184
#' arx_args_list(symmetrize = FALSE)
8285
#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120)
83-
arx_args_list <- function(
84-
lags = c(0, 7, 14), ahead = 7, min_train_window = 20,
85-
levels = c(0.05, 0.95), intercept = TRUE,
86-
symmetrize = TRUE,
87-
nonneg = TRUE,
88-
quantile_by_key = FALSE) {
86+
arx_args_list <- function(lags = c(0, 7, 14), ahead = 7, min_train_window = 20,
87+
levels = c(0.05, 0.95), intercept = TRUE,
88+
symmetrize = TRUE,
89+
nonneg = TRUE,
90+
quantile_by_key = FALSE) {
8991

9092
# error checking if lags is a list
9193
.lags <- lags
@@ -94,13 +96,15 @@ arx_args_list <- function(
9496
arg_is_scalar(ahead, min_train_window)
9597
arg_is_nonneg_int(ahead, min_train_window, lags)
9698
arg_is_lgl(intercept, symmetrize, nonneg)
97-
arg_is_probabilities(levels, allow_null=TRUE)
99+
arg_is_probabilities(levels, allow_null = TRUE)
98100

99101
max_lags <- max(lags)
100102

101-
list(lags = .lags, ahead = as.integer(ahead),
102-
min_train_window = min_train_window,
103-
levels = levels, intercept = intercept,
104-
symmetrize = symmetrize, nonneg = nonneg,
105-
max_lags = max_lags)
106-
}
103+
list(
104+
lags = .lags, ahead = as.integer(ahead),
105+
min_train_window = min_train_window,
106+
levels = levels, intercept = intercept,
107+
symmetrize = symmetrize, nonneg = nonneg,
108+
max_lags = max_lags
109+
)
110+
}

R/knn_iterative_ar_forecaster.R

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#' KNN enhanced iterative AR forecaster with optional covariates
2+
#'
3+
#' @param x Unused covariates. Must to be missing (resulting in AR on `y`) .
4+
#' @param y Response.
5+
#' @param key_vars Factor(s). A prediction will be made for each unique
6+
#' combination.
7+
#' @param time_value the time value associated with each row of measurements.
8+
#' @param args Additional arguments specifying the forecasting task. Created
9+
#' by calling `knn_iteraive_ar_args_list()`.
10+
#'
11+
#' @return A data frame of point (and optionally interval) forecasts at multiple
12+
#' aheads (multiple horizons from one to specified `ahead`) for each unique combination of `key_vars`.
13+
#' @export
14+
15+
knn_iteraive_ar_forecaster <- function(x, y, key_vars, time_value,
16+
args = knn_iteraive_ar_args_list()) {
17+
18+
# TODO: function to verify standard forecaster signature inputs
19+
assign_arg_list(args)
20+
if (is.null(key_vars)) { # this is annoying/repetitive, seemingly necessary?
21+
keys <- NULL
22+
distinct_keys <- tibble(.dump = NA)
23+
} else {
24+
keys <- tibble::tibble(key_vars)
25+
distinct_keys <- dplyr::distinct(keys)
26+
}
27+
if (!is.null(x)) warning("The current version for KNN enhanced iterative forecasting strategy does not support covariates. 'x' will not be used!")
28+
29+
30+
# generate data
31+
pool <- create_lags_and_leads(NULL, y, c(1:query_window_len), 1:ahead, time_value, keys)
32+
# Return NA if insufficient training data
33+
if (nrow(pool) < topK) {
34+
qnames <- probs_to_string(levels)
35+
out <- dplyr::bind_cols(distinct_keys, point = NA) %>%
36+
dplyr::select(!dplyr::any_of(".dump"))
37+
return(enframer(out, qnames))
38+
}
39+
# get test data
40+
time_keys <- data.frame(keys, time_value)
41+
test_time_value <- max(time_value)
42+
common_names <- names(time_keys)
43+
key_names <- setdiff(common_names, "time_value")
44+
Querys <- dplyr::left_join(time_keys, pool, by = common_names) %>%
45+
dplyr::group_by(dplyr::across(dplyr::all_of(key_names))) %>%
46+
tidyr::fill(dplyr::starts_with("x")) %>%
47+
dplyr::filter(time_value == test_time_value) %>%
48+
select(!dplyr::starts_with("y")) %>%
49+
drop_na()
50+
51+
# embed querys and pool
52+
pool_raw <- pool
53+
pool <- pool %>%
54+
select(common_names, dplyr::starts_with("x"), "y1") %>%
55+
drop_na()
56+
pool_idx <- pool[common_names]
57+
58+
Querys_idx <- Querys[common_names]
59+
pool_emb <- embedding(pool %>% select(-common_names, -dplyr::starts_with("y")))
60+
# iterative prediction procedure
61+
62+
tmp <- data.frame()
63+
for (i in 1:nrow(Querys)) {
64+
query <- as.numeric(Querys[i, -c(1:2)])
65+
66+
for (h in 1:ahead) {
67+
if (h == 1 | update_model) {
68+
query_emb <- embedding(t(query))[1, ]
69+
sims <- pool_emb %*% query_emb
70+
topk_id <- tensr:::topK(sims, topK)
71+
train_id <- pool_idx[topk_id, ]
72+
train_da <- train_id %>%
73+
left_join(pool, by = common_names) %>%
74+
select("y1", paste("x", lags, sep = ""))
75+
76+
if (intercept) train_da$x0 <- 1
77+
obj <- stats::lm(
78+
y1 ~ . + 0,
79+
data = train_da
80+
)
81+
}
82+
83+
test_da <- data.frame(t(query[lags]))
84+
names(test_da) <- paste("x", lags, sep = "")
85+
if (intercept) test_da$x0 <- 1
86+
point <- stats::predict(obj, test_da)
87+
88+
yname <- paste("y", h, sep = "")
89+
residual_pool <- pool_raw %>%
90+
select(common_names, dplyr::starts_with("x"), yname) %>%
91+
drop_na()
92+
residual_pool_emb <- embedding(residual_pool %>% select(-common_names, -yname))
93+
sims <- residual_pool_emb %*% query_emb
94+
topk_id <- tensr:::topK(sims, topK)
95+
96+
residual_da <- residual_pool[topk_id, ]
97+
gty <- residual_da[yname]
98+
residual_da <- residual_pool[topk_id, ] %>%
99+
select(-common_names, -yname) %>%
100+
as.matrix()
101+
for (j in 1:h) {
102+
residual_tmp <- data.frame(residual_da[, lags])
103+
names(residual_tmp) <- paste("x", lags, sep = "")
104+
if (intercept) residual_tmp$x0 <- 1
105+
pred <- stats::predict(obj, residual_tmp)
106+
residual_da <- cbind(pred, residual_da[, -query_window_len])
107+
}
108+
109+
r <- (gty - pred)[, 1] / pred
110+
r[is.na(r)] <- 0
111+
q <- residual_quantiles_normlized(r, point, levels, symmetrize)
112+
q <- cbind(Querys_idx[i, key_names], q)
113+
q$ahead <- h
114+
tmp <- bind_rows(tmp, q)
115+
116+
query <- c(point, query[-query_window_len])
117+
}
118+
}
119+
if (nonneg) {
120+
tmp <- dplyr::mutate(tmp, dplyr::across(!ahead, ~ pmax(.x, 0)))
121+
}
122+
123+
res <- tmp %>%
124+
dplyr::select(!dplyr::any_of(".dump")) %>%
125+
dplyr::relocate(ahead)
126+
return(res)
127+
}
128+
129+
130+
131+
#'KNN enhanced iterative AR forecaster argument constructor
132+
#'
133+
#' Constructs a list of arguments for [knn_iteraive_ar_forecaster()].
134+
#'
135+
#' @template param-lags
136+
#' @template param-query_window_len
137+
#' @template param-topK
138+
#' @template param-ahead
139+
#' @template param-min_train_window
140+
#' @template param-levels
141+
#' @template param-intercept
142+
#' @template param-symmetrize
143+
#' @template param-nonneg
144+
#' @template param-update_model
145+
#' @param quantile_by_key Not currently implemented
146+
#'
147+
#' @return A list containing updated parameter choices.
148+
#' @export
149+
#'
150+
#' @examples
151+
#' arx_args_list()
152+
#' arx_args_list(symmetrize = FALSE)
153+
#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120)
154+
knn_iteraive_ar_args_list <- function(lags = c(0, 7, 14),
155+
query_window_len = 50,
156+
topK = 500,
157+
ahead = 7,
158+
min_train_window = 20,
159+
levels = c(0.05, 0.95),
160+
intercept = TRUE,
161+
symmetrize = TRUE,
162+
nonneg = TRUE,
163+
quantile_by_key = FALSE,
164+
update_model = TRUE) {
165+
166+
# error checking if lags is a list
167+
.lags <- lags
168+
if (is.list(lags)) lags <- unlist(lags)
169+
170+
arg_is_scalar(ahead, min_train_window, query_window_len, topK)
171+
arg_is_nonneg_int(ahead, min_train_window, lags, query_window_len, topK)
172+
arg_is_lgl(intercept, symmetrize, nonneg, update_model)
173+
arg_is_probabilities(levels, allow_null = TRUE)
174+
175+
max_lags <- max(lags)
176+
177+
list(
178+
lags = .lags, ahead = as.integer(ahead),
179+
query_window_len = query_window_len,
180+
topK = topK,
181+
min_train_window = min_train_window,
182+
levels = levels, intercept = intercept,
183+
symmetrize = symmetrize, nonneg = nonneg,
184+
max_lags = max_lags,
185+
update_model = update_model
186+
)
187+
}

0 commit comments

Comments
 (0)