-
Notifications
You must be signed in to change notification settings - Fork 10
ARX classifier vignette #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
0abdac9
Add arx classifier vignette
rachlobay e090eb8
Remove
rachlobay e7ec836
Try adding vignette again
rachlobay d974e43
Remove space
rachlobay aeef3c3
Spacing
rachlobay 3df82b2
Try to get rid of error
rachlobay e4d8fb6
Space
rachlobay 29afcc3
Style file
rachlobay 1408283
Add packages
rachlobay bbf8554
Add ggplot2
rachlobay e76be9c
- cache
rachlobay 2893035
Update title header
rachlobay 05467cd
Update vignettes/arx-classifier-vignette.Rmd
rachlobay 18b7440
Update vignettes/arx-classifier-vignette.Rmd
rachlobay 3015abb
help() instead of question mark
rachlobay c27f688
Update vignettes/arx-classifier-vignette.Rmd
rachlobay 0600ac3
Add purrr to suggests
rachlobay 45da6d2
Rename file to arx-classifier.Rmd
rachlobay 40f71b4
Update references to vignettes in this package
rachlobay 1a2f485
Remove horizon stuff
rachlobay 8092e0a
Add a bit about n_training
rachlobay 5641c95
remove surges
rachlobay b73ccac
Run use_this(purrr) and edit conclusion
rachlobay b3079dd
translating
rachlobay 198d397
scale_colour_brewer()
rachlobay f1f80d3
Try 10% increase
rachlobay 04553a6
see previous
rachlobay a78069e
description
rachlobay fb6ba33
Try to pass check
rachlobay 24ff1e7
map over larger window
rachlobay 29205c3
description
rachlobay dd357e7
Add purrr to suggests
rachlobay 3845afa
Style file
rachlobay ec24435
Moved arx-classifier to articles
rachlobay 9ba5580
Merge branch 'dev' into vignette-classifier
rachlobay b1adbd5
Please work
rachlobay b2cb7e5
Merge branch 'vignette-classifier' of https://github.com/cmu-delphi/e…
rachlobay cdd5a91
test
rachlobay 3c6a44a
make sure version is updated
rachlobay fe1db83
Use latest version of description
rachlobay 2b25c35
Move update vignette. May be parsnip warning
rachlobay File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ Imports: | |
glue, | ||
hardhat (>= 1.3.0), | ||
magrittr, | ||
purrr, | ||
quantreg, | ||
recipes (>= 1.0.4), | ||
rlang, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
--- | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
title: "arx_classifier" | ||
--- | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```{r setup, include = FALSE} | ||
knitr::opts_chunk$set( | ||
echo = TRUE, | ||
collapse = FALSE, | ||
comment = "#>", | ||
warning = FALSE, | ||
message = FALSE, | ||
cache = TRUE | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
``` | ||
|
||
## Load required packages | ||
|
||
```{r, message = FALSE, warning = FALSE} | ||
library(dplyr) | ||
library(purrr) | ||
library(ggplot2) | ||
library(epipredict) | ||
``` | ||
|
||
## Introducing the ARX classifier | ||
|
||
The `arx_classifier()` function is an autoregressive classification model for `epi_df` data that is used to predict a discrete class for each case under consideration. It is a direct forecaster in that it estimates the classes at a specific horizon or ahead value. | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
To get a sense of how the `arx_classifier()` works, let's consider a simple example with minimal inputs. For this, we will use the built-in `case_death_rate_subset` that contains confirmed COVID-19 cases and deaths from JHU CSSE for all states over Dec 31, 2020 to Dec 31, 2021. From this, we'll take a subset of data for five states over June 4, 2021 to December 31, 2021. Our objective is to predict whether the case rates are increasing when considering the 0, 7 and 14 day case rates: | ||
|
||
```{r} | ||
jhu <- case_death_rate_subset %>% | ||
dplyr::filter( | ||
time_value >= "2021-06-04", | ||
time_value <= "2021-12-31", | ||
geo_value %in% c("ca", "fl", "tx", "ny", "nj") | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
out <- arx_classifier(jhu, | ||
outcome = "case_rate", | ||
predictors = "case_rate" | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
out$predictions | ||
``` | ||
|
||
The key takeaway from the predictions is that there are two prediction classes: (-Inf, 0.25] and (0.25, Inf). This is because for our goal of classification the classes must be discrete. The discretization of the real-valued outcome is controlled by the `breaks` argument, which defaults to 0.25. Such breaks will be automatically extended to cover the entire real line. For example, the default break of 0.25 is silently extended to breaks = c(-Inf, .25, Inf) and, therefore, results in two classes: [-Inf, 0.25] and (0.25, Inf). These two classes are used to discretize the outcome. The conversion of the outcome to such classes is handled internally. So if discrete classes already exist for the outcome in the `epi_df`, then we recommend to code a classifier from scratch using the `epi_workflow` framework for more control. | ||
|
||
The `trainer` is a `parsnip` model describing the type of estimation such that `mode = "classification"` is enforced. The two typical trainers that are used are `parsnip::logistic_reg()` for two classes or `parsnip::multinom_reg()` for more than two classes. | ||
|
||
```{r} | ||
workflows::extract_spec_parsnip(out$epi_workflow) | ||
``` | ||
|
||
From the parsnip model specification, we can see that the trainer used is logistic regression, which is expected for our binary outcome. More complicated trainers like `parsnip::naive_Bayes()` or `parsnip::rand_forest()` may also be used (however, we will stick to the basics in this gentle introduction to the classifier). | ||
|
||
If you use the default trainer of logistic regression for binary classification and you decide against using the default break of 0.25, then you should only input one break so that there are two classification bins to properly dichotomize the outcome. For example, let's set a break of 0.5 instead of relying on the default of 0.25. We can do this by passing 0.5 to the `breaks` argument in `arx_class_args_list()` as follows: | ||
|
||
```{r} | ||
out_break_0.5 <- arx_classifier(jhu, | ||
outcome = "case_rate", | ||
predictors = "case_rate", | ||
args_list = arx_class_args_list( | ||
breaks = 0.5 | ||
) | ||
) | ||
|
||
out_break_0.5$predictions | ||
``` | ||
Indeed, we can observe that the two `.pred_class` are now (-Inf, 0.5] and (0.5, Inf). | ||
|
||
```{r} | ||
?arx_class_args_list | ||
``` | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Additional arguments that may be supplied to `arx_class_args_list()` include the expected `lags` and `ahead` arguments for an autoregressive-type model. These have default values of 0, 7, and 14 days for the lags of the predictors and 7 days ahead of the forecast date for predicting the outcome. There is also `n_training` to indicate the upper bound for the number of training rows per key as well as `forecast_date` and `target_date` to specify the date that the forecast is created and intended, respectively. We will not dwell on these arguments here as they are not unique to this classifier or absolutely essential to understanding how it operates. The remaining arguments will be discussed organically, as they are needed to serve our purposes. For information on any remaining arguments that are not discussed here, please see the function documentation for a complete list and their definitions. | ||
|
||
## Example of using the ARX classifier | ||
|
||
Now, to demonstrate the power and utility of this built-in arx classifier, we will adapt the classification example that was written from scratch in the [Examples of Preprocessing and Models vignette](https://cmu-delphi.github.io/epipredict/articles/preprocessing-and-models.html). However, to keep things simple and not merely a direct translation, we will only consider two prediction categories and leave the extension to three as an exercise for the reader. | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
To motivate this example, a major use of autoregressive classification models is to predict upsurges or downsurges like in hotspot prediction models to anticipate the direction of the outcome (see [McDonald, Bien, Green, Hu, et al. (2021)](https://www.pnas.org/doi/full/10.1073/pnas.2111453118) for more on these). In our case, one simple question that such models can help answer is... Do we expect that the future will have increased case rates or not relative to the present? #%% Ok to say relative to the present? Also, should I be more precise here and say relative change or best to keep it simple? | ||
|
||
To answer this question, we can create a predictive model for upsurges and downsurges of case rates rather than the raw case rates themselves. For this situation, our target is to predict whether there is an increase in case rates or not. Following [McDonald, Bien, Green, Hu, et al. (2021)](https://www.pnas.org/doi/full/10.1073/pnas.2111453118), we look at the relative change between $Y_{l,t}$ and $Y_{l, t+a}$, where the former is the case rate at location $l$ at time $t$ and the latter is the rate for that location at time $t+a$. Using these variables, we define a categorical response variable with two classes | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
$$\begin{align} | ||
Z_{l,t} = \left\{\begin{matrix} | ||
\text{up,} & \text{if } Y_{l,t}^\Delta > 0.25\\ | ||
\text{not up,} & \text{otherwise} | ||
\end{matrix}\right. | ||
\end{align}$$ | ||
where $Y_{l,t}^\Delta = (Y_{l, t} - Y_{l, t- 7} / Y_{l, t- 7}$. If $Y_{l,t}^\Delta$ > 0.25, meaning that the number of new cases over the week has increased by over 25\%, then $Z_{l,t}$ is up. This is the criteria for location $l$ to be a hotspot at time $t$. On the other hand, if $Y_{l,t}^\Delta$ \leq 0.25$, then then $Z_{l,t}$ is categorized as not up, meaning that there has not been a >25\% increase in the new cases over the past week. | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The logistic regression model we use to predict this binary response can be considered to be a simplification of the multinomial regression model presented in the [Examples of Preprocessing and Models vignette](https://cmu-delphi.github.io/epipredict/articles/preprocessing-and-models.html): | ||
|
||
$$\begin{align} | ||
\pi_{\text{up}}(x) &= Pr(Z_{l, t} = \text{up}|x) = \frac{e^{g_{\text{up}}(x)}}{1 + e^{g_{\text{up}}(x)}}, \\ | ||
\pi_{\text{not up}}(x)&= Pr(Z_{l, t} = \text{not up}|x) = 1 - Pr(Z_{l, t} = \text{up}|x) = \frac{1}{1 + e^{g_{\text{up}}(x)}} | ||
\end{align}$$ | ||
where | ||
|
||
$$ | ||
g_{\text{up}}(x) = \log\left ( \frac{\Pr(Z_{l, t} = \text{up} \vert x)}{\Pr(Z_{l, t} = \text{not up} \vert x)} \right ) = \beta_{10} + \beta_{11}t + \delta_{10}s_{state_1} + \delta_{10}s_{state_2} ... + \beta_{12}Y_{l,t}^\Delta + \beta_{13}Y_{l,t-7}^\Delta + \beta_{14}Y_{l,t-14}^\Delta. | ||
$$ | ||
|
||
Now then, we will operate on the same subset of the `case_death_rate_subset` that we used in our above example. This time, we will use it to investigate whether the number of newly reported cases over the past 7 days has increased by at least 25\% compared to the preceding week for our sample of states. | ||
|
||
Notice that by using the `arx_classifier()` function we've completely eliminated the need to manually categorize the response variable and implement pre-processing steps, which was necessary in the [Examples of Preprocessing and Models vignette](https://cmu-delphi.github.io/epipredict/articles/preprocessing-and-models.html). | ||
|
||
```{r} | ||
log_res <- arx_classifier( | ||
jhu, | ||
outcome = "case_rate", | ||
predictors = "case_rate", | ||
args_list = arx_class_args_list( | ||
breaks = 0.25 / 7 # division by 7 gives weekly not daily | ||
) | ||
) | ||
|
||
workflows::extract_preprocessor(log_res$epi_workflow) | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
Comparing the pre-processing steps for this to those in the other vignette, we can see that they are not precisely the same, but they cover the same essentials of transforming `case_rate` to the growth rate scale (`step_growth_rate()`), lagging the predictors (`step_epi_lag()`), leading the response (`step_epi_ahead()`), which are both constructed from the growth rates, and constructing the binary classification response variable (`step_mutate()`). | ||
|
||
On this topic, it is important to understand that we are not actually concerned about the case values themselves. Rather we are concerned whether the quantity of cases in the future is a lot larger than that in the present. For this reason, the outcome does not remain as cases, but rather it is transformed by using either growth rates (as the predictors and outcome in our example are) or lagged differences. While the latter is closer to the requirements for the [2022-23 CDC Flusight Hospitalization Experimental Target](https://github.com/cdcepi/Flusight-forecast-data/blob/745511c436923e1dc201dea0f4181f21a8217b52/data-experimental/README.md), and it is conceptually easy to understand because it is simply the change of the value for the horizon, it is not the default. The default is `growth_rate`. One reason for this choice is because the growth rate is on a rate scale, not on the absolute scale, so it fosters comparability across locations without any conscious effort (on the other hand, when using the `lag_difference` one would need to take care to operate on rates per 100k and not raw counts). We utilize `epiprocess::growth_rate()` to create the outcome using some of the additional arguments. One important argument for the growth rate calculation is the `method`. Only `rel_change` for relative change should be used as the method because the test data is the only data that is accessible and the other methods require access to the training data. #%% Last sentence correct? | ||
|
||
The other optional arguments for controlling the growth rate calculation (that can be inputted as `additional_gr_args`) can be found in the documentation for [`epiprocess::growth_rate()`](https://cmu-delphi.github.io/epiprocess/reference/growth_rate.html) and the related [vignette](https://cmu-delphi.github.io/epiprocess/articles/growth_rate.html). | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
### Visualizing the results | ||
|
||
To visualize the prediction classes across the states for the target date, we can plot our results as a heatmap. However, if we were to plot the results for only one target date, like our 7-day ahead predictions, then that would be a pretty sad heatmap (which would look more like a bar chart than a heatmap)... So instead of doing that, let's get predictions for several aheads and plot a heatmap across the target dates. To get the predictions across several ahead values, we will use the map function in the same way that we did in other vignettes: | ||
|
||
```{r} | ||
multi_log_res <- map(1:14, ~ arx_classifier( | ||
jhu, | ||
outcome = "case_rate", | ||
predictors = "case_rate", | ||
args_list = arx_class_args_list( | ||
breaks = 0.25 / 7, | ||
ahead = .x, | ||
horizon = .x, | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
)$predictions) %>% list_rbind() | ||
``` | ||
|
||
Notice that the ahead and horizon are what we are are changing when we use `map()`. The reason we must also change `horizon` is because that is passed to the `h` argument of `epiprocess::growth_rate()` and determines the amount of data used to calculate the growth rate. So, for nearly all intents and purposes, it should be the same as `ahead`. | ||
|
||
We can plot a the heatmap of the results over the aheads to see if there's anything novel or interesting: | ||
|
||
```{r} | ||
ggplot(multi_log_res, aes(target_date, geo_value, fill = .pred_class)) + | ||
geom_tile() + | ||
ylab("State") + | ||
xlab("Target date") | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
While there is a bit of variability near to the beginning, we can clearly see that there are upswings for all states starting from the beginning of January 2022, which we can recall was when there was a massive spike in cases for many states. So our results seem to align well with what actually happened at the beginning of January 2022. | ||
|
||
## A brief reflection | ||
|
||
The most noticeable benefit of using the `arx_classifier()` function is the simplification and reduction of the manual implementation of the classifier from about 30 down to 3 lines. However, as we noted before, the trade-off for simplicity is control over the precise pre-processing, post-processing, and additional features embedded in the coding of a classifier. So the good thing is that `epipredict` provides both - a built-in `arx_classifer()` or the means to implement your own classifier from scratch by using the `epi_workflow` framework. And which you choose depends on the circumstances. Our advice is to start with using the built-in classifier for ostensibly simple projects and begin to implement your own when the modelling project takes a complicated turn. To get some practice on coding up a classifier by hand, consider to take this opportunity to translate this binary classification model example to an `epi_workflow`, akin to that in [Examples of Preprocessing and Models vignette](https://cmu-delphi.github.io/epipredict/articles/preprocessing-and-models.html). #%% takes a complex turn, increases in complexity | ||
rachlobay marked this conversation as resolved.
Show resolved
Hide resolved
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.