Skip to content

Bug in get_domain_of_finite_discrete_rv of Categorical #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Apr 12, 2024 · 2 comments
Open

Bug in get_domain_of_finite_discrete_rv of Categorical #331

ricardoV94 opened this issue Apr 12, 2024 · 2 comments
Labels
bug Something isn't working good first issue Good for newcomers marginalization

Comments

@ricardoV94
Copy link
Member

Reported by @jessegrabowski

with MarginalModel(coords=coords) as m:
    x_data = pm.ConstantData('x', df.x, dims=['obs_idx'])
    y_data = pm.ConstantData('y', df.y, dims=['obs_idx'])

    X = pt.concatenate([pt.ones_like(x_data[:, None]), x_data[:, None], x_data[:, None] ** 2], axis=-1)

    mu = pm.Normal('mu', dims=['group'])
    beta_p = pm.Normal('beta_p', dims=['params', 'group'])
    logit_p_group = X @ beta_p
    group_idx = pm.Categorical('group_idx', logit_p=logit_p_group, dims=['obs_idx'])
    sigma = pm.Exponential('sigma', 1)

    mu = pt.switch(pt.lt(group_idx, 1), 
                   mu_trend,
                   pt.switch(pt.lt(group_idx, 2), 
                             p_x[:, 0], 
                             p_x[:, 1])
                  )
    
    y_hat = pm.Normal('y_hat', 
                      mu = mu,
                      sigma = sigma,
                      observed=y_data,
                      dims=['obs_idx'])

m.marginalize(["group_idx"])
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pymc_experimental/model/marginal_model.py:655, in get_domain_of_finite_discrete_rv(rv)
    653 elif isinstance(op, Categorical):
    654     p_param = rv.owner.inputs[3]
--> 655     return tuple(range(pt.get_vector_length(p_param)))
    656 elif isinstance(op, DiscreteUniform):
    657     lower, upper = constant_fold(rv.owner.inputs[3:])

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/tensor/__init__.py:82, in get_vector_length(v)
     79 v = as_tensor_variable(v)
     81 if v.type.ndim != 1:
---> 82     raise TypeError(f"Argument must be a vector; got {v.type}")
     84 static_shape: Optional[int] = v.type.shape[0]
     85 if static_shape is not None:

TypeError: Argument must be a vector; got Matrix(float64, shape=(256, 3))

Instead of trying to get the vector length of p_param (which assumse p is always a vector), we should be constant folding p_param.shape[-1].

@ricardoV94 ricardoV94 added bug Something isn't working good first issue Good for newcomers marginalization labels Apr 12, 2024
@victorgarcia98
Copy link

Hello @ricardoV94 @jessegrabowski !

Is this sill an issue? I would like to help :D

I tried to reproduce the error but the example is incomplete. Nonetheless, I tried with the following adapted from this tutorial:

import pymc as pm
import pytensor.tensor as pt
import pandas as pd
import numpy as np
import pymc_extras as pmx

rng = np.random.default_rng(32)

disaster_data = pd.Series(
    [4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
    3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
    2, 2, 3, 4, 2, 1, 3, 0, 2, 1, 1, 1, 1, 3, 0, 0,
    1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
    0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
    3, 3, 1, 0, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
    0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
)


# fmt: on
years = np.arange(1851, 1962)

with pm.Model() as disaster_model:
    switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
    early_rate = pm.Exponential("early_rate", 1.0)
    late_rate = pm.Exponential("late_rate", 1.0)
    rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
    disasters = pm.Poisson("disasters", rate, observed=disaster_data)

with disaster_model:
    before_marg = pm.sample(random_seed=rng)

disaster_model_marginalized = pmx.marginalize(disaster_data, [switchpoint])

with disaster_model_marginalized:
    after_marg = pm.sample(random_seed=rng)

Some observations:

  • I couldn't get samples from disaster_data when it contained np.nan values (see error [1])
  • It looks MarginalModel is deprecated in favor of using a regular PyMC model.
  • marginalize is no longer a method but a function that accepts a model.

Am I doing something wrong?


Error 1

/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/model/core.py:1288: RuntimeWarning: invalid value encountered in cast
  data = convert_observed_data(data).astype(rv_var.dtype)
/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/model/core.py:1302: ImputationWarning: Data in disasters contains missing values and will be automatically imputed from the sampling distribution.
  warnings.warn(impute_message, ImputationWarning)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>CompoundStep
>>Metropolis: [switchpoint]
>>Metropolis: [disasters_unobserved]
>NUTS: [early_rate, late_rate]
                                                                                              
                                    Acce…            Step    Grad     Samp…                   
  Progr…   Draws   Tuni…   Scali…   Rate    Diver…   size    evals    Speed   Elaps…   Rema…  
 ──────────────────────────────────────────────────────────────────────────────────────────── 
  ━━━━━━   0       True    0.00     0.00    0        0.00    0        0.00    0:00:…   -:--…  
                                                                      draw…                   
  ━━━━━━   0       True    0.00     0.00    0        0.00    0        0.00    0:00:…   -:--…  
                                                                      draw…                   
  ━━━━━━   0       True    0.00     0.00    0        0.00    0        0.00    0:00:…   -:--…  
                                                                      draw…                   
  ━━━━━━   0       True    0.00     0.00    0        0.00    0        0.00    0:00:…   -:--…  
                                                                      draw…                   
                                                                                              
Traceback (most recent call last):
  File "/home/victor/Work/Projects/pymc-extras/issue-331.py", line 31, in <module>
    before_marg = pm.sample(random_seed=rng)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 935, in sample
    _mp_sample(**sample_args, **parallel_args)
  File "/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 1411, in _mp_sample
    for draw in sampler:
  File "/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/sampling/parallel.py", line 513, in __iter__
    self._progress.update(
  File "/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/util.py", line 886, in update
    self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/step_methods/compound.py", line 340, in update_stats
    stats = update_fn(stats, step_stat, chain_idx)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/step_methods/compound.py", line 340, in update_stats
    stats = update_fn(stats, step_stat, chain_idx)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/victor/Work/Projects/pymc-extras/venv/lib/python3.12/site-packages/pymc/step_methods/metropolis.py", line 354, in update_stats
    stats["tune"][chain_idx] = step_stats["tune"]
                               ~~~~~~~~~~^^^^^^^^
TypeError: string indices must be integers, not 'str'

@jessegrabowski
Copy link
Member

This is actually a separate bug that is being tracked here. You can avoid it for now by setting progressbar=False in pm.sample. I'll get it fixed today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers marginalization
Projects
None yet
Development

No branches or pull requests

3 participants