Skip to content

statespace: Leveraging RegressionComponent yields error #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
A108669 opened this issue Jan 9, 2024 · 6 comments
Closed

statespace: Leveraging RegressionComponent yields error #297

A108669 opened this issue Jan 9, 2024 · 6 comments
Labels
bug Something isn't working statespace

Comments

@A108669
Copy link

A108669 commented Jan 9, 2024

Hello!
I have been experimenting with the Structural Time Series module; however, I have been running into trouble when I attempt to add an external regressor to a model. Below is an example including dummy data as well as the full stack trace. I have not been able to get any external regressors added to the models.

import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
np.random.seed(100)

trend_data = np.arange(n_samples) * .1
regressor_data = np.random.normal(scale=2, size=n_samples)
y = trend_data + regressor_data + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(
    data={
        'time_index': pd.date_range("2001-01-01", freq="M", periods=len(y)),
        'x': regressor_data,
        'y': y,
    }
)

trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=1)
error = st.MeasurementError(name="error")

df = df.set_index("time_index")
df.index.freq = 'M'

mod = trend + error + regressor
ss_mod = mod.build(name="test")
sigma_trend_dims, sigma_obs_dims, regressor_dims, _, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5, dims=["observed_state"])

    beta_xreg = pm.Normal("beta_xreg", .2, 1)
    data_xreg = pm.MutableData("data_xreg", df[["x"]])

    ss_mod.build_statespace_graph(df[['y']], mode="JAX")
    idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9)
ValueError: Argument [[-0.85651 ... 33556989]] given to the scan node is not compatible with its corresponding loop function variable *4-<Matrix(float64, shape=(?, ?))>

Full Stack Trace:

The following parameters should be assigned priors inside a PyMC model block: 
        initial_trend -- shape: (2,), constraints: None, dims: ('trend_state',)
        sigma_error -- shape: (1,), constraints: Positive, dims: None
        beta_xreg -- shape: (1,), constraints: None, dims: ('exog_state',)
        data_xreg -- shape: (None, 1), constraints: None, dims: ('time', 'exog_state')
        P0 -- shape: (3, 3), constraints: Positive semi-definite, dims: ('state', 'state_aux')
Compiling...
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[69], line 17
     14 data_xreg = pm.MutableData("data_xreg", df[["x"]])
     16 ss_mod.build_statespace_graph(df[['y']], mode="JAX")
---> 17 idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/mcmc.py:696, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    692     if not isinstance(step, NUTS):
    693         raise ValueError(
    694             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    695         )
--> 696     return _sample_external_nuts(
    697         sampler=nuts_sampler,
    698         draws=draws,
    699         tune=tune,
    700         chains=chains,
    701         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    702         random_seed=random_seed,
    703         initvals=initvals,
    704         model=model,
    705         progressbar=progressbar,
    706         idata_kwargs=idata_kwargs,
    707         nuts_sampler_kwargs=nuts_sampler_kwargs,
    708         **kwargs,
    709     )
    711 if isinstance(step, list):
    712     step = CompoundStep(step)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/mcmc.py:350, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    347 elif sampler == "numpyro":
    348     import pymc.sampling.jax as pymc_jax
--> 350     idata = pymc_jax.sample_numpyro_nuts(
    351         draws=draws,
    352         tune=tune,
    353         chains=chains,
    354         target_accept=target_accept,
    355         random_seed=random_seed,
    356         initvals=initvals,
    357         model=model,
    358         progressbar=progressbar,
    359         idata_kwargs=idata_kwargs,
    360         **nuts_sampler_kwargs,
    361     )
    362     return idata
    364 elif sampler == "blackjax":

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:669, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, idata_kwargs, nuts_kwargs, postprocessing_chunks)
    660 logger.info("Compiling...")
    662 init_params = _get_batched_jittered_initial_points(
    663     model=model,
    664     chains=chains,
    665     initvals=initvals,
    666     random_seed=random_seed,
    667 )
--> 669 logp_fn = get_jaxified_logp(model, negative_logp=False)
    671 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
    672 nuts_kernel = NUTS(
    673     potential_fn=logp_fn,
    674     target_accept_prob=target_accept,
    675     **nuts_kwargs,
    676 )

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:151, in get_jaxified_logp(model, negative_logp)
    149 if not negative_logp:
    150     model_logp = -model_logp
--> 151 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    153 def logp_fn_wrap(x):
    154     return logp_fn(*x)[0]

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:126, in get_jaxified_graph(inputs, outputs)
    120 def get_jaxified_graph(
    121     inputs: Optional[List[TensorVariable]] = None,
    122     outputs: Optional[List[TensorVariable]] = None,
    123 ) -> List[TensorVariable]:
    124     """Compile an PyTensor graph into an optimized JAX function"""
--> 126     graph = _replace_shared_variables(outputs) if outputs is not None else None
    128     fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
    129     # We need to add a Supervisor to the fgraph to be able to run the
    130     # JAX sequential optimizer without warnings. We made sure there
    131     # are no mutable input variables, so we only need to check for
    132     # "destroyers". This should be automatically handled by PyTensor
    133     # once https://github.com/aesara-devs/aesara/issues/637 is fixed.

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:116, in _replace_shared_variables(graph)
    109     raise ValueError(
    110         "Graph contains shared variables with default_update which cannot "
    111         "be safely replaced."
    112     )
    114 replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
--> 116 new_graph = clone_replace(graph, replace=replacements)
    117 return new_graph

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/graph/replace.py:87, in clone_replace(output, replace, **rebuild_kwds)
     84 _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
     86 # TODO Explain why we call it twice ?!
---> 87 _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
     89 return outs

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:317, in rebuild_collect_shared(outputs, inputs, replace, updates, rebuild_strict, copy_inputs_over, no_default_updates, clone_inner_graphs)
    315 for v in outputs:
    316     if isinstance(v, Variable):
--> 317         cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
    318         cloned_outputs.append(cloned_v)
    319     elif isinstance(v, Out):

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:193, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    191 if owner not in clone_d:
    192     for i in owner.inputs:
--> 193         clone_v_get_shared_updates(i, copy_inputs_over)
    194     clone_node_and_cache(
    195         owner,
    196         clone_d,
    197         strict=rebuild_strict,
    198         clone_inner_graphs=clone_inner_graphs,
    199     )
    200 return clone_d.setdefault(v, v)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:193, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    191 if owner not in clone_d:
    192     for i in owner.inputs:
--> 193         clone_v_get_shared_updates(i, copy_inputs_over)
    194     clone_node_and_cache(
    195         owner,
    196         clone_d,
    197         strict=rebuild_strict,
    198         clone_inner_graphs=clone_inner_graphs,
    199     )
    200 return clone_d.setdefault(v, v)

    [... skipping similar frames: rebuild_collect_shared.<locals>.clone_v_get_shared_updates at line 193 (3 times)]

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:193, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    191 if owner not in clone_d:
    192     for i in owner.inputs:
--> 193         clone_v_get_shared_updates(i, copy_inputs_over)
    194     clone_node_and_cache(
    195         owner,
    196         clone_d,
    197         strict=rebuild_strict,
    198         clone_inner_graphs=clone_inner_graphs,
    199     )
    200 return clone_d.setdefault(v, v)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:194, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    192         for i in owner.inputs:
    193             clone_v_get_shared_updates(i, copy_inputs_over)
--> 194         clone_node_and_cache(
    195             owner,
    196             clone_d,
    197             strict=rebuild_strict,
    198             clone_inner_graphs=clone_inner_graphs,
    199         )
    200     return clone_d.setdefault(v, v)
    201 elif isinstance(v, SharedVariable):

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/graph/basic.py:1200, in clone_node_and_cache(node, clone_d, clone_inner_graphs, **kwargs)
   1196 new_op: Optional["Op"] = cast(Optional["Op"], clone_d.get(node.op))
   1198 cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
-> 1200 new_node = node.clone_with_new_inputs(
   1201     cloned_inputs,
   1202     # Only clone inner-graph `Op`s when there isn't a cached clone (and
   1203     # when `clone_inner_graphs` is enabled)
   1204     clone_inner_graph=clone_inner_graphs if new_op is None else False,
   1205     **kwargs,
   1206 )
   1208 if new_op:
   1209     # If we didn't clone the inner-graph `Op` above, because
   1210     # there was a cached version, set the cloned `Apply` to use
   1211     # the cached clone `Op`
   1212     new_node.op = new_op

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/graph/basic.py:283, in Apply.clone_with_new_inputs(self, inputs, strict, clone_inner_graph)
    280     if isinstance(new_op, HasInnerGraph) and clone_inner_graph:  # type: ignore
    281         new_op = new_op.clone()  # type: ignore
--> 283     new_node = new_op.make_node(*new_inputs)
    284     new_node.tag = copy(self.tag).__update__(new_node.tag)
    285 else:

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/scan/op.py:1195, in Scan.make_node(self, *inputs)
   1193     new_inputs.append(outer_nonseq)
   1194     if not outer_nonseq.type.in_same_class(inner_nonseq.type):
-> 1195         raise ValueError(
   1196             f"Argument {outer_nonseq} given to the scan node is not"
   1197             f" compatible with its corresponding loop function variable {inner_nonseq}"
   1198         )
   1200 for outer_nitsot in self.outer_nitsot(inputs):
   1201     # For every nit_sot input we get as input a int/uint that
   1202     # depicts the size in memory for that sequence. This feature is
   1203     # used by truncated BPTT and by scan space optimization
   1204     if (
   1205         str(outer_nitsot.type.dtype) not in integer_dtypes
   1206         or outer_nitsot.ndim != 0
   1207     ):

ValueError: Argument [[-0.85651 ... 33556989]] given to the scan node is not compatible with its corresponding loop function variable *4-<Matrix(float64, shape=(?, ?))>
@ricardoV94
Copy link
Member

CC @jessegrabowski

@jessegrabowski
Copy link
Member

jessegrabowski commented Jan 10, 2024

Hey, thanks for giving the module a try!

I can reproduce the problem on my end, so it's definitely a bug. It looks like a JAX problem; I can run your code if I use the default PyMC sampler. Also, if you have more than one regressor it works. For example, this runs:

import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 3

np.random.seed(100)

trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[regressor_data, y],
                  index = pd.date_range("2001-01-01", freq="M", periods=n_samples),
                  columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'M'


trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=['x0', 'x1', 'x2'])
error = st.MeasurementError(name="error")

mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, regression_data_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
    data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values)
    
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5, dims=["observed_state"])

    beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)

    ss_mod.build_statespace_graph(df[['y']], mode='JAX')
    idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)
    # prior = pm.sample_prior_predictive(samples=10)

Probably the data is being incorrectly squeezed somewhere. I'll look closely and push a fix ASAP. Thanks for finding this bug and opening an issue!

@ricardoV94 ricardoV94 added bug Something isn't working statespace labels Jan 10, 2024
@jessegrabowski
Copy link
Member

jessegrabowski commented Feb 15, 2024

I finally had some time to look closely at this. It appears to be a bug that arises because broadcastable dimensions are signified by a shape of 1 in pytensor. This means the program considers them dynamic, since they might change after broadcasting. As a result, JAX gets upset by the graph, because it doesn't allow dynamic shapes. This is why the model works if you have more than one exogenous variable -- the 2nd dimension of the exogenous data isn't 1 anymore, and everything is inferred to be static. Might be related to pymc-devs/pytensor#408, but not sure.

For now, I can think of two possible work-arounds:

  1. Explicitly specify the shape of the exogenous data when you create the pm.MutableData, by passing a shape keyword argument.
  2. Use pm.ConstantData instead of pm.MutableData

Despite my choice of ordering, I think option 2 is preferable.

Here is a working example:

import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 1

np.random.seed(100)

trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[y, regressor_data],
                  index = pd.date_range("2001-01-01", freq="ME", periods=n_samples),
                  columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'ME'


trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=[f'x{i}' for i in range(k_exog)])
error = st.MeasurementError(name="error")

mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
    
    # Option 1:
    data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values, 
                               dims=['time', 'exog_state'],
                               shape=(n_samples, k_exog)) # <--- Key line

    # Option 2:
    # data_xreg = pm.ConstantData("data_xreg", df.drop(columns='y').values, 
    #                           dims=['time', 'exog_state'])
    
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5)

    beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)
    
    ss_mod.build_statespace_graph(df[['y']], mode='JAX')
    idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)

Note that I also specified n_samples. If you don't JAX will bark at you about dynamic shapes when you try to do post-estimation sampling (ss_mod.sample_conditional_posterior, for example).

I'll let you know when I come up with a more long-term solution.

@ricardoV94
Copy link
Member

You only need to specify the shape for broadcastable dims if you intend it to broadcast. You can pass shape=(None, 1) if you still want the other dim to be resizeable (but cannot broadcast it with other parameters)

@jessegrabowski
Copy link
Member

Yes this works as well (with pm.MutableData), but JAX will still error on pm.sample_posterior_predictive, complaining about dynamic slicing. So I recommend to just declare both for now, since conditional forecasting with exogenous timeseries isn't support yet anyway.

@jessegrabowski
Copy link
Member

jessegrabowski commented Apr 16, 2024

This should be fixed by #326. Feel free to open a new issue if you're still hitting problems. I updated the structural example notebook, but it still needs more work. Still, it should give you an idea of how to include exogenous regressors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working statespace
Projects
None yet
Development

No branches or pull requests

3 participants