Skip to content

Commit ca7e8b8

Browse files
authored
Update example in "Adding JAX and Numba support for Ops" (#687)
* Changed example for extending JAX. * Changed example for extending JAX. * Use CumOp in the example * Delete new_jax_example.ipynb
1 parent 30b760f commit ca7e8b8

File tree

1 file changed

+187
-55
lines changed

1 file changed

+187
-55
lines changed

doc/extending/creating_a_numba_jax_op.rst

+187-55
Original file line numberDiff line numberDiff line change
@@ -7,56 +7,96 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im
77
This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will
88
focus specifically on the JAX case, but the same mechanisms are used for Numba as well.
99

10-
Step 1: Identify the PyTensor :class:`Op` youd like to implement in JAX
10+
Step 1: Identify the PyTensor :class:`Op` you'd like to implement in JAX
1111
------------------------------------------------------------------------
1212

13-
Find the source for the PyTensor :class:`Op` youd like to be supported in JAX, and
14-
identify the function signature and return values. These can be determined by
15-
looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar
13+
Find the source for the PyTensor :class:`Op` you'd like to be supported in JAX, and
14+
identify the function signature and return values. These can be determined by
15+
looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar
1616
with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read
1717
:ref:`creating_an_op` if you are not familiar.
1818

19-
For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows:
19+
For example, you want to extend support for :class:`CumsumOp`\:
2020

2121
.. code:: python
2222
23-
def make_node(self, n, m, k):
24-
n = as_tensor_variable(n)
25-
m = as_tensor_variable(m)
26-
k = as_tensor_variable(k)
27-
assert n.ndim == 0
28-
assert m.ndim == 0
29-
assert k.ndim == 0
30-
return Apply(
31-
self,
32-
[n, m, k],
33-
[TensorType(dtype=self.dtype, shape=(None, None))()],
34-
)
23+
class CumsumOp(Op):
24+
__props__ = ("axis",)
25+
26+
def __new__(typ, *args, **kwargs):
27+
obj = object.__new__(CumOp, *args, **kwargs)
28+
obj.mode = "add"
29+
return obj
30+
31+
32+
:class:`CumsumOp` turns out to be a variant of :class:`CumOp`\ :class:`Op`
33+
which currently has an :meth:`Op.make_node` as follows:
34+
35+
.. code:: python
3536
37+
def make_node(self, x):
38+
x = ptb.as_tensor_variable(x)
39+
out_type = x.type()
40+
41+
if self.axis is None:
42+
out_type = vector(dtype=x.dtype) # Flatten
43+
elif self.axis >= x.ndim or self.axis < -x.ndim:
44+
raise ValueError(f"axis(={self.axis}) out of bounds")
45+
46+
return Apply(self, [x], [out_type])
3647
3748
The :class:`Apply` instance that's returned specifies the exact types of inputs that
3849
our JAX implementation will receive and the exact types of outputs it's expected to
39-
return--both in terms of their data types and number of dimensions.
50+
return--both in terms of their data types and number of dimensions/shapes.
4051
The actual inputs our implementation will receive are necessarily numeric values
4152
or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the
4253
general signature of the underlying computation.
4354

44-
More specifically, the :class:`Apply` implies that the inputs come from values that are
45-
automatically converted to PyTensor variables via :func:`as_tensor_variable`, and
46-
the ``assert``\s that follow imply that they must be scalars. According to this
47-
logic, the inputs could have any data type (e.g. floats, ints), so our JAX
48-
implementation must be able to handle all the possible data types.
55+
More specifically, the :class:`Apply` implies that there is one input that is
56+
automatically converted to PyTensor variables via :func:`as_tensor_variable`.
57+
There is another parameter, `axis`, that is used to determine the direction
58+
of the operation, hence shape of the output. The check that follows imply that
59+
`axis` must refer to a dimension in the input tensor. The input's elements
60+
could also have any data type (e.g. floats, ints), so our JAX implementation
61+
must be able to handle all the possible data types.
4962

5063
It also tells us that there's only one return value, that it has a data type
51-
determined by :attr:`Eye.dtype`, and that it has two non-broadcastable
52-
dimensions. The latter implies that the result is necessarily a matrix. The
53-
former implies that our JAX implementation will need to access the :attr:`dtype`
54-
attribute of the PyTensor :class:`Eye`\ :class:`Op` it's converting.
64+
determined by :meth:`x.type()` i.e., the data type of the original tensor.
65+
This implies that the result is necessarily a matrix.
5566

56-
Next, we can look at the :meth:`Op.perform` implementation to see exactly
57-
how the inputs and outputs are used to compute the outputs for an :class:`Op`
58-
in Python. This method is effectively what needs to be implemented in JAX.
67+
Some class may have a more complex behavior. For example, the :class:`CumOp`\ :class:`Op`
68+
also has another variant :class:`CumprodOp`\ :class:`Op` with the exact signature
69+
as :class:`CumsumOp`\ :class:`Op`. The difference lies in that the `mode` attribute in
70+
:class:`CumOp` definition:
71+
72+
.. code:: python
5973
74+
class CumOp(COp):
75+
# See function cumsum/cumprod for docstring
76+
77+
__props__ = ("axis", "mode")
78+
check_input = False
79+
params_type = ParamsType(
80+
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
81+
)
82+
83+
def __init__(self, axis: int | None = None, mode="add"):
84+
if mode not in ("add", "mul"):
85+
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
86+
self.axis = axis
87+
self.mode = mode
88+
89+
c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis)
90+
91+
`__props__` is used to parametrize the general behavior of the :class:`Op`. One need to
92+
pay attention to this to decide whether the JAX implementation should support all variants
93+
or raise an explicit NotImplementedError for cases that are not supported e.g., when
94+
:class:`CumsumOp` of :class:`CumOp("add")` is supported but not :class:`CumprodOp` of
95+
:class:`CumOp("mul")`.
96+
97+
Next, we look at the :meth:`Op.perform` implementation to see exactly
98+
how the inputs and outputs are used to compute the outputs for an :class:`Op`
99+
in Python. This method is effectively what needs to be implemented in JAX.
60100

61101
Step 2: Find the relevant JAX method (or something close)
62102
---------------------------------------------------------
@@ -82,47 +122,83 @@ Here's an example for :class:`IfElse`:
82122
)
83123
return res if n_outs > 1 else res[0]
84124
125+
In this case, :class:`CumOp` is implemented with NumPy's :func:`numpy.cumsum`
126+
and :func:`numpy.cumprod`, which have JAX equivalents: :func:`jax.numpy.cumsum`
127+
and :func:`jax.numpy.cumprod`.
128+
129+
.. code:: python
130+
131+
def perform(self, node, inputs, output_storage):
132+
x = inputs[0]
133+
z = output_storage[0]
134+
if self.mode == "add":
135+
z[0] = np.cumsum(x, axis=self.axis)
136+
else:
137+
z[0] = np.cumprod(x, axis=self.axis)
85138
86139
Step 3: Register the function with the `jax_funcify` dispatcher
87140
---------------------------------------------------------------
88141

89-
With the PyTensor `Op` replicated in JAX, well need to register the
142+
With the PyTensor `Op` replicated in JAX, we'll need to register the
90143
function with the PyTensor JAX `Linker`. This is done through the use of
91144
`singledispatch`. If you don't know how `singledispatch` works, see the
92145
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_.
93146

94147
The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and
95148
:func:`pytensor.link.jax.dispatch.jax_funcify`.
96149

97-
Heres an example for the `Eye`\ `Op`:
150+
Here's an example for the `CumOp`\ `Op`:
98151

99152
.. code:: python
100153
101154
import jax.numpy as jnp
102155
103-
from pytensor.tensor.basic import Eye
156+
from pytensor.tensor.extra_ops import CumOp
104157
from pytensor.link.jax.dispatch import jax_funcify
105158
106159
107-
@jax_funcify.register(Eye)
108-
def jax_funcify_Eye(op):
160+
@jax_funcify.register(CumOp)
161+
def jax_funcify_CumOp(op, **kwargs):
162+
axis = op.axis
163+
mode = op.mode
109164
110-
# Obtain necessary "static" attributes from the Op being converted
111-
dtype = op.dtype
165+
def cumop(x, axis=axis, mode=mode):
166+
if mode == "add":
167+
return jnp.cumsum(x, axis=axis)
168+
else:
169+
return jnp.cumprod(x, axis=axis)
112170
113-
# Create a JAX jit-able function that implements the Op
114-
def eye(N, M, k):
115-
return jnp.eye(N, M, k, dtype=dtype)
171+
return cumop
116172
117-
return eye
173+
Suppose `jnp.cumprod` does not exist, we will need to register the function as follows:
174+
175+
.. code:: python
176+
177+
import jax.numpy as jnp
178+
179+
from pytensor.tensor.extra_ops import CumOp
180+
from pytensor.link.jax.dispatch import jax_funcify
118181
119182
183+
@jax_funcify.register(CumOp)
184+
def jax_funcify_CumOp(op, **kwargs):
185+
axis = op.axis
186+
mode = op.mode
187+
188+
def cumop(x, axis=axis, mode=mode):
189+
if mode == "add":
190+
return jnp.cumsum(x, axis=axis)
191+
else:
192+
raise NotImplementedError("JAX does not support cumprod function at the moment.")
193+
194+
return cumop
195+
120196
Step 4: Write tests
121197
-------------------
122198

123199
Test that your registered `Op` is working correctly by adding tests to the
124-
appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax`` and one of
125-
the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can
200+
appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of
201+
the modules in ``tests.link.numba``). The tests should ensure that your implementation can
126202
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
127203
Check the existing tests for the general outline of these kinds of tests. In
128204
most cases, a helper function can be used to easily verify the correspondence
@@ -131,23 +207,79 @@ between a JAX/Numba implementation and its `Op`.
131207
For example, the :func:`compare_jax_and_py` function streamlines the steps
132208
involved in making comparisons with `Op.perform`.
133209

134-
Here's a small example of a test for :class:`Eye`:
210+
Here's a small example of a test for :class:`CumOp` above:
135211

136212
.. code:: python
213+
214+
import numpy as np
215+
import pytensor.tensor as pt
216+
from pytensor.configdefaults import config
217+
from tests.link.jax.test_basic import compare_jax_and_py
218+
from pytensor.graph import FunctionGraph
219+
from pytensor.graph.op import get_test_value
220+
221+
def test_jax_CumOp():
222+
"""Test JAX conversion of the `CumOp` `Op`."""
223+
224+
# Create a symbolic input for the first input of `CumOp`
225+
a = pt.matrix("a")
137226
138-
import pytensor.tensor as pt
227+
# Create test value tag for a
228+
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
229+
230+
# Create the output variable
231+
out = pt.cumsum(a, axis=0)
232+
233+
# Create a PyTensor `FunctionGraph`
234+
fgraph = FunctionGraph([a], [out])
235+
236+
# Pass the graph and inputs to the testing function
237+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
238+
239+
# For the second mode of CumOp
240+
out = pt.cumprod(a, axis=1)
241+
fgraph = FunctionGraph([a], [out])
242+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
243+
244+
If the variant :class:`CumprodOp` is not implemented, we can add a test for it as follows:
245+
246+
.. code:: python
247+
248+
import pytest
249+
250+
def test_jax_CumOp():
251+
"""Test JAX conversion of the `CumOp` `Op`."""
252+
a = pt.matrix("a")
253+
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
254+
255+
with pytest.raises(NotImplementedError):
256+
out = pt.cumprod(a, axis=1)
257+
fgraph = FunctionGraph([a], [out])
258+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
259+
260+
Note
261+
----
262+
In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows:
263+
264+
.. code:: python
265+
def test_jax_Eye():
266+
"""Test JAX conversion of the `Eye` `Op`."""
139267
140-
def test_jax_Eye():
141-
"""Test JAX conversion of the `Eye` `Op`."""
268+
# Create a symbolic input for `Eye`
269+
x_at = pt.scalar()
142270
143-
# Create a symbolic input for `Eye`
144-
x_at = pt.scalar()
271+
# Create a variable that is the output of an `Eye` `Op`
272+
eye_var = pt.eye(x_at)
145273
146-
# Create a variable that is the output of an `Eye` `Op`
147-
eye_var = pt.eye(x_at)
274+
# Create an PyTensor `FunctionGraph`
275+
out_fg = FunctionGraph(outputs=[eye_var])
148276
149-
# Create an PyTensor `FunctionGraph`
150-
out_fg = FunctionGraph(outputs=[eye_var])
277+
# Pass the graph and any inputs to the testing function
278+
compare_jax_and_py(out_fg, [3])
151279
152-
# Pass the graph and any inputs to the testing function
153-
compare_jax_and_py(out_fg, [3])
280+
This one nowadays leads to a test failure due to new restrictions in JAX + JIT,
281+
as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_.
282+
All jitted functions now must have constant shape, which means a graph like the
283+
one of :class:`Eye` can never be translated to JAX, since it's fundamentally a
284+
function with dynamic shapes. In other words, only PyTensor graphs with static shapes
285+
can be translated to JAX at the moment.

0 commit comments

Comments
 (0)