Skip to content

Commit 9f23987

Browse files
committed
added logic to update static shape of target when forecasting with exogenous variables
1 parent bca5360 commit 9f23987

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,14 @@ def forecast(
20692069
data_dims=["data_time", OBS_STATE_DIM],
20702070
)
20712071

2072+
for name in self.data_names:
2073+
if name in scenario.keys():
2074+
pm.set_data(
2075+
{"data": np.zeros_like(scenario[name])},
2076+
coords={"data_time": np.arange(len(forecast_index))},
2077+
)
2078+
break
2079+
20722080
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
20732081
mu, cov = grouped_outputs[group_idx]
20742082

0 commit comments

Comments
 (0)