|
| 1 | +library(dplyr) |
| 2 | +train <- jhu_csse_daily_subset |> |
| 3 | + filter(time_value >= as.Date("2021-10-01")) |> |
| 4 | + select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av) |
| 5 | +ngeos <- n_distinct(train$geo_value) |
| 6 | + |
| 7 | +test_that("flatline determines target_date where forecast_date exists", { |
| 8 | + |
| 9 | + flat <- flatline_forecaster( |
| 10 | + train, "dr", |
| 11 | + args_list = flatline_args_list( |
| 12 | + forecast_date = as.Date("2021-12-31"), |
| 13 | + target_date = as.Date("2022-01-01"), |
| 14 | + ahead = 1L |
| 15 | + ) |
| 16 | + ) |
| 17 | + |
| 18 | + # previously, if target_date existed, it could be |
| 19 | + # erroneously incremented by the ahead |
| 20 | + expect_identical( |
| 21 | + flat$predictions$target_date, |
| 22 | + rep(as.Date("2022-01-01"), ngeos) |
| 23 | + ) |
| 24 | + expect_identical( |
| 25 | + flat$predictions$forecast_date, |
| 26 | + rep(as.Date("2021-12-31"), ngeos) |
| 27 | + ) |
| 28 | + |
| 29 | + # potentially resulted in NA predictions |
| 30 | + # see #290 https://github.com/cmu-delphi/epipredict/issues/290 |
| 31 | + expect_true(all(!is.na(flat$predictions$.pred_distn))) |
| 32 | + expect_true(all(!is.na(flat$predictions$.pred))) |
| 33 | +}) |
| 34 | + |
| 35 | +test_that("arx_forecaster determines target_date where forecast_date exists", { |
| 36 | + |
| 37 | + arx <- arx_forecaster( |
| 38 | + train, "dr", c("dr", "cr"), |
| 39 | + args_list = arx_args_list( |
| 40 | + forecast_date = as.Date("2021-12-31"), |
| 41 | + target_date = as.Date("2022-01-01"), |
| 42 | + ahead = 1L |
| 43 | + ) |
| 44 | + ) |
| 45 | + # previously, if target_date existed, it could be |
| 46 | + # erroneously incremented by the ahead |
| 47 | + expect_identical( |
| 48 | + arx$predictions$target_date, |
| 49 | + rep(as.Date("2022-01-01"), ngeos) |
| 50 | + ) |
| 51 | + expect_identical( |
| 52 | + arx$predictions$forecast_date, |
| 53 | + rep(as.Date("2021-12-31"), ngeos) |
| 54 | + ) |
| 55 | + |
| 56 | + # potentially resulted in NA predictions |
| 57 | + # see #290 https://github.com/cmu-delphi/epipredict/issues/290 |
| 58 | + expect_true(all(!is.na(arx$predictions$.pred_distn))) |
| 59 | + expect_true(all(!is.na(arx$predictions$.pred))) |
| 60 | +}) |
| 61 | + |
| 62 | +test_that("arx_classifier determines target_date where forecast_date exists", { |
| 63 | + |
| 64 | + |
| 65 | + arx <- arx_classifier( |
| 66 | + train, "dr", c("dr"), |
| 67 | + trainer = parsnip::nearest_neighbor(mode = "classification"), |
| 68 | + args_list = arx_class_args_list( |
| 69 | + forecast_date = as.Date("2021-12-31"), |
| 70 | + target_date = as.Date("2022-01-01"), |
| 71 | + ahead = 1L |
| 72 | + ) |
| 73 | + ) |
| 74 | + |
| 75 | + # previously, if target_date existed, it could be |
| 76 | + # erroneously incremented by the ahead |
| 77 | + expect_identical( |
| 78 | + arx$predictions$target_date, |
| 79 | + rep(as.Date("2022-01-01"), ngeos) |
| 80 | + ) |
| 81 | + expect_identical( |
| 82 | + arx$predictions$forecast_date, |
| 83 | + rep(as.Date("2021-12-31"), ngeos) |
| 84 | + ) |
| 85 | + |
| 86 | + # potentially resulted in NA predictions |
| 87 | + # see #290 https://github.com/cmu-delphi/epipredict/issues/290 |
| 88 | + expect_true(all(!is.na(arx$predictions$.pred_class))) |
| 89 | +}) |
0 commit comments