@@ -128,32 +128,51 @@ def ss_mod_no_exog_dt(rng):
128
128
129
129
130
130
@pytest .fixture (scope = "session" )
131
- def exog_ss_mod (rng ):
132
- ll = st .LevelTrendComponent ()
133
- reg = st .RegressionComponent (name = "exog" , state_names = ["a" , "b" , "c" ])
134
- mod = (ll + reg ).build (verbose = False )
131
+ def exog_data (rng ):
132
+ # simulate data
133
+ df = pd .DataFrame (
134
+ {
135
+ "date" : pd .date_range (start = "2023-05-01" , end = "2023-05-10" , freq = "D" ),
136
+ "x1" : rng .choice (2 , size = 10 , replace = True ).astype (float ),
137
+ "y" : rng .normal (size = (10 ,)),
138
+ }
139
+ )
135
140
136
- return mod
141
+ df .loc [[1 , 3 , 9 ], ["y" ]] = np .nan
142
+ return df .set_index ("date" )
137
143
138
144
139
145
@pytest .fixture (scope = "session" )
140
- def exog_pymc_mod (exog_ss_mod , rng ):
141
- y = rng .normal (size = (100 , 1 )).astype (floatX )
142
- X = rng .normal (size = (100 , 3 )).astype (floatX )
146
+ def exog_ss_mod (exog_data ):
147
+ level_trend = st .LevelTrendComponent (order = 1 , innovations_order = [0 ])
148
+ exog = st .RegressionComponent (
149
+ name = "exog" , # Name of this exogenous variable component
150
+ k_exog = 1 , # Only one exogenous variable now
151
+ innovations = False , # Typically fixed effect (no stochastic evolution)
152
+ state_names = exog_data [["x1" ]].columns .tolist (),
153
+ )
143
154
144
- with pm .Model (coords = exog_ss_mod .coords ) as m :
145
- exog_data = pm .Data ("data_exog" , X )
146
- initial_trend = pm .Normal ("initial_trend" , dims = ["trend_state" ])
147
- P0_sigma = pm .Exponential ("P0_sigma" , 1 )
148
- P0 = pm .Deterministic (
149
- "P0" , pt .eye (exog_ss_mod .k_states ) * P0_sigma , dims = ["state" , "state_aux" ]
155
+ combined_model = level_trend + exog
156
+ return combined_model .build ()
157
+
158
+
159
+ @pytest .fixture (scope = "session" )
160
+ def exog_pymc_mod (exog_ss_mod , exog_data ):
161
+ # define pymc model
162
+ with pm .Model (coords = exog_ss_mod .coords ) as struct_model :
163
+ P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 4 , dims = ["state" ])
164
+ P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = ["state" , "state_aux" ])
165
+
166
+ initial_trend = pm .Normal ("initial_trend" , mu = [0 ], sigma = [0.005 ], dims = ["trend_state" ])
167
+
168
+ data_exog = pm .Data (
169
+ "data_exog" , exog_data ["x1" ].values [:, None ], dims = ["time" , "exog_state" ]
150
170
)
151
- beta_exog = pm .Normal ("beta_exog" , dims = ["exog_state" ])
171
+ beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["exog_state" ])
152
172
153
- sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["trend_shock" ])
154
- exog_ss_mod .build_statespace_graph (y , save_kalman_filter_outputs_in_idata = True )
173
+ exog_ss_mod .build_statespace_graph (exog_data ["y" ])
155
174
156
- return m
175
+ return struct_model
157
176
158
177
159
178
@pytest .fixture (scope = "session" )
@@ -844,10 +863,14 @@ def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng,
844
863
assert forecast_idx [0 ] == (t0 + delta )
845
864
846
865
866
+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
867
+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
847
868
@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
848
- @pytest .mark .parametrize ("start" , [None , - 1 , 10 ])
869
+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
870
+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
871
+ @pytest .mark .parametrize ("start" , [None , - 1 , 5 ])
849
872
def test_forecast_with_exog_data (rng , exog_ss_mod , idata_exog , start ):
850
- scenario = pd .DataFrame (np .zeros ((10 , 3 )), columns = ["a" , "b" , "c " ])
873
+ scenario = pd .DataFrame (np .zeros ((10 , 1 )), columns = ["x1 " ])
851
874
scenario .iloc [5 , 0 ] = 1e9
852
875
853
876
forecast_idata = exog_ss_mod .forecast (
@@ -856,17 +879,50 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
856
879
857
880
components = exog_ss_mod .extract_components_from_idata (forecast_idata )
858
881
level = components .forecast_latent .sel (state = "LevelTrend[level]" )
859
- betas = components .forecast_latent .sel (state = ["exog[a]" , "exog[b]" , "exog[c ]" ])
882
+ betas = components .forecast_latent .sel (state = ["exog[x1 ]" ])
860
883
861
884
scenario .index .name = "time"
862
885
scenario_xr = (
863
886
scenario .unstack ()
864
887
.to_xarray ()
865
888
.rename ({"level_0" : "state" })
866
- .assign_coords (state = ["exog[a]" , "exog[b]" , "exog[c ]" ])
889
+ .assign_coords (state = ["exog[x1 ]" ])
867
890
)
868
891
869
892
regression_effect = forecast_idata .forecast_observed .isel (observed_state = 0 ) - level
870
893
regression_effect_expected = (betas * scenario_xr ).sum (dim = ["state" ])
871
894
872
895
assert_allclose (regression_effect , regression_effect_expected )
896
+
897
+
898
+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
899
+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
900
+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
901
+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
902
+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
903
+ def test_foreacast_valid_index (exog_pymc_mod , exog_ss_mod , exog_data ):
904
+ # Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
905
+ with exog_pymc_mod :
906
+ idata = pm .sample_prior_predictive ()
907
+
908
+ # Define start date and forecast period
909
+ start_date , n_periods = pd .to_datetime ("2023-05-05" ), 5
910
+
911
+ # Extract exogenous data for the forecast period
912
+ scenario = {
913
+ "data_exog" : pd .DataFrame (
914
+ exog_data [["x1" ]].loc [start_date :].iloc [:n_periods ], columns = exog_data [["x1" ]].columns
915
+ )
916
+ }
917
+
918
+ # Generate the forecast
919
+ forecasts = exog_ss_mod .forecast (idata .prior , scenario = scenario , use_scenario_index = True )
920
+ assert "forecast_latent" in forecasts
921
+ assert "forecast_observed" in forecasts
922
+
923
+ assert (forecasts .coords ["time" ].values == scenario ["data_exog" ].index .values ).all ()
924
+ assert not np .any (np .isnan (forecasts .forecast_latent .values ))
925
+ assert not np .any (np .isnan (forecasts .forecast_observed .values ))
926
+
927
+ assert forecasts .forecast_latent .shape [2 ] == n_periods
928
+ assert forecasts .forecast_observed .shape [2 ] == n_periods
0 commit comments