Skip to content

Commit cab925c

Browse files
Expand test coverage of forecast option combinations
1 parent 99fcf22 commit cab925c

File tree

2 files changed

+157
-14
lines changed

2 files changed

+157
-14
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,16 +1502,21 @@ def sample_statespace_matrices(
15021502

15031503
@staticmethod
15041504
def _validate_forecast_args(
1505-
time_index: pd.Index,
1505+
time_index: pd.RangeIndex | pd.DatetimeIndex,
15061506
start: int | pd.Timestamp,
15071507
periods: int | None = None,
15081508
end: int | pd.Timestamp = None,
15091509
scenario: pd.DataFrame | np.ndarray | None = None,
15101510
use_scenario_index: bool = False,
15111511
verbose: bool = True,
15121512
):
1513-
if start not in time_index:
1514-
raise ValueError("start must be in the data index used to fit the model.")
1513+
if isinstance(start, pd.Timestamp) and start not in time_index:
1514+
raise ValueError("Datetime start must be in the data index used to fit the model.")
1515+
elif isinstance(start, int):
1516+
if abs(start) > len(time_index):
1517+
raise ValueError(
1518+
"Integer start must be within the range of the data index used to fit the model."
1519+
)
15151520
if periods is None and end is None:
15161521
raise ValueError("Must specify one of either periods or end")
15171522
if periods is not None and end is not None:
@@ -1559,6 +1564,22 @@ def _validate_scenario_data(
15591564
name: str | None = None,
15601565
verbose=True,
15611566
):
1567+
"""
1568+
Validate the scenario data provided to the forecast method by checking that it has the correct shape and
1569+
dimensions.
1570+
1571+
Parameters
1572+
----------
1573+
scenario
1574+
name
1575+
verbose
1576+
1577+
Returns
1578+
-------
1579+
scenario: pd.DataFrame | np.ndarray | dict[str, pd.DataFrame | np.ndarray]
1580+
Scenario data, validated and potentially modified.
1581+
1582+
"""
15621583
if not self._needs_exog_data:
15631584
return scenario
15641585

@@ -1758,18 +1779,27 @@ def get_or_create_index(x, time_index, start=None):
17581779
forecast_index = None
17591780

17601781
if is_datetime:
1761-
freq = time_index.inferred_freq
1782+
freq = time_index.freq
17621783
if isinstance(start, int):
17631784
start = time_index[start]
1785+
if isinstance(end, int):
1786+
raise ValueError(
1787+
"end must be a timestamp if using a datetime index. To specify a number of "
1788+
"timesteps from the start date, use the periods argument instead."
1789+
)
17641790
if end is not None:
17651791
forecast_index = pd.date_range(start, end=end, freq=freq)
17661792
if periods is not None:
1767-
# date_range include both start and end, but we're going to pop off the start later (it will be
1768-
# interpreted as x0). So we need to add 1 to the periods so the user gets "periods" number of
1769-
# forecasts back
1793+
# date_range includes both the start and end date, but we're going to pop off the start later
1794+
# (it will be interpreted as x0). So we need to add 1 to the periods so the user gets "periods"
1795+
# number of forecasts back
17701796
forecast_index = pd.date_range(start, periods=periods + 1, freq=freq)
17711797

17721798
else:
1799+
# If the user provided a positive integer as start, directly interpret it as the start time. If its
1800+
# negative, interpret it as a positional index.
1801+
if start < 0:
1802+
start = time_index[start]
17731803
if end is not None:
17741804
forecast_index = pd.RangeIndex(start, end, step=1, dtype="int")
17751805
if periods is not None:

tests/statespace/test_statespace.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ def pymc_mod(ss_mod):
114114
return pymc_mod
115115

116116

117+
@pytest.fixture(scope="session")
118+
def ss_mod_no_exog(rng):
119+
ll = st.LevelTrendComponent(order=2, innovations_order=1)
120+
return ll.build()
121+
122+
123+
@pytest.fixture(scope="session")
124+
def ss_mod_no_exog_dt(rng):
125+
ll = st.LevelTrendComponent(order=2, innovations_order=1)
126+
return ll.build()
127+
128+
117129
@pytest.fixture(scope="session")
118130
def exog_ss_mod(rng):
119131
ll = st.LevelTrendComponent()
@@ -143,6 +155,42 @@ def exog_pymc_mod(exog_ss_mod, rng):
143155
return m
144156

145157

158+
@pytest.fixture(scope="session")
159+
def pymc_mod_no_exog(ss_mod_no_exog, rng):
160+
y = pd.DataFrame(rng.normal(size=(100, 1)).astype(floatX), columns=["y"])
161+
162+
with pm.Model(coords=ss_mod_no_exog.coords) as m:
163+
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
164+
P0_sigma = pm.Exponential("P0_sigma", 1)
165+
P0 = pm.Deterministic(
166+
"P0", pt.eye(ss_mod_no_exog.k_states) * P0_sigma, dims=["state", "state_aux"]
167+
)
168+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
169+
ss_mod_no_exog.build_statespace_graph(y)
170+
171+
return m
172+
173+
174+
@pytest.fixture(scope="session")
175+
def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
176+
y = pd.DataFrame(
177+
rng.normal(size=(100, 1)).astype(floatX),
178+
columns=["y"],
179+
index=pd.date_range("2020-01-01", periods=100, freq="D"),
180+
)
181+
182+
with pm.Model(coords=ss_mod_no_exog_dt.coords) as m:
183+
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
184+
P0_sigma = pm.Exponential("P0_sigma", 1)
185+
P0 = pm.Deterministic(
186+
"P0", pt.eye(ss_mod_no_exog_dt.k_states) * P0_sigma, dims=["state", "state_aux"]
187+
)
188+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
189+
ss_mod_no_exog_dt.build_statespace_graph(y)
190+
191+
return m
192+
193+
146194
@pytest.fixture(scope="session")
147195
def idata(pymc_mod, rng):
148196
with pymc_mod:
@@ -162,6 +210,24 @@ def idata_exog(exog_pymc_mod, rng):
162210
return idata
163211

164212

213+
@pytest.fixture(scope="session")
214+
def idata_no_exog(pymc_mod_no_exog, rng):
215+
with pymc_mod_no_exog:
216+
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
217+
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
218+
idata.extend(idata_prior)
219+
return idata
220+
221+
222+
@pytest.fixture(scope="session")
223+
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
224+
with pymc_mod_no_exog_dt:
225+
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
226+
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
227+
idata.extend(idata_prior)
228+
return idata
229+
230+
165231
def test_invalid_filter_name_raises():
166232
msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys()))
167233
with pytest.raises(NotImplementedError, match=msg):
@@ -664,28 +730,75 @@ def test_invalid_scenarios():
664730
ss_mod._validate_scenario_data(scenario)
665731

666732

733+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
667734
@pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
668-
def test_forecast(filter_output, ss_mod, idata, rng):
669-
time_idx = idata.posterior.coords["time"].values
670-
forecast_idata = ss_mod.forecast(
671-
idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng
735+
@pytest.mark.parametrize(
736+
"mod_name, idata_name, start, end, periods",
737+
[
738+
("ss_mod_no_exog", "idata_no_exog", None, None, 10),
739+
("ss_mod_no_exog", "idata_no_exog", -1, None, 10),
740+
("ss_mod_no_exog", "idata_no_exog", 10, None, 10),
741+
("ss_mod_no_exog", "idata_no_exog", 10, 21, None),
742+
("ss_mod_no_exog_dt", "idata_no_exog_dt", None, None, 10),
743+
("ss_mod_no_exog_dt", "idata_no_exog_dt", -1, None, 10),
744+
("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, None, 10),
745+
("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, "2020-01-21", None),
746+
("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", "2020-03-11", None),
747+
("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", None, 10),
748+
],
749+
ids=[
750+
"range_default",
751+
"range_negative",
752+
"range_int",
753+
"range_end",
754+
"datetime_default",
755+
"datetime_negative",
756+
"datetime_int",
757+
"datetime_int_end",
758+
"datetime_datetime_end",
759+
"datetime_datetime",
760+
],
761+
)
762+
def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng, request):
763+
mod = request.getfixturevalue(mod_name)
764+
idata = request.getfixturevalue(idata_name)
765+
time_idx = mod._get_fit_time_index()
766+
is_datetime = isinstance(time_idx, pd.DatetimeIndex)
767+
768+
if isinstance(start, str):
769+
t0 = pd.Timestamp(start)
770+
elif isinstance(start, int):
771+
t0 = time_idx[start]
772+
else:
773+
t0 = time_idx[-1]
774+
775+
delta = time_idx.freq if is_datetime else 1
776+
777+
forecast_idata = mod.forecast(
778+
idata, start=start, end=end, periods=periods, filter_output=filter_output, random_seed=rng
672779
)
673780

674-
assert forecast_idata.coords["time"].values.shape == (10,)
781+
forecast_idx = forecast_idata.coords["time"].values
782+
forecast_idx = pd.DatetimeIndex(forecast_idx) if is_datetime else pd.Index(forecast_idx)
783+
784+
assert forecast_idx.shape == (10,)
675785
assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
676786
assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
677787

678788
assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
679789
assert not np.any(np.isnan(forecast_idata.forecast_observed.values))
680790

791+
assert forecast_idx[0] == (t0 + delta)
792+
681793

682794
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
683-
def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog):
795+
@pytest.mark.parametrize("start", [None, -1, 10])
796+
def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
684797
scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"])
685798
scenario.iloc[5, 0] = 1e9
686799

687800
forecast_idata = exog_ss_mod.forecast(
688-
idata_exog, periods=10, random_seed=rng, scenario=scenario
801+
idata_exog, start=start, periods=10, random_seed=rng, scenario=scenario
689802
)
690803

691804
components = exog_ss_mod.extract_components_from_idata(forecast_idata)

0 commit comments

Comments
 (0)