Skip to content

including cyclic or seasonal components causes error messages from build_statespace_graph since last bug fix #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
rklees opened this issue Dec 21, 2023 · 0 comments · Fixed by #288

Comments

@rklees
Copy link

rklees commented Dec 21, 2023

I still do not get pymc_experimental running when adding a cycle and/or a seasonal to the model. Below an example code. In this case, the model consists of an integrated random walk + cycle + seasonal with 1 harmonic + measurement noise. I can generate the data associated with this model, but when trying to integrate into pymc I get an error message from build_statespace_graph. The error message points to inconsistencies in the shape of the cyclic and/or seasonal component (shape (2,) versus shape (1,) ) in base code. The error appears when I add any combination of cyclic and/or seasonal components. Before the last error fix related to cyclic components documented on Github it only appeared when a cyclic component was added.

Here is the code and the error message:

import jax

jax.config.update("jax_platform_name", "cpu")
import numpyro

#import blackjax
import nutpie

numpyro.set_host_device_count(4)

import sys

sys.path.append("..")
print(sys.path)

import pymc_experimental.statespace

import importlib
importlib.reload(pymc_experimental.statespace.structural)

from pymc_experimental.statespace import structural as st
from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG, MATRIX_NAMES
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
import pytensor
import pytensor.tensor as pt
import numpy as np
import pandas as pd
from patsy import dmatrix
from pymc_experimental.statespace.core.representation import PytensorRepresentation
import xarray as xr

%reload_ext autoreload
%autoreload complete

from importlib.metadata import version
print('pymc version = ', version('pymc'))
print('pytensor version = ', version('pytensor'))

print('pandas version = ', version('pandas'))
print('pandas version = ', pd.version)
print('arviz version = ', version('arviz'))
print('numpy version = ', version('numpy'))
print('pytensor version = ', version('pytensor'))
#print('blackjax version = ', version('blackjax'))
print('nutpie version = ', version('nutpie'))
print('xarray version = ', version('xarray'))

plt.rcParams.update(
{
"figure.figsize": (14, 4),
"figure.dpi": 144,
"figure.constrained_layout.use": True,
"axes.grid": True,
"grid.linewidth": 0.5,
"grid.linestyle": "--",
"axes.spines.top": False,
"axes.spines.bottom": False,
"axes.spines.left": False,
"axes.spines.right": False,
}
)

def unpack_statespace(ssm):
return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES]

def unpack_symbolic_matrices_with_params(mod, param_dict):
f_matrices = pytensor.function(
list(mod._name_to_variable.values()), unpack_statespace(mod.ssm), on_unused_input="ignore"
)
x0, P0, c, d, T, Z, R, H, Q = f_matrices(**param_dict)
return x0, P0, c, d, T, Z, R, H, Q

def simulate_from_numpy_model(mod, rng, param_dict, steps=100):
"""
Helper function to visualize the components outside of a PyMC model context
"""
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict)
Z_time_varies = Z.ndim == 3

k_states = mod.k_states
k_posdef = mod.k_posdef

x = np.zeros((steps, k_states))
y = np.zeros(steps)

x[0] = x0
if Z_time_varies:
    y[0] = Z[0] @ x0
else:
    y[0] = Z @ x0

if not np.allclose(H, 0):
    y[0] += rng.multivariate_normal(mean=np.zeros(1), cov=H)

for t in range(1, steps):
    if k_posdef > 0:
        shock = rng.multivariate_normal(mean=np.zeros(k_posdef), cov=Q)
        innov = R @ shock
    else:
        innov = 0

    if not np.allclose(H, 0):
        error = rng.multivariate_normal(mean=np.zeros(1), cov=H)
    else:
        error = 0

    x[t] = c + T @ x[t - 1] + innov

    if Z_time_varies:
        y[t] = d + Z[t] @ x[t] + error
    else:
        y[t] = d + Z @ x[t] + error

return x, y

def simulate_many_trajectories(mod, rng, param_dict, n_simulations, steps=100):
k_states = mod.k_states
k_posdef = mod.k_posdef

xs = np.zeros((n_simulations, steps, k_states))
ys = np.zeros((n_simulations, steps))

for i in range(n_simulations):
    x, y = simulate_from_numpy_model(mod, rng, param_dict, steps)
    xs[i] = x
    ys[i] = y
return xs, ys

seed = sum(map(ord, "Structural Timeseries"))
rng = np.random.default_rng(seed)

measurement_error = st.MeasurementError(name="obs")
IRW = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
cycle = st.CycleComponent(name="annual_cycle", cycle_length=12, innovations=True) # cycle length is the period in number of samples; non-integer periods are allowed
SA_cycle = st.FrequencySeasonality(
name="SA_cycle", season_length=5.347, n=1, innovations=True # season_length is the period in units of number of sampels; non-integer values allowed
)

param_dict = {
"initial_trend": np.zeros((2,)), "sigma_trend": np.array([0.2]),
"annual_cycle": np.array([10., 0.]), "sigma_annual_cycle": np.array([1.0]),
"SA_cycle": np.array([20., 0.]), "sigma_SA_cycle": np.array([0.5]),
"sigma_obs": np.array([0.1]),
}

mod = IRW + cycle + SA_cycle + measurement_error

x, y = simulate_from_numpy_model(mod, rng, param_dict, steps=144)

plt.figure(figsize=(10,5))
plt.plot(y), plt.title('IRW plus annual cycle plus SA_cycle')

plt.figure(figsize=(10,5))
plt.plot(x[:, 0]), plt.title('level component')

plt.figure(figsize=(10,5))
plt.plot(x[:, 1]), plt.title('trend component')

plt.figure(figsize=(10,5))
plt.plot(x[:, 2]), plt.title('annual cycle component')

plt.figure(figsize=(10,5))
plt.plot(x[:, 4]), plt.title('SA cycle component')

time = np.arange(144)/12
data = pd.DataFrame({'time': time, 'meas': y})
nobs = len(data['meas'])
dt = np.mean(np.diff(data['time'])) # sampling period in units of years

mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True)
mod += st.FrequencySeasonality(name='SA_cycle', season_length=5.347, n=1, innovations=True)
mod += st.MeasurementError(name="obs")

model = mod.build(name="IRW+cycle+measurement_error")

with pm.Model(coords=coords) as model_1:
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=10, dims=sigma_trend_dims)
annual_cycle = pm.Normal("annual_cycle", sigma=5)
sigma_annual_cycle = pm.Gamma("sigma_annual_cycle", alpha=2, beta=5)
SA_cycle = pm.Normal("SA_cycle", sigma=5)
sigma_SA_cycle = pm.Gamma("sigma_SA_cycle", alpha=2, beta=5)
sigma_obs = pm.Gamma("sigma_obs", alpha=2, beta=5, dims=('observed_state',))

with model_1:
model.build_statespace_graph(data['meas'], mode="JAX")


TypeError Traceback (most recent call last)
Cell In[12], line 2
1 with model_1:
----> 2 model.build_statespace_graph(data['meas'], mode="JAX")

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:779, in PyMCStateSpace.build_statespace_graph(self, data, register_data, mode, missing_fill_value, cov_jitter, return_updates, include_smoother)
721 """
722 Given a parameter vector theta, constructs the full computational graph describing the state space model and
723 the associated log probability of the data. Hidden states and log probabilities are computed via the Kalman
(...)
775 If return_updates is False, the method will return None.
776 """
777 pm_mod = modelcontext(None)
--> 779 self._insert_random_variables()
780 obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
782 self.data_len = data.shape[0]

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:637, in PyMCStateSpace._insert_random_variables(self)
633 matrices = list(self._unpack_statespace_with_placeholders())
634 replacement_dict = {
635 var: pt.atleast_1d(pymc_model[name]) for name, var in self._name_to_variable.items()
636 }
--> 637 self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/replace.py:205, in graph_replace(outputs, replace, strict)
197 raise ValueError(f"{key} is not a part of graph")
199 sorted_replacements = sorted(
200 fg_replace.items(),
201 # sort based on the fg toposort, if a variable has no owner, it goes first
202 key=partial(toposort_key, fg, toposort),
203 reverse=True,
204 )
--> 205 fg.replace_all(sorted_replacements, import_missing=True)
206 if as_list:
207 return list(fg.outputs)

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:515, in FunctionGraph.replace_all(self, pairs, **kwargs)
513 """Replace variables in the FunctionGraph according to (var, new_var) pairs in a list."""
514 for var, new_var in pairs:
--> 515 self.replace(var, new_var, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:508, in FunctionGraph.replace(self, var, new_var, reason, verbose, import_missing)
501 raise AssertionError(
502 "The replacement variable has a test value with "
503 "a shape different from the original variable's "
504 f"test value. Original: {tval_shape}, new: {new_tval_shape}"
505 )
507 for node, i in list(self.clients[var]):
--> 508 self.change_node_input(
509 node, i, new_var, reason=reason, import_missing=import_missing
510 )

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:428, in FunctionGraph.change_node_input(self, node, i, new_var, reason, import_missing, check)
426 r = node.inputs[i]
427 if check and not r.type.is_super(new_var.type):
--> 428 raise TypeError(
429 f"The type of the replacement ({new_var.type}) must be "
430 f"compatible with the type of the original Variable ({r.type})."
431 )
432 node.inputs[i] = new_var
434 if r is new_var:

TypeError: The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(2,))).

@rklees rklees changed the title including cycles generates error messages from build_statespace_graph including cyclic or seasonal components causes error messages from build_statespace_graph since last bug fix Dec 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant