Skip to content

Jax sampling type-hints incompatible with Python 3.9 #6941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
lt-brs opened this issue Oct 4, 2023 · 8 comments · Fixed by #6945 or #6969
Closed

Jax sampling type-hints incompatible with Python 3.9 #6941

lt-brs opened this issue Oct 4, 2023 · 8 comments · Fixed by #6945 or #6969

Comments

@lt-brs
Copy link

lt-brs commented Oct 4, 2023

Describe the issue:

Latest release of pymc (5.9.0) seems incompatible with a numpyro powered fit.
I'm using pymc via the package bambi, a simple wrapper to build Bayesian linear regression.

Downgrading pymc to 5.7.2 made the code sample work.

Reproduceable code example:

import bambi as bmb

SEED = 7355608
data = bmb.load_data("ESCS")
model = bmb.Model("drugs ~ o + c + e + a + n", data)

fitted = model.fit(
    tune=2000, draws=2000, 
    inference_method="nuts_numpyro",
    init="adapt_diag", random_seed=SEED)

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/home/lambert_bruyas/git/mmo-prism/notebooks/LaB/Support/pymc_jax_bug.ipynb Cell 4 line 2
      1 model = bmb.Model("drugs ~ o + c + e + a + n", data)
----> 2 fitted = model.fit(
      3     tune=2000, draws=2000, 
      4     inference_method="nuts_numpyro",
      5     init="adapt_diag", random_seed=SEED)

File ~/anaconda3/envs/prism_dev/lib/python3.9/site-packages/bambi/models.py:325, in Model.fit(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
    318     response = self.components[self.response_name]
    319     _log.info(
    320         "Modeling the probability that %s==%s",
    321         response.response_term.name,
    322         str(response.response_term.success),
    323     )
--> 325 return self.backend.run(
    326     draws=draws,
    327     tune=tune,
    328     discard_tuned_samples=discard_tuned_samples,
    329     omit_offsets=omit_offsets,
    330     include_mean=include_mean,
    331     inference_method=inference_method,
    332     init=init,
    333     n_init=n_init,
    334     chains=chains,
    335     cores=cores,
    336     random_seed=random_seed,
    337     **kwargs,
    338 )

File ~/anaconda3/envs/prism_dev/lib/python3.9/site-packages/bambi/backend/pymc.py:96, in PyMCModel.run(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
     94 # NOTE: Methods return different types of objects (idata, approximation, and dictionary)
     95 if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
---> 96     result = self._run_mcmc(
     97         draws,
     98         tune,
     99         discard_tuned_samples,
    100         omit_offsets,
    101         include_mean,
    102         init,
    103         n_init,
    104         chains,
    105         cores,
    106         random_seed,
    107         inference_method,
    108         **kwargs,
    109     )
    110 elif inference_method == "vi":
    111     result = self._run_vi(**kwargs)

File ~/anaconda3/envs/prism_dev/lib/python3.9/site-packages/bambi/backend/pymc.py:206, in PyMCModel._run_mcmc(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, init, n_init, chains, cores, random_seed, sampler_backend, **kwargs)
    204             raise
    205 elif sampler_backend == "nuts_numpyro":
--> 206     import pymc.sampling_jax  # pylint: disable=import-outside-toplevel
    208     if not chains:
    209         # sample_numpyro_nuts does not handle chains = None like pm.sample does
    210         chains = 4

File ~/anaconda3/envs/prism_dev/lib/python3.9/site-packages/pymc/sampling_jax.py:22
     19 import warnings
     21 warnings.warn("This module is deprecated, use pymc.sampling.jax", DeprecationWarning)
---> 22 from pymc.sampling.jax import *

File ~/anaconda3/envs/prism_dev/lib/python3.9/site-packages/pymc/sampling/jax.py:185
    178 def _device_put(input, device: str):
    179     return jax.device_put(input, jax.devices(device)[0])
    182 def _postprocess_samples(
    183     jax_fn: Callable,
    184     raw_mcmc_samples: List[TensorVariable],
--> 185     postprocessing_backend: Literal["cpu", "gpu"] | None = None,
    186     postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
    187 ) -> List[TensorVariable]:
    188     if postprocessing_vectorize == "scan":
    189         t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]

TypeError: unsupported operand type(s) for |: '_LiteralGenericAlias' and 'NoneType'

PyMC version information:

pymc=5.9.0
bambi=0.12.0
numpyro=0.13.0
jax=0.4.14
jaxlib=0.4.14
pytensor=2.17.1

Context for the issue:

No response

@lt-brs lt-brs added the bug label Oct 4, 2023
@welcome
Copy link

welcome bot commented Oct 4, 2023

Welcome Banner
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@digicosmos86
Copy link
Contributor

Having the same error here. All numpyro-based sampling fails with this error. Is pymc dropping support for Python 3.9?

@ColCarroll
Copy link
Member

Right, it looks like numpyro uses union types, introduced in Python 3.10. The easiest solution would be to install a previous version of numpyro. A harder solution would be to use a more recent version of Python.

Note that the scientific python spec suggests dropping support for Python 3.9 tomorrow!

@junpenglao
Copy link
Member

junpenglao commented Oct 4, 2023

Note that the scientific python spec suggests dropping support for Python 3.9 tomorrow!

Wow already?! Life comes at you fast...

@ricardoV94
Copy link
Member

PyMC is still supporting Python 3.9 so we should revert those type-hints changes. CC @ferrine

@ricardoV94 ricardoV94 changed the title Pymc powered by Numpyro : TypeError: unsupported operand type(s) for |: '_LiteralGenericAlias' and 'NoneType' Jax sampling type-hints incompatible with Python 3.9 Oct 6, 2023
@michaelosthege michaelosthege self-assigned this Oct 6, 2023
@aarojas20
Copy link

Hi @ricardoV94! I'm still getting this error. I see that the PR made the jax.py fixes in lines 339 and 549, but it looks like line 185 was left out (as OP @lt-brs originally pointed out in the error path).

@ricardoV94 ricardoV94 reopened this Oct 17, 2023
@ricardoV94
Copy link
Member

@aarojas20 thanks for reporting back

@ricardoV94
Copy link
Member

We should run pre-commit on the oldest supported version on python to avoid issues like this in the future

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment