Skip to content

Commit 6e1d3ea

Browse files
authored
Merge pull request #226 from dsweber2/formatting
Formatting
2 parents 76fb5ff + 77cd78b commit 6e1d3ea

File tree

79 files changed

+1860
-1352
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1860
-1352
lines changed

.git-blame-ignore-revs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# using styler at all
2+
aca7d5e7b66d8bac9d9fbcec3acdb98a087d58fa
3+
f12fcc2bf3fe0a75ba2b10eaaf8a1f1d22486a17

.github/workflows/styler.yml

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
2+
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
3+
on:
4+
push:
5+
paths:
6+
[
7+
"**.[rR]",
8+
"**.[qrR]md",
9+
"**.[rR]markdown",
10+
"**.[rR]nw",
11+
"**.[rR]profile",
12+
]
13+
14+
name: Style
15+
16+
jobs:
17+
style:
18+
runs-on: ubuntu-latest
19+
env:
20+
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
21+
steps:
22+
- name: Checkout repo
23+
uses: actions/checkout@v3
24+
with:
25+
fetch-depth: 0
26+
27+
- name: Setup R
28+
uses: r-lib/actions/setup-r@v2
29+
with:
30+
use-public-rspm: true
31+
32+
- name: Install dependencies
33+
uses: r-lib/actions/setup-r-dependencies@v2
34+
with:
35+
extra-packages: any::styler, any::roxygen2
36+
needs: styler
37+
38+
- name: Enable styler cache
39+
run: styler::cache_activate()
40+
shell: Rscript {0}
41+
42+
- name: Determine cache location
43+
id: styler-location
44+
run: |
45+
cat(
46+
"location=",
47+
styler::cache_info(format = "tabular")$location,
48+
"\n",
49+
file = Sys.getenv("GITHUB_OUTPUT"),
50+
append = TRUE,
51+
sep = ""
52+
)
53+
shell: Rscript {0}
54+
55+
- name: Cache styler
56+
uses: actions/cache@v3
57+
with:
58+
path: ${{ steps.styler-location.outputs.location }}
59+
key: ${{ runner.os }}-styler-${{ github.sha }}
60+
restore-keys: |
61+
${{ runner.os }}-styler-
62+
${{ runner.os }}-
63+
64+
- name: Style
65+
run: styler::style_pkg()
66+
shell: Rscript {0}
67+
68+
- name: Commit and push changes
69+
run: |
70+
if FILES_TO_COMMIT=($(git diff-index --name-only ${{ github.sha }} \
71+
| egrep --ignore-case '\.(R|[qR]md|Rmarkdown|Rnw|Rprofile)$'))
72+
then
73+
git config --local user.name "$GITHUB_ACTOR"
74+
git config --local user.email "[email protected]"
75+
git commit ${FILES_TO_COMMIT[*]} -m "Style code (GHA)"
76+
git pull --ff-only
77+
git push origin
78+
else
79+
echo "No changes to commit."
80+
fi

R/arx_classifier.R

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ arx_classifier <- function(
4747
predictors,
4848
trainer = parsnip::logistic_reg(),
4949
args_list = arx_class_args_list()) {
50-
51-
if (!is_classification(trainer))
50+
if (!is_classification(trainer)) {
5251
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
52+
}
5353

5454
wf <- arx_class_epi_workflow(
5555
epi_data, outcome, predictors, trainer, args_list
@@ -65,13 +65,15 @@ arx_classifier <- function(
6565
tibble::as_tibble() %>%
6666
dplyr::select(-time_value)
6767

68-
structure(list(
69-
predictions = preds,
70-
epi_workflow = wf,
71-
metadata = list(
72-
training = attr(epi_data, "metadata"),
73-
forecast_created = Sys.time()
74-
)),
68+
structure(
69+
list(
70+
predictions = preds,
71+
epi_workflow = wf,
72+
metadata = list(
73+
training = attr(epi_data, "metadata"),
74+
forecast_created = Sys.time()
75+
)
76+
),
7577
class = c("arx_class", "canned_epipred")
7678
)
7779
}
@@ -117,12 +119,13 @@ arx_class_epi_workflow <- function(
117119
predictors,
118120
trainer = NULL,
119121
args_list = arx_class_args_list()) {
120-
121122
validate_forecaster_inputs(epi_data, outcome, predictors)
122-
if (!inherits(args_list, c("arx_class", "alist")))
123+
if (!inherits(args_list, c("arx_class", "alist"))) {
123124
rlang::abort("args_list was not created using `arx_class_args_list().")
124-
if (!(is.null(trainer) || is_classification(trainer)))
125+
}
126+
if (!(is.null(trainer) || is_classification(trainer))) {
125127
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
128+
}
126129
lags <- arx_lags_validator(predictors, args_list$lags)
127130

128131
# --- preprocessor
@@ -172,8 +175,10 @@ arx_class_epi_workflow <- function(
172175
o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o))
173176
r <- r %>%
174177
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>%
175-
step_mutate(outcome_class = cut(!!o2, breaks = args_list$breaks),
176-
role = "outcome") %>%
178+
step_mutate(
179+
outcome_class = cut(!!o2, breaks = args_list$breaks),
180+
role = "outcome"
181+
) %>%
177182
step_epi_naomit() %>%
178183
step_training_window(n_recent = args_list$n_training)
179184

@@ -245,9 +250,7 @@ arx_class_args_list <- function(
245250
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
246251
log_scale = FALSE,
247252
additional_gr_args = list(),
248-
nafill_buffer = Inf
249-
) {
250-
253+
nafill_buffer = Inf) {
251254
.lags <- lags
252255
if (is.list(lags)) lags <- unlist(lags)
253256
method <- match.arg(method)
@@ -266,7 +269,8 @@ arx_class_args_list <- function(
266269
cli::cli_abort(
267270
c("`additional_gr_args` must be a {.cls list}.",
268271
"!" = "This is a {.cls {class(additional_gr_args)}}.",
269-
i = "See `?epiprocess::growth_rate` for available arguments.")
272+
i = "See `?epiprocess::growth_rate` for available arguments."
273+
)
270274
)
271275
}
272276

@@ -277,19 +281,20 @@ arx_class_args_list <- function(
277281

278282
max_lags <- max(lags)
279283
structure(
280-
enlist(lags = .lags,
281-
ahead,
282-
n_training,
283-
breaks,
284-
forecast_date,
285-
target_date,
286-
outcome_transform,
287-
max_lags,
288-
horizon,
289-
method,
290-
log_scale,
291-
additional_gr_args,
292-
nafill_buffer
284+
enlist(
285+
lags = .lags,
286+
ahead,
287+
n_training,
288+
breaks,
289+
forecast_date,
290+
target_date,
291+
outcome_transform,
292+
max_lags,
293+
horizon,
294+
method,
295+
log_scale,
296+
additional_gr_args,
297+
nafill_buffer
293298
),
294299
class = c("arx_class", "alist")
295300
)
@@ -300,4 +305,3 @@ print.arx_class <- function(x, ...) {
300305
name <- "ARX Classifier"
301306
NextMethod(name = name, ...)
302307
}
303-

R/arx_forecaster.R

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,24 @@
2525
#' jhu <- case_death_rate_subset %>%
2626
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
2727
#'
28-
#' out <- arx_forecaster(jhu, "death_rate",
29-
#' c("case_rate", "death_rate"))
28+
#' out <- arx_forecaster(
29+
#' jhu, "death_rate",
30+
#' c("case_rate", "death_rate")
31+
#' )
3032
#'
3133
#' out <- arx_forecaster(jhu, "death_rate",
32-
#' c("case_rate", "death_rate"), trainer = quantile_reg(),
33-
#' args_list = arx_args_list(levels = 1:9 / 10))
34+
#' c("case_rate", "death_rate"),
35+
#' trainer = quantile_reg(),
36+
#' args_list = arx_args_list(levels = 1:9 / 10)
37+
#' )
3438
arx_forecaster <- function(epi_data,
3539
outcome,
3640
predictors,
3741
trainer = parsnip::linear_reg(),
3842
args_list = arx_args_list()) {
39-
40-
if (!is_regression(trainer))
43+
if (!is_regression(trainer)) {
4144
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
45+
}
4246

4347
wf <- arx_fcast_epi_workflow(
4448
epi_data, outcome, predictors, trainer, args_list
@@ -54,13 +58,15 @@ arx_forecaster <- function(epi_data,
5458
tibble::as_tibble() %>%
5559
dplyr::select(-time_value)
5660

57-
structure(list(
58-
predictions = preds,
59-
epi_workflow = wf,
60-
metadata = list(
61-
training = attr(epi_data, "metadata"),
62-
forecast_created = Sys.time()
63-
)),
61+
structure(
62+
list(
63+
predictions = preds,
64+
epi_workflow = wf,
65+
metadata = list(
66+
training = attr(epi_data, "metadata"),
67+
forecast_created = Sys.time()
68+
)
69+
),
6470
class = c("arx_fcast", "canned_epipred")
6571
)
6672
}
@@ -85,25 +91,30 @@ arx_forecaster <- function(epi_data,
8591
#' jhu <- case_death_rate_subset %>%
8692
#' dplyr::filter(time_value >= as.Date("2021-12-01"))
8793
#'
88-
#' arx_fcast_epi_workflow(jhu, "death_rate",
89-
#' c("case_rate", "death_rate"))
94+
#' arx_fcast_epi_workflow(
95+
#' jhu, "death_rate",
96+
#' c("case_rate", "death_rate")
97+
#' )
9098
#'
9199
#' arx_fcast_epi_workflow(jhu, "death_rate",
92-
#' c("case_rate", "death_rate"), trainer = quantile_reg(),
93-
#' args_list = arx_args_list(levels = 1:9 / 10))
100+
#' c("case_rate", "death_rate"),
101+
#' trainer = quantile_reg(),
102+
#' args_list = arx_args_list(levels = 1:9 / 10)
103+
#' )
94104
arx_fcast_epi_workflow <- function(
95105
epi_data,
96106
outcome,
97107
predictors,
98108
trainer = NULL,
99109
args_list = arx_args_list()) {
100-
101110
# --- validation
102111
validate_forecaster_inputs(epi_data, outcome, predictors)
103-
if (!inherits(args_list, c("arx_fcast", "alist")))
112+
if (!inherits(args_list, c("arx_fcast", "alist"))) {
104113
cli::cli_abort("args_list was not created using `arx_args_list().")
105-
if (!(is.null(trainer) || is_regression(trainer)))
114+
}
115+
if (!(is.null(trainer) || is_regression(trainer))) {
106116
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
117+
}
107118
lags <- arx_lags_validator(predictors, args_list$lags)
108119

109120
# --- preprocessor
@@ -126,15 +137,17 @@ arx_fcast_epi_workflow <- function(
126137
# add all levels to the forecaster and update postprocessor
127138
tau <- sort(compare_quantile_args(
128139
args_list$levels,
129-
rlang::eval_tidy(trainer$args$tau))
130-
)
140+
rlang::eval_tidy(trainer$args$tau)
141+
))
131142
args_list$levels <- tau
132143
trainer$args$tau <- rlang::enquo(tau)
133144
f <- layer_quantile_distn(f, levels = tau) %>% layer_point_from_distn()
134145
} else {
135146
f <- layer_residual_quantiles(
136-
f, probs = args_list$levels, symmetrize = args_list$symmetrize,
137-
by_key = args_list$quantile_by_key)
147+
f,
148+
probs = args_list$levels, symmetrize = args_list$symmetrize,
149+
by_key = args_list$quantile_by_key
150+
)
138151
}
139152
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
140153
layer_add_target_date(target_date = target_date)
@@ -204,7 +217,6 @@ arx_args_list <- function(
204217
nonneg = TRUE,
205218
quantile_by_key = character(0L),
206219
nafill_buffer = Inf) {
207-
208220
# error checking if lags is a list
209221
.lags <- lags
210222
if (is.list(lags)) lags <- unlist(lags)
@@ -222,17 +234,19 @@ arx_args_list <- function(
222234

223235
max_lags <- max(lags)
224236
structure(
225-
enlist(lags = .lags,
226-
ahead,
227-
n_training,
228-
levels,
229-
forecast_date,
230-
target_date,
231-
symmetrize,
232-
nonneg,
233-
max_lags,
234-
quantile_by_key,
235-
nafill_buffer),
237+
enlist(
238+
lags = .lags,
239+
ahead,
240+
n_training,
241+
levels,
242+
forecast_date,
243+
target_date,
244+
symmetrize,
245+
nonneg,
246+
max_lags,
247+
quantile_by_key,
248+
nafill_buffer
249+
),
236250
class = c("arx_fcast", "alist")
237251
)
238252
}
@@ -248,16 +262,22 @@ compare_quantile_args <- function(alist, tlist) {
248262
default_alist <- eval(formals(arx_args_list)$levels)
249263
default_tlist <- eval(formals(quantile_reg)$tau)
250264
if (setequal(alist, default_alist)) {
251-
if (setequal(tlist, default_tlist)) return(sort(unique(union(alist, tlist))))
252-
else return(sort(unique(tlist)))
265+
if (setequal(tlist, default_tlist)) {
266+
return(sort(unique(union(alist, tlist))))
267+
} else {
268+
return(sort(unique(tlist)))
269+
}
253270
} else {
254-
if (setequal(tlist, default_tlist)) return(sort(unique(alist)))
255-
else {
256-
if (setequal(alist, tlist)) return(sort(unique(alist)))
271+
if (setequal(tlist, default_tlist)) {
272+
return(sort(unique(alist)))
273+
} else {
274+
if (setequal(alist, tlist)) {
275+
return(sort(unique(alist)))
276+
}
257277
rlang::abort(c(
258278
"You have specified different, non-default, quantiles in the trainier and `arx_args` options.",
259-
i = "Please only specify quantiles in one location.")
260-
)
279+
i = "Please only specify quantiles in one location."
280+
))
261281
}
262282
}
263283
}

0 commit comments

Comments
 (0)