From c06b3b01f0c0f54207bd4b8e2a4074a7f8c641a6 Mon Sep 17 00:00:00 2001 From: Justin Kunimune Date: Sat, 16 Mar 2024 14:42:37 -0400 Subject: [PATCH] Update example in creating_a_numba_jax_op This example as it was threw a type error because the registered function is called with `node` and `output_storage` keyword arguments, which were not included in the function signature. Adding a `**kwargs` argument resolves this, and is also what the existing Jaxified Ops all seem to do. --- doc/extending/creating_a_numba_jax_op.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index abf3f528bf..4779a0ac38 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -105,7 +105,7 @@ Here’s an example for the `Eye`\ `Op`: @jax_funcify.register(Eye) - def jax_funcify_Eye(op): + def jax_funcify_Eye(op, **kwargs): # Obtain necessary "static" attributes from the Op being converted dtype = op.dtype