Skip to content

ADVI errors in numba mode for StudentT likelihood when total_size is set #7778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jessegrabowski opened this issue May 8, 2025 · 1 comment · Fixed by pymc-devs/pytensor#1402
Labels
bug numba VI Variational Inference

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented May 8, 2025

Description

import pymc as pm
import numpy as np

rng = np.random.default_rng()

with pm.Model() as m:
    data = pm.Data('data', rng.normal(size=(1000, 5)))
    obs = pm.Data('obs', rng.normal(size=(1000,)))
    
    data_batch, obs_batch = pm.Minibatch(data, obs, batch_size=128)
    
    beta = pm.Normal('beta', size=(5,))
    mu = data_batch @ beta
    sigma = pm.Exponential('sigma', 1)
    
    y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
    
    idata = pm.fit(n=1_000_000, 
                   compile_kwargs={'mode':'NUMBA'})
Full Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 18
     14 sigma = pm.Exponential('sigma', 1)
     16 y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
---> 18 idata = pm.fit(n=1_000_000, 
     19                compile_kwargs={'mode':'NUMBA'})

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:775, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    773 else:
    774     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 775 return inference.fit(n, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:158, in Inference.fit(self, n, score, callbacks, progressbar, progressbar_theme, **kwargs)
    156     callbacks = []
    157 score = self._maybe_score(score)
--> 158 step_func = self.objective.step_function(score=score, **kwargs)
    160 if score:
    161     state = self._iterate_with_loss(
    162         0, n, step_func, progressbar, progressbar_theme, callbacks
    163     )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/opvi.py:405, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, compile_kwargs, fn_kwargs)
    403 seed = self.approx.rng.randint(2**30, dtype=np.int64)
    404 if score:
--> 405     step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs)
    406 else:
    407     step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/pytensorf.py:947, in compile(inputs, outputs, random_seed, mode, **kwargs)
    945 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    946 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 947 pytensor_function = pytensor.function(
    948     inputs,
    949     outputs,
    950     updates={**rng_updates, **kwargs.pop("updates", {})},
    951     mode=mode,
    952     **kwargs,
    953 )
    954 return pytensor_function

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
    321     fn = orig_function(
    322         inputs,
    323         outputs,
   (...)    327         trust_input=trust_input,
    328     )
    329 else:
    330     # note: pfunc will also call orig_function -- orig_function is
    331     #      a choke point that all compilation must pass through
--> 332     fn = pfunc(
    333         params=inputs,
    334         outputs=outputs,
    335         mode=mode,
    336         updates=updates,
    337         givens=givens,
    338         no_default_updates=no_default_updates,
    339         accept_inplace=accept_inplace,
    340         name=name,
    341         rebuild_strict=rebuild_strict,
    342         allow_input_downcast=allow_input_downcast,
    343         on_unused_input=on_unused_input,
    344         profile=profile,
    345         output_keys=output_keys,
    346         trust_input=trust_input,
    347     )
    348 return fn

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
    452     profile = ProfileStats(message=profile)
    454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    455     params,
    456     outputs,
   (...)    463     fgraph=fgraph,
    464 )
--> 466 return orig_function(
    467     inputs,
    468     cloned_outputs,
    469     mode,
    470     accept_inplace=accept_inplace,
    471     name=name,
    472     profile=profile,
    473     on_unused_input=on_unused_input,
    474     output_keys=output_keys,
    475     fgraph=fgraph,
    476     trust_input=trust_input,
    477 )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
   1820     m = Maker(
   1821         inputs,
   1822         outputs,
   (...)   1830         trust_input=trust_input,
   1831     )
   1832     with config.change_flags(compute_test_value="off"):
-> 1833         fn = m.create(defaults)
   1834 finally:
   1835     if profile and fn:

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map)
   1714 start_import_time = pytensor.link.c.cmodule.import_time
   1716 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1717     _fn, _i, _o = self.linker.make_thunk(
   1718         input_storage=input_storage_lists, storage_map=storage_map
   1719     )
   1721 end_linker = time.perf_counter()
   1723 linker_time = end_linker - start_linker

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
      7 def fgraph_convert(self, fgraph, **kwargs):
      8     from pytensor.link.numba.dispatch import numba_funcify
---> 10     return numba_funcify(fgraph, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:380, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    373 @numba_funcify.register(FunctionGraph)
    374 def numba_funcify_FunctionGraph(
    375     fgraph,
   (...)    378     **kwargs,
    379 ):
--> 380     return fgraph_to_python(
    381         fgraph,
    382         numba_funcify,
    383         type_conversion_fn=numba_typify,
    384         fgraph_name=fgraph_name,
    385         **kwargs,
    386     )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:401, in numba_funcify_RandomVariable(op, node, **kwargs)
    398 core_shape_len = get_vector_length(core_shape)
    399 inplace = rv_op.inplace
--> 401 core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
    402 nin = 1 + len(dist_params)  # rng + params
    403 core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:47, in numba_core_rv_funcify(op, node)
     44 @singledispatch
     45 def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
     46     """Return the core function for a random variable operation."""
---> 47     raise NotImplementedError(f"Core implementation of {op} not implemented.")

NotImplementedError: Core implementation of t_rv{"(),(),()->()"} not implemented.

Interestingly, it works fine if you change total_size = None.

@jessegrabowski jessegrabowski added bug VI Variational Inference numba labels May 8, 2025
@ricardoV94
Copy link
Member

I think we simply don't have a dispatch to StudentT in numba

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug numba VI Variational Inference
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants