Skip to content

calculation with tensor run errors #5694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mgendia opened this issue Apr 7, 2022 · 8 comments
Closed

calculation with tensor run errors #5694

mgendia opened this issue Apr 7, 2022 · 8 comments
Labels

Comments

@mgendia
Copy link

mgendia commented Apr 7, 2022

Description of your problem

I was building models on big datasets on pymc3 which were going fine but the run time was either too long or the kernel would simply, so I turned to pymc4 to utilize the GPU power. But there seems to be an error when building a model, that requires data transformation using tensors, that otherwise were working fine on pymc3. I am sharing a dummy code from slava kisilevich example on github https://github.com/slavakx/bayesian_mmm . But here I try to modify to run it using pymc4;
The project is around creating a marketing mix model where the bayesian model is aimed to estimate the adstock and saturation variables (theta, alpha, and gamma).

These functions that cause the error

def adstock_geometric(x: float, theta: float):
    x_decayed = np.zeros_like(x)
    x_decayed[0] = x[0]
                               
    for xi in range(1, len(x_decayed)):
        x_decayed[xi] = x[xi] + theta * x_decayed[xi - 1]

    return x_decayed

    def adstock_geometric_theano_pymc3(x, theta):
        x = tt.as_tensor_variable(x)
        #x = tt.vector("x")
        #theta = tt.scalar("theta")

    def adstock_geometric_recurrence_theano(index, input_x, decay_x, theta):
        return tt.set_subtensor(decay_x[index], tt.sum(input_x + theta * decay_x[index - 1]))

    len_observed = x.shape[0]

    x_decayed = tt.zeros_like(x)
    x_decayed = tt.set_subtensor(x_decayed[0], x[0])

    output, _ = theano.scan(
        fn = adstock_geometric_recurrence_theano, 
        sequences = [tt.arange(1, len_observed), x[1:len_observed]], 
        outputs_info = x_decayed,
        non_sequences = theta, 
        n_steps = len_observed - 1
    )
    
    return output[-1]



def saturation_hill_pymc3(x, alpha, gamma): 
    
    x_s_hill = x ** alpha / (x ** alpha + gamma ** alpha)
    
    return x_s_hill

And this is how the model piece look like;

transform_variables = ["trend",	"season", "holiday", "competitor_sales_B",	"events", "tv_S", "ooh_S", "print_S", "facebook_I", "search_clicks_P", "newsletter"]

delay_channels = ["tv_S", "ooh_S", "print_S", "facebook_I", "search_clicks_P", "newsletter"]

media_channels = ["tv_S", "ooh_S", "print_S", "facebook_I", "search_clicks_P"]

control_variables = ["trend", "season", "holiday", "competitor_sales_B", "events"]

target = "revenue"

with pm.Model() as model_2:
    for channel_name in delay_channels:
        print(f"Delay Channels: Adding {channel_name}")
        
        x = data_transformed[channel_name].values
        
        
        
        adstock_param = pm.Beta(f"{channel_name}_adstock", 3, 3)
        saturation_gamma = pm.Beta(f"{channel_name}_gamma", 2, 2)
        saturation_alpha = pm.Gamma(f"{channel_name}_alpha", 3, 1)
        
        x_new = adstock_geometric_theano_pymc3(x, adstock_param)
        x_new_sliced = x_new[START_ANALYSIS_INDEX:END_ANALYSIS_INDEX]
        saturation_tensor = saturation_hill_pymc3(x_new_sliced, saturation_alpha, saturation_gamma)
        
        channel_b = pm.HalfNormal(f"{channel_name}_media_coef", sd = 3)
        response_mean.append(saturation_tensor * channel_b)
        
    for control_var in control_variables:
        print(f"Control Variables: Adding {control_var}")
        
        x = data_transformed[control_var].values[START_ANALYSIS_INDEX:END_ANALYSIS_INDEX]
        
        control_beta = pm.Normal(f"{control_var}_control_coef", sd = 3)
        control_x = control_beta * x
        response_mean.append(control_x)
        
    intercept = pm.Normal("intercept", np.mean(data_transformed[target].values), sd = 3)
    #intercept = pm.HalfNormal("intercept", 0, sd = 3)
        
    sigma = pm.HalfNormal("sigma", 4)
    
    likelihood = pm.Normal("outcome", mu = intercept + sum(response_mean), sd = sigma, observed = data_transformed[target].values[START_ANALYSIS_INDEX:END_ANALYSIS_INDEX])

and that is the error I get:

KeyError                                  Traceback (most recent call last)
File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/theano/tensor/type.py:265, in TensorType.dtype_specs(self)
    264 try:
--> 265     return {
    266         "float16": (float, "npy_float16", "NPY_FLOAT16"),
    267         "float32": (float, "npy_float32", "NPY_FLOAT32"),
    268         "float64": (float, "npy_float64", "NPY_FLOAT64"),
    269         "bool": (bool, "npy_bool", "NPY_BOOL"),
    270         "uint8": (int, "npy_uint8", "NPY_UINT8"),
    271         "int8": (int, "npy_int8", "NPY_INT8"),
    272         "uint16": (int, "npy_uint16", "NPY_UINT16"),
    273         "int16": (int, "npy_int16", "NPY_INT16"),
    274         "uint32": (int, "npy_uint32", "NPY_UINT32"),
    275         "int32": (int, "npy_int32", "NPY_INT32"),
    276         "uint64": (int, "npy_uint64", "NPY_UINT64"),
    277         "int64": (int, "npy_int64", "NPY_INT64"),
    278         "complex128": (complex, "theano_complex128", "NPY_COMPLEX128"),
    279         "complex64": (complex, "theano_complex64", "NPY_COMPLEX64"),
    280     }[self.dtype]
    281 except KeyError:

KeyError: 'object'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Input In [24], in <cell line: 27>()
     36 saturation_gamma = pm.Beta(f"{channel_name}_gamma", 2, 2)
     37 saturation_alpha = pm.Gamma(f"{channel_name}_alpha", 3, 1)
---> 39 x_new = adstock_geometric_theano_pymc3(x, adstock_param)
     40 x_new_sliced = x_new[START_ANALYSIS_INDEX:END_ANALYSIS_INDEX]
     41 saturation_tensor = saturation_hill_pymc3(x_new_sliced, saturation_alpha, saturation_gamma)

Input In [23], in adstock_geometric_theano_pymc3(x, theta)
     20 x_decayed = tt.zeros_like(x)
     21 x_decayed = tt.set_subtensor(x_decayed[0], x[0])
---> 23 output, _ = theano.scan(
     24     fn = adstock_geometric_recurrence_theano, 
     25     sequences = [tt.arange(1, len_observed), x[1:len_observed]], 
     26     outputs_info = x_decayed,
     27     non_sequences = theta, 
     28     n_steps = len_observed - 1
     29 )
     31 return output[-1]

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/theano/scan/basic.py:347, in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
    345 for elem in wrap_into_list(non_sequences):
    346     if not isinstance(elem, Variable):
--> 347         non_seqs.append(tt.as_tensor_variable(elem))
    348     else:
    349         non_seqs.append(elem)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/theano/tensor/basic.py:207, in as_tensor_variable(x, name, ndim)
    198 elif isinstance(x, bool):
    199     raise TypeError(
    200         "Cannot cast True or False as a tensor variable. Please use "
    201         "np.array(True) or np.array(False) if you need these constants. "
   (...)
    204         "use theano.tensor.eq(v, w) instead."
    205     )
--> 207 return constant(x, name=name, ndim=ndim)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/theano/tensor/basic.py:255, in constant(x, name, ndim, dtype)
    249             raise ValueError(
    250                 f"ndarray could not be cast to constant with {int(ndim)} dimensions"
    251             )
    253     assert x_.ndim == ndim
--> 255 ttype = TensorType(dtype=x_.dtype, broadcastable=[s == 1 for s in x_.shape])
    257 try:
    258     return TensorConstant(ttype, x_, name=name)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/theano/tensor/type.py:54, in TensorType.__init__(self, dtype, broadcastable, name, sparse_grad)
     51 # broadcastable is immutable, and all elements are either
     52 # True or False
     53 self.broadcastable = tuple(bool(b) for b in broadcastable)
---> 54 self.dtype_specs()  # error checking is done there
     55 self.name = name
     56 self.numpy_dtype = np.dtype(self.dtype)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/theano/tensor/type.py:282, in TensorType.dtype_specs(self)
    265     return {
    266         "float16": (float, "npy_float16", "NPY_FLOAT16"),
    267         "float32": (float, "npy_float32", "NPY_FLOAT32"),
   (...)
    279         "complex64": (complex, "theano_complex64", "NPY_COMPLEX64"),
    280     }[self.dtype]
    281 except KeyError:
--> 282     raise TypeError(
    283         f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
    284     )

TypeError: Unsupported dtype for TensorType: object

if I declare the variables inside the functions as tensor.fvector or tensor.scalar it passes the calculation part but throws this error;

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [29], in <cell line: 27>()
     41     saturation_tensor = saturation_hill_pymc3(x_new_sliced, saturation_alpha, saturation_gamma)
     43     channel_b = pm.HalfNormal(f"{channel_name}_media_coef",  3)
---> 44     response_mean.append(saturation_tensor * channel_b)
     46 for control_var in control_variables:
     47     print(f"Control Variables: Adding {control_var}")

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/var.py:203, in _tensor_py_operators.__rmul__(self, other)
    202 def __rmul__(self, other):
--> 203     return at.math.mul(other, self)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/graph/op.py:294, in Op.__call__(self, *inputs, **kwargs)
    252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    291 
    292 """
    293 return_list = kwargs.pop("return_list", False)
--> 294 node = self.make_node(*inputs, **kwargs)
    296 if config.compute_test_value != "off":
    297     compute_test_value(node)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/elemwise.py:462, in Elemwise.make_node(self, *inputs)
    456 def make_node(self, *inputs):
    457     """
    458     If the inputs have different number of dimensions, their shape
    459     is left-completed to the greatest number of dimensions with 1s
    460     using DimShuffle.
    461     """
--> 462     inputs = [as_tensor_variable(i) for i in inputs]
    463     out_dtypes, out_broadcastables, inputs = self.get_output_info(
    464         DimShuffle, *inputs
    465     )
    466     outputs = [
    467         TensorType(dtype=dtype, shape=broadcastable)()
    468         for dtype, broadcastable in zip(out_dtypes, out_broadcastables)
    469     ]

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/elemwise.py:462, in <listcomp>(.0)
    456 def make_node(self, *inputs):
    457     """
    458     If the inputs have different number of dimensions, their shape
    459     is left-completed to the greatest number of dimensions with 1s
    460     using DimShuffle.
    461     """
--> 462     inputs = [as_tensor_variable(i) for i in inputs]
    463     out_dtypes, out_broadcastables, inputs = self.get_output_info(
    464         DimShuffle, *inputs
    465     )
    466     outputs = [
    467         TensorType(dtype=dtype, shape=broadcastable)()
    468         for dtype, broadcastable in zip(out_dtypes, out_broadcastables)
    469     ]

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/__init__.py:42, in as_tensor_variable(x, name, ndim, **kwargs)
     10 def as_tensor_variable(
     11     x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
     12 ) -> "TensorVariable":
     13     """Convert `x` into an equivalent `TensorVariable`.
     14 
     15     This function can be used to turn ndarrays, numbers, `Scalar` instances,
   (...)
     40 
     41     """
---> 42     return _as_tensor_variable(x, name, ndim, **kwargs)

File /opt/python/3.8.10/lib/python3.8/functools.py:875, in singledispatch.<locals>.wrapper(*args, **kw)
    871 if not args:
    872     raise TypeError(f'{funcname} requires at least '
    873                     '1 positional argument')
--> 875 return dispatch(args[0].__class__)(*args, **kw)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/__init__.py:49, in _as_tensor_variable(x, name, ndim, **kwargs)
     45 @singledispatch
     46 def _as_tensor_variable(
     47     x, name: Optional[str], ndim: Optional[int], **kwargs
     48 ) -> "TensorVariable":
---> 49     raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")

NotImplementedError: Cannot convert Elemwise{true_div,no_inplace}.0 to a tensor variable.

I go into a series of attempts changing the channel_b then the response_mean , and the intercept to tensor vectors to pass the same errors I am getting but I end up with the below error when calculating the likelihood that I can't seem to overcome;

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [42], in <cell line: 27>()
     63 response_mean= tt.sum(response_mean) + intercept    
     64 sigma = pm.HalfNormal("sigma", 4)
---> 66 likelihood = pm.Normal("outcome", mu = response_mean, sigma = sigma, observed = data_transformed[target].values[START_ANALYSIS_INDEX:END_ANALYSIS_INDEX])

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/pymc/distributions/distribution.py:266, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    262     rng = model.next_rng()
    264 # Create the RV and process dims and observed to determine
    265 # a shape by which the created RV may need to be resized.
--> 266 rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
    267     cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs
    268 )
    270 if resize_shape:
    271     # A batch size was specified through `dims`, or implied by `observed`.
    272     rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/pymc/distributions/distribution.py:165, in _make_rv_and_resize_shape(cls, dims, model, observed, args, **kwargs)
    162 """Creates the RV and processes dims or observed to determine a resize shape."""
    163 # Create the RV without dims information, because that's not something tracked at the Aesara level.
    164 # If necessary we'll later replicate to a different size implied by already known dims.
--> 165 rv_out = cls.dist(*args, **kwargs)
    166 ndim_actual = rv_out.ndim
    167 resize_shape = None

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/pymc/distributions/continuous.py:554, in Normal.dist(cls, mu, sigma, tau, no_assert, **kwargs)
    551 if not no_assert:
    552     assert_negative_support(sigma, "sigma", "Normal")
--> 554 return super().dist([mu, sigma], **kwargs)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/pymc/distributions/distribution.py:353, in Distribution.dist(cls, dist_params, shape, size, **kwargs)
    348 create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
    349     shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp
    350 )
    351 # Create the RV with a `size` right away.
    352 # This is not necessarily the final result.
--> 353 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
    355 # Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
    356 if shape is not None and Ellipsis in shape:

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/random/basic.py:108, in NormalRV.__call__(self, loc, scale, size, **kwargs)
    107 def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
--> 108     return super().__call__(loc, scale, size=size, **kwargs)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/random/op.py:279, in RandomVariable.__call__(self, size, name, rng, dtype, *args, **kwargs)
    278 def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
--> 279     res = super().__call__(rng, size, dtype, *args, **kwargs)
    281     if name is not None:
    282         res.name = name

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/graph/op.py:294, in Op.__call__(self, *inputs, **kwargs)
    252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    291 
    292 """
    293 return_list = kwargs.pop("return_list", False)
--> 294 node = self.make_node(*inputs, **kwargs)
    296 if config.compute_test_value != "off":
    297     compute_test_value(node)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/random/op.py:312, in RandomVariable.make_node(self, rng, size, dtype, *dist_params)
    287 """Create a random variable node.
    288 
    289 Parameters
   (...)
    308 
    309 """
    310 size = normalize_size_param(size)
--> 312 dist_params = tuple(
    313     as_tensor_variable(p) if not isinstance(p, Variable) else p
    314     for p in dist_params
    315 )
    317 if rng is None:
    318     rng = aesara.shared(np.random.default_rng())

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/random/op.py:313, in <genexpr>(.0)
    287 """Create a random variable node.
    288 
    289 Parameters
   (...)
    308 
    309 """
    310 size = normalize_size_param(size)
    312 dist_params = tuple(
--> 313     as_tensor_variable(p) if not isinstance(p, Variable) else p
    314     for p in dist_params
    315 )
    317 if rng is None:
    318     rng = aesara.shared(np.random.default_rng())

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/__init__.py:42, in as_tensor_variable(x, name, ndim, **kwargs)
     10 def as_tensor_variable(
     11     x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
     12 ) -> "TensorVariable":
     13     """Convert `x` into an equivalent `TensorVariable`.
     14 
     15     This function can be used to turn ndarrays, numbers, `Scalar` instances,
   (...)
     40 
     41     """
---> 42     return _as_tensor_variable(x, name, ndim, **kwargs)

File /opt/python/3.8.10/lib/python3.8/functools.py:875, in singledispatch.<locals>.wrapper(*args, **kw)
    871 if not args:
    872     raise TypeError(f'{funcname} requires at least '
    873                     '1 positional argument')
--> 875 return dispatch(args[0].__class__)(*args, **kw)

File ~/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/tensor/__init__.py:49, in _as_tensor_variable(x, name, ndim, **kwargs)
     45 @singledispatch
     46 def _as_tensor_variable(
     47     x, name: Optional[str], ndim: Optional[int], **kwargs
     48 ) -> "TensorVariable":
---> 49     raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")

NotImplementedError: Cannot convert Elemwise{add,no_inplace}.0 to a tensor variable.

if I remove the adstock_geometric_theano_pymc3() & saturation_hill_pymc3() from the calculation within the model building section, it runs normally, but then defeats the whole purpose of the model..

I am not very familiar with tensor so I am not sure my bug fixing tactics were right...
appreciate any guidance on how to fix the bugs or any guidance on how to use the GPU power with pymc3, I have a deadline coming soon and in a time crunch..

thank you.

Versions and main components

  • PyMC/PyMC3 Version: 4.0.0b6
  • Theano Version: 1.1.2
  • Aesara Version: 2.5.3
  • Python Version: 3.8.0
  • Operating system: working on virtual enviroment at work, believe its linux based. not 100% sure to be honest
  • How did you install PyMC/PyMC3: pip
@ricardoV94
Copy link
Member

ricardoV94 commented Apr 8, 2022

Even though I can't see the imports, it seems like you might be using Theano/Theano-pymc functions inside those helper functions and not Aesara ones. You need to change any import theano; import theano.tensor as tt to import aesara; import aesara.tensor as at in PyMC V4, and use those.

@ricardoV94
Copy link
Member

There is a short guide here that might be useful: https://www.pymc-labs.io/blog-posts/the-quickest-migration-guide-ever-from-pymc3-to-pymc-v40/

@mgendia
Copy link
Author

mgendia commented Apr 8, 2022

Thank you @ricardoV94 for the clarification. It does pass now once I changed all the theano functions to aesara. much appreciated.

when running the

with model:
   trace= pm.sampling_jax.sample_numpyro_nuts(1000, tune=2000)

I get:

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "MAG_non_heirarchial_agg_channels_brick_filter_time_control.py", line 171, in <module>
    trace= pm.sampling_jax.sample_numpyro_nuts(1000, tune=2000)
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/pymc/sampling_jax.py", line 516, in sample_numpyro_nuts
    pmap_numpyro.run(
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 578, in run
    states, last_state = pmap(partial_map_fn)(map_args)
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 171, in reraise_with_filtered_traceback
    raise e.with_traceback(filtered_tb)
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 360, in _single_chain_mcmc
    init_state = self.sampler.init(
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 735, in init
    init_state = hmc_init_fn(init_params, rng_key)
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 716, in <lambda>
    hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 317, in init_kernel
    vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 279, in init_fn
    potential_energy, z_grad = _value_and_grad(
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 247, in _value_and_grad
    return value_and_grad(f)(x)
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
    return logp_fn(*x)[0]
  File "/tmp/tmpiivv26_7", line 132, in jax_funcified_fgraph
    auto_213928, auto_213929, auto_213930, auto_213931, auto_213932 = scan(auto_187703, auto_201982, auto_189830, auto_189255, auto_191105, auto_192380, auto_193655, auto_194936, auto_213016, auto_213015, auto_213014, auto_213013, auto_213012, auto_199748, auto_199769, auto_199788, auto_199807, auto_199826)
  File "/home/gendimo1/sdz_marketing_mix_model/pymc4.env/lib/python3.8/site-packages/aesara/link/jax/dispatch.py", line 420, in scan
    scan_args = ScanArgs(
TypeError: __init__() missing 1 required positional argument: 'as_while'

not sure why? should I raise another issue for it?

Thanks

@ricardoV94
Copy link
Member

Yes that's expected. Unfortunately the JAX backend does not currently work with Scans

@ricardoV94
Copy link
Member

There is an issue for that here: aesara-devs/aesara#710

@mgendia
Copy link
Author

mgendia commented Apr 8, 2022

@ricardoV94 what would you recommend me do as a workaround? or simply sampling_jax is not functioning at the moment?

@ricardoV94
Copy link
Member

sampling_jax is working but not for all types of models, including those using scans.

You can write the correct jax code directly and wrap it in a new Op if that's worth the trouble for you.

There's a WIP guide here: pymc-devs/pymc-examples#302

@mgendia
Copy link
Author

mgendia commented Apr 8, 2022

thanks @ricardoV94 , I will look into it.

I will close this issue as you have addressed it well. appreciate it :)

@mgendia mgendia closed this as completed Apr 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants