We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
StudentT
total_size
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import pymc as pm import numpy as np rng = np.random.default_rng() with pm.Model() as m: data = pm.Data('data', rng.normal(size=(1000, 5))) obs = pm.Data('obs', rng.normal(size=(1000,))) data_batch, obs_batch = pm.Minibatch(data, obs, batch_size=128) beta = pm.Normal('beta', size=(5,)) mu = data_batch @ beta sigma = pm.Exponential('sigma', 1) y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000) idata = pm.fit(n=1_000_000, compile_kwargs={'mode':'NUMBA'})
--------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) Cell In[1], line 18 14 sigma = pm.Exponential('sigma', 1) 16 y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000) ---> 18 idata = pm.fit(n=1_000_000, 19 compile_kwargs={'mode':'NUMBA'}) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:775, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs) 773 else: 774 raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance") --> 775 return inference.fit(n, **kwargs) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:158, in Inference.fit(self, n, score, callbacks, progressbar, progressbar_theme, **kwargs) 156 callbacks = [] 157 score = self._maybe_score(score) --> 158 step_func = self.objective.step_function(score=score, **kwargs) 160 if score: 161 state = self._iterate_with_loss( 162 0, n, step_func, progressbar, progressbar_theme, callbacks 163 ) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs) 41 @wraps(f) 42 def res(*args, **kwargs): 43 with self: ---> 44 return f(*args, **kwargs) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/opvi.py:405, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, compile_kwargs, fn_kwargs) 403 seed = self.approx.rng.randint(2**30, dtype=np.int64) 404 if score: --> 405 step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs) 406 else: 407 step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/pytensorf.py:947, in compile(inputs, outputs, random_seed, mode, **kwargs) 945 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt) 946 mode = Mode(linker=mode.linker, optimizer=opt_qry) --> 947 pytensor_function = pytensor.function( 948 inputs, 949 outputs, 950 updates={**rng_updates, **kwargs.pop("updates", {})}, 951 mode=mode, 952 **kwargs, 953 ) 954 return pytensor_function File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input) 321 fn = orig_function( 322 inputs, 323 outputs, (...) 327 trust_input=trust_input, 328 ) 329 else: 330 # note: pfunc will also call orig_function -- orig_function is 331 # a choke point that all compilation must pass through --> 332 fn = pfunc( 333 params=inputs, 334 outputs=outputs, 335 mode=mode, 336 updates=updates, 337 givens=givens, 338 no_default_updates=no_default_updates, 339 accept_inplace=accept_inplace, 340 name=name, 341 rebuild_strict=rebuild_strict, 342 allow_input_downcast=allow_input_downcast, 343 on_unused_input=on_unused_input, 344 profile=profile, 345 output_keys=output_keys, 346 trust_input=trust_input, 347 ) 348 return fn File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input) 452 profile = ProfileStats(message=profile) 454 inputs, cloned_outputs = construct_pfunc_ins_and_outs( 455 params, 456 outputs, (...) 463 fgraph=fgraph, 464 ) --> 466 return orig_function( 467 inputs, 468 cloned_outputs, 469 mode, 470 accept_inplace=accept_inplace, 471 name=name, 472 profile=profile, 473 on_unused_input=on_unused_input, 474 output_keys=output_keys, 475 fgraph=fgraph, 476 trust_input=trust_input, 477 ) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input) 1820 m = Maker( 1821 inputs, 1822 outputs, (...) 1830 trust_input=trust_input, 1831 ) 1832 with config.change_flags(compute_test_value="off"): -> 1833 fn = m.create(defaults) 1834 finally: 1835 if profile and fn: File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map) 1714 start_import_time = pytensor.link.c.cmodule.import_time 1716 with config.change_flags(traceback__limit=config.traceback__compile_limit): -> 1717 _fn, _i, _o = self.linker.make_thunk( 1718 input_storage=input_storage_lists, storage_map=storage_map 1719 ) 1721 end_linker = time.perf_counter() 1723 linker_time = end_linker - start_linker File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs) 238 def make_thunk( 239 self, 240 input_storage: Optional["InputStorageType"] = None, (...) 243 **kwargs, 244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]: --> 245 return self.make_all( 246 input_storage=input_storage, 247 output_storage=output_storage, 248 storage_map=storage_map, 249 )[:3] File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map) 692 for k in storage_map: 693 compute_map[k] = [k.owner is None] --> 695 thunks, nodes, jit_fn = self.create_jitable_thunk( 696 compute_map, nodes, input_storage, output_storage, storage_map 697 ) 699 [fn] = thunks 700 fn.jit_fn = jit_fn File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map) 644 # This is a bit hackish, but we only return one of the output nodes 645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1] --> 647 converted_fgraph = self.fgraph_convert( 648 self.fgraph, 649 order=order, 650 input_storage=input_storage, 651 output_storage=output_storage, 652 storage_map=storage_map, 653 ) 655 thunk_inputs = self.create_thunk_inputs(storage_map) 656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs] File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs) 7 def fgraph_convert(self, fgraph, **kwargs): 8 from pytensor.link.numba.dispatch import numba_funcify ---> 10 return numba_funcify(fgraph, **kwargs) File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw) 908 if not args: 909 raise TypeError(f'{funcname} requires at least ' 910 '1 positional argument') --> 912 return dispatch(args[0].__class__)(*args, **kw) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:380, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs) 373 @numba_funcify.register(FunctionGraph) 374 def numba_funcify_FunctionGraph( 375 fgraph, (...) 378 **kwargs, 379 ): --> 380 return fgraph_to_python( 381 fgraph, 382 numba_funcify, 383 type_conversion_fn=numba_typify, 384 fgraph_name=fgraph_name, 385 **kwargs, 386 ) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs) 734 body_assigns = [] 735 for node in order: --> 736 compiled_func = op_conversion_fn( 737 node.op, node=node, storage_map=storage_map, **kwargs 738 ) 740 # Create a local alias with a unique name 741 local_compiled_func_name = unique_name(compiled_func) File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw) 908 if not args: 909 raise TypeError(f'{funcname} requires at least ' 910 '1 positional argument') --> 912 return dispatch(args[0].__class__)(*args, **kw) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:401, in numba_funcify_RandomVariable(op, node, **kwargs) 398 core_shape_len = get_vector_length(core_shape) 399 inplace = rv_op.inplace --> 401 core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) 402 nin = 1 + len(dist_params) # rng + params 403 core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw) 908 if not args: 909 raise TypeError(f'{funcname} requires at least ' 910 '1 positional argument') --> 912 return dispatch(args[0].__class__)(*args, **kw) File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:47, in numba_core_rv_funcify(op, node) 44 @singledispatch 45 def numba_core_rv_funcify(op: Op, node: Apply) -> Callable: 46 """Return the core function for a random variable operation.""" ---> 47 raise NotImplementedError(f"Core implementation of {op} not implemented.") NotImplementedError: Core implementation of t_rv{"(),(),()->()"} not implemented.
Interestingly, it works fine if you change total_size = None.
total_size = None
The text was updated successfully, but these errors were encountered:
I think we simply don't have a dispatch to StudentT in numba
Sorry, something went wrong.
Successfully merging a pull request may close this issue.
Uh oh!
There was an error while loading. Please reload this page.
Description
Full Traceback
Interestingly, it works fine if you change
total_size = None
.The text was updated successfully, but these errors were encountered: