Skip to content

Error message from build_statespace_graph when cycle is one of the model components. #281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
rklees opened this issue Dec 13, 2023 · 0 comments · Fixed by #288
Closed

Error message from build_statespace_graph when cycle is one of the model components. #281

rklees opened this issue Dec 13, 2023 · 0 comments · Fixed by #288

Comments

@rklees
Copy link

rklees commented Dec 13, 2023

When I build a PyMC statespace model with pymc-experimental, I always get an error message when adding a cycle to the model:

"TypeError: The type of the replacement (Vector(float64, shape=(2,))) must be compatible with the type of the original Variable (Vector(float64, shape=(1,)))."

Below the code. The error message refers to the line "annual cycle = ...".

mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True)
mod += st.MeasurementError(name="obs")
model = mod.build(name="IRW+cycle+measurement_error")
model.param_dims
initial_trend_dims, sigma_trend_dims, annual_cycle_dims, sigma_obs_dims, P0_dims = model.param_dims.values()
coords = model.coords
with pm.Model(coords=coords) as model_1:
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=10, dims=sigma_trend_dims)
annual_cycle = pm.Normal("annual_cycle", sigma=5, dims=annual_cycle_dims)
sigma_annual_cycle = pm.Gamma("sigma_annual_cycle", alpha=2, beta=5)
sigma_obs = pm.Gamma("sigma_obs", alpha=2, beta=5, dims=['observed_state'])
model.build_statespace_graph(data, mode="JAX")
idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9)

I guess the problem is that annual_cycle should have shape = (2,), but has shape = (1,). May that points to a bug in the base code? For instance, in line 1329, init_state gets shape=(1,), though I wonder if it should read shape=(2,). However, this alone does not fix the problem as I found; more adaptations may be needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant