|
| 1 | +#' KNN enhanced iterative AR forecaster with optional covariates |
| 2 | +#' |
| 3 | +#' @param x Unused covariates. Must to be missing (resulting in AR on `y`) . |
| 4 | +#' @param y Response. |
| 5 | +#' @param key_vars Factor(s). A prediction will be made for each unique |
| 6 | +#' combination. |
| 7 | +#' @param time_value the time value associated with each row of measurements. |
| 8 | +#' @param args Additional arguments specifying the forecasting task. Created |
| 9 | +#' by calling `knn_iteraive_ar_args_list()`. |
| 10 | +#' |
| 11 | +#' @return A data frame of point (and optionally interval) forecasts at multiple |
| 12 | +#' aheads (multiple horizons from one to specified `ahead`) for each unique combination of `key_vars`. |
| 13 | +#' @export |
| 14 | + |
| 15 | +knn_iteraive_ar_forecaster <- function(x, y, key_vars, time_value, |
| 16 | + args = knn_iteraive_ar_args_list()) { |
| 17 | + |
| 18 | + # TODO: function to verify standard forecaster signature inputs |
| 19 | + assign_arg_list(args) |
| 20 | + if (is.null(key_vars)) { # this is annoying/repetitive, seemingly necessary? |
| 21 | + keys <- NULL |
| 22 | + distinct_keys <- tibble(.dump = NA) |
| 23 | + } else { |
| 24 | + keys <- tibble::tibble(key_vars) |
| 25 | + distinct_keys <- dplyr::distinct(keys) |
| 26 | + } |
| 27 | + if (!is.null(x)) warning("The current version for KNN enhanced iterative forecasting strategy does not support covariates. 'x' will not be used!") |
| 28 | + |
| 29 | + |
| 30 | + # generate data |
| 31 | + pool <- create_lags_and_leads(NULL, y, c(1:query_window_len), 1:ahead, time_value, keys) |
| 32 | + # Return NA if insufficient training data |
| 33 | + if (nrow(pool) < topK) { |
| 34 | + qnames <- probs_to_string(levels) |
| 35 | + out <- dplyr::bind_cols(distinct_keys, point = NA) %>% |
| 36 | + dplyr::select(!dplyr::any_of(".dump")) |
| 37 | + return(enframer(out, qnames)) |
| 38 | + } |
| 39 | + # get test data |
| 40 | + time_keys <- data.frame(keys, time_value) |
| 41 | + test_time_value <- max(time_value) |
| 42 | + common_names <- names(time_keys) |
| 43 | + key_names <- setdiff(common_names, "time_value") |
| 44 | + Querys <- dplyr::left_join(time_keys, pool, by = common_names) %>% |
| 45 | + dplyr::group_by(dplyr::across(dplyr::all_of(key_names))) %>% |
| 46 | + tidyr::fill(dplyr::starts_with("x")) %>% |
| 47 | + dplyr::filter(time_value == test_time_value) %>% |
| 48 | + select(!dplyr::starts_with("y")) %>% |
| 49 | + drop_na() |
| 50 | + |
| 51 | + # embed querys and pool |
| 52 | + pool_raw <- pool |
| 53 | + pool <- pool %>% |
| 54 | + select(common_names, dplyr::starts_with("x"), "y1") %>% |
| 55 | + drop_na() |
| 56 | + pool_idx <- pool[common_names] |
| 57 | + |
| 58 | + Querys_idx <- Querys[common_names] |
| 59 | + pool_emb <- embedding(pool %>% select(-common_names, -dplyr::starts_with("y"))) |
| 60 | + # iterative prediction procedure |
| 61 | + |
| 62 | + tmp <- data.frame() |
| 63 | + for (i in 1:nrow(Querys)) { |
| 64 | + query <- as.numeric(Querys[i, -c(1:2)]) |
| 65 | + |
| 66 | + for (h in 1:ahead) { |
| 67 | + if (h == 1 | update_model) { |
| 68 | + query_emb <- embedding(t(query))[1, ] |
| 69 | + sims <- pool_emb %*% query_emb |
| 70 | + topk_id <- tensr:::topK(sims, topK) |
| 71 | + train_id <- pool_idx[topk_id, ] |
| 72 | + train_da <- train_id %>% |
| 73 | + left_join(pool, by = common_names) %>% |
| 74 | + select("y1", paste("x", lags, sep = "")) |
| 75 | + |
| 76 | + if (intercept) train_da$x0 <- 1 |
| 77 | + obj <- stats::lm( |
| 78 | + y1 ~ . + 0, |
| 79 | + data = train_da |
| 80 | + ) |
| 81 | + } |
| 82 | + |
| 83 | + test_da <- data.frame(t(query[lags])) |
| 84 | + names(test_da) <- paste("x", lags, sep = "") |
| 85 | + if (intercept) test_da$x0 <- 1 |
| 86 | + point <- stats::predict(obj, test_da) |
| 87 | + |
| 88 | + yname <- paste("y", h, sep = "") |
| 89 | + residual_pool <- pool_raw %>% |
| 90 | + select(common_names, dplyr::starts_with("x"), yname) %>% |
| 91 | + drop_na() |
| 92 | + residual_pool_emb <- embedding(residual_pool %>% select(-common_names, -yname)) |
| 93 | + sims <- residual_pool_emb %*% query_emb |
| 94 | + topk_id <- tensr:::topK(sims, topK) |
| 95 | + |
| 96 | + residual_da <- residual_pool[topk_id, ] |
| 97 | + gty <- residual_da[yname] |
| 98 | + residual_da <- residual_pool[topk_id, ] %>% |
| 99 | + select(-common_names, -yname) %>% |
| 100 | + as.matrix() |
| 101 | + for (j in 1:h) { |
| 102 | + residual_tmp <- data.frame(residual_da[, lags]) |
| 103 | + names(residual_tmp) <- paste("x", lags, sep = "") |
| 104 | + if (intercept) residual_tmp$x0 <- 1 |
| 105 | + pred <- stats::predict(obj, residual_tmp) |
| 106 | + residual_da <- cbind(pred, residual_da[, -query_window_len]) |
| 107 | + } |
| 108 | + |
| 109 | + r <- (gty - pred)[, 1] / pred |
| 110 | + r[is.na(r)] <- 0 |
| 111 | + q <- residual_quantiles_normlized(r, point, levels, symmetrize) |
| 112 | + q <- cbind(Querys_idx[i, key_names], q) |
| 113 | + q$ahead <- h |
| 114 | + tmp <- bind_rows(tmp, q) |
| 115 | + |
| 116 | + query <- c(point, query[-query_window_len]) |
| 117 | + } |
| 118 | + } |
| 119 | + if (nonneg) { |
| 120 | + tmp <- dplyr::mutate(tmp, dplyr::across(!ahead, ~ pmax(.x, 0))) |
| 121 | + } |
| 122 | + |
| 123 | + res <- tmp %>% |
| 124 | + dplyr::select(!dplyr::any_of(".dump")) %>% |
| 125 | + dplyr::relocate(ahead) |
| 126 | + return(res) |
| 127 | +} |
| 128 | + |
| 129 | + |
| 130 | + |
| 131 | +#'KNN enhanced iterative AR forecaster argument constructor |
| 132 | +#' |
| 133 | +#' Constructs a list of arguments for [knn_iteraive_ar_forecaster()]. |
| 134 | +#' |
| 135 | +#' @template param-lags |
| 136 | +#' @template param-query_window_len |
| 137 | +#' @template param-topK |
| 138 | +#' @template param-ahead |
| 139 | +#' @template param-min_train_window |
| 140 | +#' @template param-levels |
| 141 | +#' @template param-intercept |
| 142 | +#' @template param-symmetrize |
| 143 | +#' @template param-nonneg |
| 144 | +#' @template param-update_model |
| 145 | +#' @param quantile_by_key Not currently implemented |
| 146 | +#' |
| 147 | +#' @return A list containing updated parameter choices. |
| 148 | +#' @export |
| 149 | +#' |
| 150 | +#' @examples |
| 151 | +#' arx_args_list() |
| 152 | +#' arx_args_list(symmetrize = FALSE) |
| 153 | +#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120) |
| 154 | +knn_iteraive_ar_args_list <- function(lags = c(0, 7, 14), |
| 155 | + query_window_len = 50, |
| 156 | + topK = 500, |
| 157 | + ahead = 7, |
| 158 | + min_train_window = 20, |
| 159 | + levels = c(0.05, 0.95), |
| 160 | + intercept = TRUE, |
| 161 | + symmetrize = TRUE, |
| 162 | + nonneg = TRUE, |
| 163 | + quantile_by_key = FALSE, |
| 164 | + update_model = TRUE) { |
| 165 | + |
| 166 | + # error checking if lags is a list |
| 167 | + .lags <- lags |
| 168 | + if (is.list(lags)) lags <- unlist(lags) |
| 169 | + |
| 170 | + arg_is_scalar(ahead, min_train_window, query_window_len, topK) |
| 171 | + arg_is_nonneg_int(ahead, min_train_window, lags, query_window_len, topK) |
| 172 | + arg_is_lgl(intercept, symmetrize, nonneg, update_model) |
| 173 | + arg_is_probabilities(levels, allow_null = TRUE) |
| 174 | + |
| 175 | + max_lags <- max(lags) |
| 176 | + |
| 177 | + list( |
| 178 | + lags = .lags, ahead = as.integer(ahead), |
| 179 | + query_window_len = query_window_len, |
| 180 | + topK = topK, |
| 181 | + min_train_window = min_train_window, |
| 182 | + levels = levels, intercept = intercept, |
| 183 | + symmetrize = symmetrize, nonneg = nonneg, |
| 184 | + max_lags = max_lags, |
| 185 | + update_model = update_model |
| 186 | + ) |
| 187 | +} |
0 commit comments