Skip to content

Commit 10e5df2

Browse files
authored
Merge pull request #137 from cmu-delphi/98-vignette-on-parsnip-and-recipes
population scaling function adjustments
2 parents a7922cf + 8940f6d commit 10e5df2

File tree

3 files changed

+135
-46
lines changed

3 files changed

+135
-46
lines changed

R/step_population_scaling.R

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
#' Create a recipe step that scales variables using population data
1+
#' Convert raw scale predictions to per-capita
22
#'
33
#' `step_population_scaling` creates a specification of a recipe step
4-
#' that will add a population scaled column in the data. For example,
5-
#' load a dataset that contains county population, and join to an `epi_df`
6-
#' that currently only contains number of new cases by county. Once scaled,
7-
#' predictions can be made on case rate. Although worth noting that there is
8-
#' nothing special about "population". The function can be used to scale by any
9-
#' variable. Population is simply the most natural and common use case.
4+
#' that will perform per-capita scaling. Typical usage would
5+
#' load a dataset that contains state-level population, and use it to convert
6+
#' predictions made from a raw scale model to rate-scale by dividing by
7+
#' the population.
8+
#' Although, it is worth noting that there is nothing special about "population".
9+
#' The function can be used to scale by any variable. Population is the
10+
#' standard use case in the epidemiology forecasting scenario. Any value
11+
#' passed will *divide* the selected variables while the `rate_rescaling`
12+
#' argument is a common *multiplier* of the selected variables.
1013
#'
1114
#' @param recipe A recipe object. The step will be added to the sequence of
1215
#' operations for this recipe. The recipe should contain information about the
@@ -19,25 +22,30 @@
1922
#' be ard are not limited to "outcome".
2023
#' @param trained A logical to indicate if the quantities for preprocessing
2124
#' have been estimated.
22-
#' @param df a data frame that contains the population data used for scaling.
23-
#' @param by A character vector of variables to left join by.
25+
#' @param df a data frame that contains the population data to be used for
26+
#' inverting the existing scaling.
27+
#' @param by A (possibly named) character vector of variables to join by.
2428
#'
2529
#' If `NULL`, the default, the function will perform a natural join, using all
26-
#' variables in common across the `epi_df` and the user-provided dataset.
27-
#' If columns in `epi_df` and `df` have the same name (and aren't
28-
#' included in by), `.df` is added to the one from the user-provided data
30+
#' variables in common across the `epi_df` produced by the `predict()` call
31+
#' and the user-provided dataset.
32+
#' If columns in that `epi_df` and `df` have the same name (and aren't
33+
#' included in `by`), `.df` is added to the one from the user-provided data
2934
#' to disambiguate.
3035
#'
3136
#' To join by different variables on the `epi_df` and `df`, use a named vector.
32-
#' For example, by = c("geo_value" = "states") will match `epi_df$geo_value`
37+
#' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
3338
#' to `df$states`. To join by multiple variables, use a vector with length > 1.
34-
#' For example, by = c("geo_value" = "states", "county" = "county") will match
39+
#' For example, `by = c("geo_value" = "states", "county" = "county")` will match
3540
#' `epi_df$geo_value` to `df$states` and `epi_df$county` to `df$county`.
3641
#'
3742
#' See [dplyr::left_join()] for more details.
3843
#' @param df_pop_col the name of the column in the data frame `df` that
3944
#' contains the population data and will be used for scaling.
4045
#' This should be one column.
46+
#' @param rate_rescaling Sometimes raw scales are "per 100K" or "per 1M".
47+
#' Adjustments can be made here. For example, if the original
48+
#' scale is "per 100K", then set `rate_rescaling = 1e5` to get rates.
4149
#' @param create_new TRUE to create a new column and keep the original column
4250
#' in the `epi_df`
4351
#' @param suffix a character. The suffix added to the column name if
@@ -61,8 +69,7 @@
6169
#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
6270
#' dplyr::select(geo_value, time_value, cases)
6371
#'
64-
#' pop_data = data.frame(states = c("ca", "ny"),
65-
#' value = c(20000, 30000))
72+
#' pop_data = data.frame(states = c("ca", "ny"), value = c(20000, 30000))
6673
#'
6774
#' r <- epi_recipe(jhu) %>%
6875
#' step_population_scaling(df = pop_data,
@@ -86,11 +93,12 @@
8693
#' parsnip::fit(jhu) %>%
8794
#' add_frosting(f)
8895
#'
89-
#' latest <- get_test_data(recipe = r,
90-
#' x = epiprocess::jhu_csse_daily_subset %>%
91-
#' dplyr::filter(time_value > "2021-11-01",
92-
#' geo_value %in% c("ca", "ny")) %>%
93-
#' dplyr::select(geo_value, time_value, cases))
96+
#' latest <- get_test_data(
97+
#' recipe = r,
98+
#' x = epiprocess::jhu_csse_daily_subset %>%
99+
#' dplyr::filter(time_value > "2021-11-01",
100+
#' geo_value %in% c("ca", "ny")) %>%
101+
#' dplyr::select(geo_value, time_value, cases))
94102
#'
95103
#'
96104
#' predict(wf, latest)
@@ -102,11 +110,19 @@ step_population_scaling <-
102110
df,
103111
by = NULL,
104112
df_pop_col,
113+
rate_rescaling = 1,
105114
create_new = TRUE,
106115
suffix = "_scaled",
107116
columns = NULL,
108117
skip = FALSE,
109118
id = rand_id("population_scaling")){
119+
arg_is_scalar(role, trained, df_pop_col, rate_rescaling, create_new, suffix, id)
120+
arg_is_lgl(create_new, skip)
121+
arg_is_chr(df_pop_col, suffix, id)
122+
arg_is_chr(by, columns, allow_null = TRUE)
123+
if (rate_rescaling <= 0)
124+
cli_stop("`rate_rescaling` should be a positive number")
125+
110126
add_step(
111127
recipe,
112128
step_population_scaling_new(
@@ -116,6 +132,7 @@ step_population_scaling <-
116132
df = df,
117133
by = by,
118134
df_pop_col = df_pop_col,
135+
rate_rescaling = rate_rescaling,
119136
create_new = create_new,
120137
suffix = suffix,
121138
columns = columns,
@@ -126,7 +143,7 @@ step_population_scaling <-
126143
}
127144

128145
step_population_scaling_new <-
129-
function(role, trained, df, by, df_pop_col, terms, create_new,
146+
function(role, trained, df, by, df_pop_col, rate_rescaling, terms, create_new,
130147
suffix, columns, skip, id) {
131148
step(
132149
subclass = "population_scaling",
@@ -136,6 +153,7 @@ step_population_scaling_new <-
136153
df = df,
137154
by = by,
138155
df_pop_col = df_pop_col,
156+
rate_rescaling = rate_rescaling,
139157
create_new = create_new,
140158
suffix = suffix,
141159
columns = columns,
@@ -153,6 +171,7 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {
153171
df = x$df,
154172
by = x$by,
155173
df_pop_col = x$df_pop_col,
174+
rate_rescaling = x$rate_rescaling,
156175
create_new = x$create_new,
157176
suffix = x$suffix,
158177
columns = recipes_eval_select(x$terms, training, info),
@@ -172,8 +191,9 @@ bake.step_population_scaling <- function(object,
172191
try_join <- try(dplyr::left_join(new_data, object$df,
173192
by= object$by),
174193
silent = TRUE)
175-
if (any(grepl("Join columns must be present in data", unlist(try_join)))){
176-
stop("columns in `by` selectors of `step_population_scaling` must be present in data and match")}
194+
if (any(grepl("Join columns must be present in data", unlist(try_join)))) {
195+
cli_stop(c("columns in `by` selectors of `step_population_scaling` ",
196+
"must be present in data and match"))}
177197

178198
if(object$suffix != "_scaled" && object$create_new == FALSE){
179199
message("`suffix` not used to generate new column in `step_population_scaling`")
@@ -194,7 +214,7 @@ bake.step_population_scaling <- function(object,
194214
dplyr::mutate(
195215
dplyr::across(
196216
dplyr::all_of(object$columns),
197-
~.x/!!pop_col ,
217+
~.x * object$rate_rescaling /!!pop_col ,
198218
.names = "{.col}{suffix}")) %>%
199219
# removed so the models do not use the population column
200220
dplyr::select(- !!pop_col)

man/step_population_scaling.Rd

Lines changed: 31 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-population_scaling.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,62 @@ test_that("expect error if `by` selector does not match", {
295295
)
296296
})
297297

298+
299+
test_that("Rate rescaling behaves as expected", {
300+
x <- tibble(geo_value = rep("place",50),
301+
time_value = as.Date("2021-01-01") + 0:49,
302+
case_rate = rep(0.0005, 50),
303+
cases = rep(5000, 50)) %>%
304+
as_epi_df()
305+
306+
reverse_pop_data = data.frame(states = c("place"),
307+
value = c(1/1000))
308+
309+
r <- epi_recipe(x) %>%
310+
step_population_scaling(df = reverse_pop_data,
311+
df_pop_col = "value",
312+
rate_rescaling = 100, # cases per 100
313+
by = c("geo_value" = "states"),
314+
case_rate, suffix = "_scaled")
315+
316+
expect_equal(unique(bake(prep(r,x),x)$case_rate_scaled),
317+
0.0005*100/(1/1000)) # done testing step_*
318+
319+
f <- frosting() %>%
320+
layer_population_scaling(.pred, df = reverse_pop_data,
321+
rate_rescaling = 100, # revert back to case rate per 100
322+
by = c("geo_value" = "states"),
323+
df_pop_col = "value")
324+
325+
x <- tibble(geo_value = rep("place",50),
326+
time_value = as.Date("2021-01-01") + 0:49,
327+
case_rate = rep(0.0005, 50)) %>%
328+
as_epi_df()
329+
330+
r <- epi_recipe(x) %>%
331+
step_epi_lag(case_rate, lag = c(7, 14)) %>% # cases
332+
step_epi_ahead(case_rate, ahead = 7, role = "outcome") %>% # cases
333+
step_naomit(all_predictors()) %>%
334+
step_naomit(all_outcomes(), skip = TRUE)
335+
336+
f <- frosting() %>%
337+
layer_predict() %>%
338+
layer_threshold(.pred) %>%
339+
layer_naomit(.pred) %>%
340+
layer_population_scaling(.pred, df = reverse_pop_data,
341+
rate_rescaling = 100, # revert back to case rate per 100
342+
by = c("geo_value" = "states"),
343+
df_pop_col = "value")
344+
345+
wf <- epi_workflow(r, parsnip::linear_reg()) %>%
346+
fit(x) %>%
347+
add_frosting(f)
348+
349+
latest <- get_test_data(recipe = r, x = x)
350+
351+
# suppress warning: prediction from a rank-deficient fit may be misleading
352+
suppressWarnings(expect_equal(unique(predict(wf, latest)$.pred)*(1/1000)/100,
353+
unique(predict(wf, latest)$.pred_scaled)))
354+
})
355+
356+

0 commit comments

Comments
 (0)