@@ -114,6 +114,18 @@ def pymc_mod(ss_mod):
114
114
return pymc_mod
115
115
116
116
117
+ @pytest .fixture (scope = "session" )
118
+ def ss_mod_no_exog (rng ):
119
+ ll = st .LevelTrendComponent (order = 2 , innovations_order = 1 )
120
+ return ll .build ()
121
+
122
+
123
+ @pytest .fixture (scope = "session" )
124
+ def ss_mod_no_exog_dt (rng ):
125
+ ll = st .LevelTrendComponent (order = 2 , innovations_order = 1 )
126
+ return ll .build ()
127
+
128
+
117
129
@pytest .fixture (scope = "session" )
118
130
def exog_ss_mod (rng ):
119
131
ll = st .LevelTrendComponent ()
@@ -143,6 +155,42 @@ def exog_pymc_mod(exog_ss_mod, rng):
143
155
return m
144
156
145
157
158
+ @pytest .fixture (scope = "session" )
159
+ def pymc_mod_no_exog (ss_mod_no_exog , rng ):
160
+ y = pd .DataFrame (rng .normal (size = (100 , 1 )).astype (floatX ), columns = ["y" ])
161
+
162
+ with pm .Model (coords = ss_mod_no_exog .coords ) as m :
163
+ initial_trend = pm .Normal ("initial_trend" , dims = ["trend_state" ])
164
+ P0_sigma = pm .Exponential ("P0_sigma" , 1 )
165
+ P0 = pm .Deterministic (
166
+ "P0" , pt .eye (ss_mod_no_exog .k_states ) * P0_sigma , dims = ["state" , "state_aux" ]
167
+ )
168
+ sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["trend_shock" ])
169
+ ss_mod_no_exog .build_statespace_graph (y )
170
+
171
+ return m
172
+
173
+
174
+ @pytest .fixture (scope = "session" )
175
+ def pymc_mod_no_exog_dt (ss_mod_no_exog_dt , rng ):
176
+ y = pd .DataFrame (
177
+ rng .normal (size = (100 , 1 )).astype (floatX ),
178
+ columns = ["y" ],
179
+ index = pd .date_range ("2020-01-01" , periods = 100 , freq = "D" ),
180
+ )
181
+
182
+ with pm .Model (coords = ss_mod_no_exog_dt .coords ) as m :
183
+ initial_trend = pm .Normal ("initial_trend" , dims = ["trend_state" ])
184
+ P0_sigma = pm .Exponential ("P0_sigma" , 1 )
185
+ P0 = pm .Deterministic (
186
+ "P0" , pt .eye (ss_mod_no_exog_dt .k_states ) * P0_sigma , dims = ["state" , "state_aux" ]
187
+ )
188
+ sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["trend_shock" ])
189
+ ss_mod_no_exog_dt .build_statespace_graph (y )
190
+
191
+ return m
192
+
193
+
146
194
@pytest .fixture (scope = "session" )
147
195
def idata (pymc_mod , rng ):
148
196
with pymc_mod :
@@ -162,6 +210,24 @@ def idata_exog(exog_pymc_mod, rng):
162
210
return idata
163
211
164
212
213
+ @pytest .fixture (scope = "session" )
214
+ def idata_no_exog (pymc_mod_no_exog , rng ):
215
+ with pymc_mod_no_exog :
216
+ idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
217
+ idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
218
+ idata .extend (idata_prior )
219
+ return idata
220
+
221
+
222
+ @pytest .fixture (scope = "session" )
223
+ def idata_no_exog_dt (pymc_mod_no_exog_dt , rng ):
224
+ with pymc_mod_no_exog_dt :
225
+ idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
226
+ idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
227
+ idata .extend (idata_prior )
228
+ return idata
229
+
230
+
165
231
def test_invalid_filter_name_raises ():
166
232
msg = "The following are valid filter types: " + ", " .join (list (FILTER_FACTORY .keys ()))
167
233
with pytest .raises (NotImplementedError , match = msg ):
@@ -664,28 +730,75 @@ def test_invalid_scenarios():
664
730
ss_mod ._validate_scenario_data (scenario )
665
731
666
732
733
+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
667
734
@pytest .mark .parametrize ("filter_output" , ["predicted" , "filtered" , "smoothed" ])
668
- def test_forecast (filter_output , ss_mod , idata , rng ):
669
- time_idx = idata .posterior .coords ["time" ].values
670
- forecast_idata = ss_mod .forecast (
671
- idata , start = time_idx [- 1 ], periods = 10 , filter_output = filter_output , random_seed = rng
735
+ @pytest .mark .parametrize (
736
+ "mod_name, idata_name, start, end, periods" ,
737
+ [
738
+ ("ss_mod_no_exog" , "idata_no_exog" , None , None , 10 ),
739
+ ("ss_mod_no_exog" , "idata_no_exog" , - 1 , None , 10 ),
740
+ ("ss_mod_no_exog" , "idata_no_exog" , 10 , None , 10 ),
741
+ ("ss_mod_no_exog" , "idata_no_exog" , 10 , 21 , None ),
742
+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , None , None , 10 ),
743
+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , - 1 , None , 10 ),
744
+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , 10 , None , 10 ),
745
+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , 10 , "2020-01-21" , None ),
746
+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , "2020-03-01" , "2020-03-11" , None ),
747
+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , "2020-03-01" , None , 10 ),
748
+ ],
749
+ ids = [
750
+ "range_default" ,
751
+ "range_negative" ,
752
+ "range_int" ,
753
+ "range_end" ,
754
+ "datetime_default" ,
755
+ "datetime_negative" ,
756
+ "datetime_int" ,
757
+ "datetime_int_end" ,
758
+ "datetime_datetime_end" ,
759
+ "datetime_datetime" ,
760
+ ],
761
+ )
762
+ def test_forecast (filter_output , mod_name , idata_name , start , end , periods , rng , request ):
763
+ mod = request .getfixturevalue (mod_name )
764
+ idata = request .getfixturevalue (idata_name )
765
+ time_idx = mod ._get_fit_time_index ()
766
+ is_datetime = isinstance (time_idx , pd .DatetimeIndex )
767
+
768
+ if isinstance (start , str ):
769
+ t0 = pd .Timestamp (start )
770
+ elif isinstance (start , int ):
771
+ t0 = time_idx [start ]
772
+ else :
773
+ t0 = time_idx [- 1 ]
774
+
775
+ delta = time_idx .freq if is_datetime else 1
776
+
777
+ forecast_idata = mod .forecast (
778
+ idata , start = start , end = end , periods = periods , filter_output = filter_output , random_seed = rng
672
779
)
673
780
674
- assert forecast_idata .coords ["time" ].values .shape == (10 ,)
781
+ forecast_idx = forecast_idata .coords ["time" ].values
782
+ forecast_idx = pd .DatetimeIndex (forecast_idx ) if is_datetime else pd .Index (forecast_idx )
783
+
784
+ assert forecast_idx .shape == (10 ,)
675
785
assert forecast_idata .forecast_latent .dims == ("chain" , "draw" , "time" , "state" )
676
786
assert forecast_idata .forecast_observed .dims == ("chain" , "draw" , "time" , "observed_state" )
677
787
678
788
assert not np .any (np .isnan (forecast_idata .forecast_latent .values ))
679
789
assert not np .any (np .isnan (forecast_idata .forecast_observed .values ))
680
790
791
+ assert forecast_idx [0 ] == (t0 + delta )
792
+
681
793
682
794
@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
683
- def test_forecast_with_exog_data (rng , exog_ss_mod , idata_exog ):
795
+ @pytest .mark .parametrize ("start" , [None , - 1 , 10 ])
796
+ def test_forecast_with_exog_data (rng , exog_ss_mod , idata_exog , start ):
684
797
scenario = pd .DataFrame (np .zeros ((10 , 3 )), columns = ["a" , "b" , "c" ])
685
798
scenario .iloc [5 , 0 ] = 1e9
686
799
687
800
forecast_idata = exog_ss_mod .forecast (
688
- idata_exog , periods = 10 , random_seed = rng , scenario = scenario
801
+ idata_exog , start = start , periods = 10 , random_seed = rng , scenario = scenario
689
802
)
690
803
691
804
components = exog_ss_mod .extract_components_from_idata (forecast_idata )
0 commit comments