@@ -7,56 +7,96 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im
7
7
This tutorial will explain how JAX and Numba implementations are created for an :class: `Op `. It will
8
8
focus specifically on the JAX case, but the same mechanisms are used for Numba as well.
9
9
10
- Step 1: Identify the PyTensor :class: `Op ` you’ d like to implement in JAX
10
+ Step 1: Identify the PyTensor :class: `Op ` you' d like to implement in JAX
11
11
------------------------------------------------------------------------
12
12
13
- Find the source for the PyTensor :class: `Op ` you’ d like to be supported in JAX, and
14
- identify the function signature and return values. These can be determined by
15
- looking at the :meth: `Op.make_node ` implementation. In general, one needs to be familiar
13
+ Find the source for the PyTensor :class: `Op ` you' d like to be supported in JAX, and
14
+ identify the function signature and return values. These can be determined by
15
+ looking at the :meth: `Op.make_node ` implementation. In general, one needs to be familiar
16
16
with PyTensor :class: `Op `\s in order to provide a conversion implementation, so first read
17
17
:ref: `creating_an_op ` if you are not familiar.
18
18
19
- For example, the :class: ` Eye ` \ :class: ` Op ` current has an :meth: ` Op.make_node ` as follows :
19
+ For example, you want to extend support for :class: ` CumsumOp ` \ :
20
20
21
21
.. code :: python
22
22
23
- def make_node (self , n , m , k ):
24
- n = as_tensor_variable(n)
25
- m = as_tensor_variable(m)
26
- k = as_tensor_variable(k)
27
- assert n.ndim == 0
28
- assert m.ndim == 0
29
- assert k.ndim == 0
30
- return Apply(
31
- self ,
32
- [n, m, k],
33
- [TensorType(dtype = self .dtype, shape = (None , None ))()],
34
- )
23
+ class CumsumOp (Op ):
24
+ __props__ = (" axis" ,)
25
+
26
+ def __new__ (typ , * args , ** kwargs ):
27
+ obj = object .__new__ (CumOp, * args, ** kwargs)
28
+ obj.mode = " add"
29
+ return obj
30
+
31
+
32
+ :class: `CumsumOp ` turns out to be a variant of :class: `CumOp `\ :class: `Op `
33
+ which currently has an :meth: `Op.make_node ` as follows:
34
+
35
+ .. code :: python
35
36
37
+ def make_node (self , x ):
38
+ x = ptb.as_tensor_variable(x)
39
+ out_type = x.type()
40
+
41
+ if self .axis is None :
42
+ out_type = vector(dtype = x.dtype) # Flatten
43
+ elif self .axis >= x.ndim or self .axis < - x.ndim:
44
+ raise ValueError (f " axis(= { self .axis} ) out of bounds " )
45
+
46
+ return Apply(self , [x], [out_type])
36
47
37
48
The :class: `Apply ` instance that's returned specifies the exact types of inputs that
38
49
our JAX implementation will receive and the exact types of outputs it's expected to
39
- return--both in terms of their data types and number of dimensions.
50
+ return--both in terms of their data types and number of dimensions/shapes .
40
51
The actual inputs our implementation will receive are necessarily numeric values
41
52
or NumPy :class: `ndarray `\s ; all that :meth: `Op.make_node ` tells us is the
42
53
general signature of the underlying computation.
43
54
44
- More specifically, the :class: `Apply ` implies that the inputs come from values that are
45
- automatically converted to PyTensor variables via :func: `as_tensor_variable `, and
46
- the ``assert ``\s that follow imply that they must be scalars. According to this
47
- logic, the inputs could have any data type (e.g. floats, ints), so our JAX
48
- implementation must be able to handle all the possible data types.
55
+ More specifically, the :class: `Apply ` implies that there is one input that is
56
+ automatically converted to PyTensor variables via :func: `as_tensor_variable `.
57
+ There is another parameter, `axis `, that is used to determine the direction
58
+ of the operation, hence shape of the output. The check that follows imply that
59
+ `axis ` must refer to a dimension in the input tensor. The input's elements
60
+ could also have any data type (e.g. floats, ints), so our JAX implementation
61
+ must be able to handle all the possible data types.
49
62
50
63
It also tells us that there's only one return value, that it has a data type
51
- determined by :attr: `Eye.dtype `, and that it has two non-broadcastable
52
- dimensions. The latter implies that the result is necessarily a matrix. The
53
- former implies that our JAX implementation will need to access the :attr: `dtype `
54
- attribute of the PyTensor :class: `Eye `\ :class: `Op ` it's converting.
64
+ determined by :meth: `x.type() ` i.e., the data type of the original tensor.
65
+ This implies that the result is necessarily a matrix.
55
66
56
- Next, we can look at the :meth: `Op.perform ` implementation to see exactly
57
- how the inputs and outputs are used to compute the outputs for an :class: `Op `
58
- in Python. This method is effectively what needs to be implemented in JAX.
67
+ Some class may have a more complex behavior. For example, the :class: `CumOp `\ :class: `Op `
68
+ also has another variant :class: `CumprodOp `\ :class: `Op ` with the exact signature
69
+ as :class: `CumsumOp `\ :class: `Op `. The difference lies in that the `mode ` attribute in
70
+ :class: `CumOp ` definition:
71
+
72
+ .. code :: python
59
73
74
+ class CumOp (COp ):
75
+ # See function cumsum/cumprod for docstring
76
+
77
+ __props__ = (" axis" , " mode" )
78
+ check_input = False
79
+ params_type = ParamsType(
80
+ c_axis = int_t, mode = EnumList((" MODE_ADD" , " add" ), (" MODE_MUL" , " mul" ))
81
+ )
82
+
83
+ def __init__ (self , axis : int | None = None , mode = " add" ):
84
+ if mode not in (" add" , " mul" ):
85
+ raise ValueError (f ' { type (self ).__name__ } : Unknown mode " { mode} " ' )
86
+ self .axis = axis
87
+ self .mode = mode
88
+
89
+ c_axis = property (lambda self : np.MAXDIMS if self .axis is None else self .axis)
90
+
91
+ `__props__ ` is used to parametrize the general behavior of the :class: `Op `. One need to
92
+ pay attention to this to decide whether the JAX implementation should support all variants
93
+ or raise an explicit NotImplementedError for cases that are not supported e.g., when
94
+ :class: `CumsumOp ` of :class: `CumOp("add") ` is supported but not :class: `CumprodOp ` of
95
+ :class: `CumOp("mul") `.
96
+
97
+ Next, we look at the :meth: `Op.perform ` implementation to see exactly
98
+ how the inputs and outputs are used to compute the outputs for an :class: `Op `
99
+ in Python. This method is effectively what needs to be implemented in JAX.
60
100
61
101
Step 2: Find the relevant JAX method (or something close)
62
102
---------------------------------------------------------
@@ -82,47 +122,83 @@ Here's an example for :class:`IfElse`:
82
122
)
83
123
return res if n_outs > 1 else res[0 ]
84
124
125
+ In this case, :class: `CumOp ` is implemented with NumPy's :func: `numpy.cumsum `
126
+ and :func: `numpy.cumprod `, which have JAX equivalents: :func: `jax.numpy.cumsum `
127
+ and :func: `jax.numpy.cumprod `.
128
+
129
+ .. code :: python
130
+
131
+ def perform (self , node , inputs , output_storage ):
132
+ x = inputs[0 ]
133
+ z = output_storage[0 ]
134
+ if self .mode == " add" :
135
+ z[0 ] = np.cumsum(x, axis = self .axis)
136
+ else :
137
+ z[0 ] = np.cumprod(x, axis = self .axis)
85
138
86
139
Step 3: Register the function with the `jax_funcify ` dispatcher
87
140
---------------------------------------------------------------
88
141
89
- With the PyTensor `Op ` replicated in JAX, we’ ll need to register the
142
+ With the PyTensor `Op ` replicated in JAX, we' ll need to register the
90
143
function with the PyTensor JAX `Linker `. This is done through the use of
91
144
`singledispatch `. If you don't know how `singledispatch ` works, see the
92
145
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch >`_.
93
146
94
147
The relevant dispatch functions created by `singledispatch ` are :func: `pytensor.link.numba.dispatch.numba_funcify ` and
95
148
:func: `pytensor.link.jax.dispatch.jax_funcify `.
96
149
97
- Here’ s an example for the `Eye `\ `Op `:
150
+ Here' s an example for the `CumOp `\ `Op `:
98
151
99
152
.. code :: python
100
153
101
154
import jax.numpy as jnp
102
155
103
- from pytensor.tensor.basic import Eye
156
+ from pytensor.tensor.extra_ops import CumOp
104
157
from pytensor.link.jax.dispatch import jax_funcify
105
158
106
159
107
- @jax_funcify.register (Eye)
108
- def jax_funcify_Eye (op ):
160
+ @jax_funcify.register (CumOp)
161
+ def jax_funcify_CumOp (op , ** kwargs ):
162
+ axis = op.axis
163
+ mode = op.mode
109
164
110
- # Obtain necessary "static" attributes from the Op being converted
111
- dtype = op.dtype
165
+ def cumop (x , axis = axis, mode = mode):
166
+ if mode == " add" :
167
+ return jnp.cumsum(x, axis = axis)
168
+ else :
169
+ return jnp.cumprod(x, axis = axis)
112
170
113
- # Create a JAX jit-able function that implements the Op
114
- def eye (N , M , k ):
115
- return jnp.eye(N, M, k, dtype = dtype)
171
+ return cumop
116
172
117
- return eye
173
+ Suppose `jnp.cumprod ` does not exist, we will need to register the function as follows:
174
+
175
+ .. code :: python
176
+
177
+ import jax.numpy as jnp
178
+
179
+ from pytensor.tensor.extra_ops import CumOp
180
+ from pytensor.link.jax.dispatch import jax_funcify
118
181
119
182
183
+ @jax_funcify.register (CumOp)
184
+ def jax_funcify_CumOp (op , ** kwargs ):
185
+ axis = op.axis
186
+ mode = op.mode
187
+
188
+ def cumop (x , axis = axis, mode = mode):
189
+ if mode == " add" :
190
+ return jnp.cumsum(x, axis = axis)
191
+ else :
192
+ raise NotImplementedError (" JAX does not support cumprod function at the moment." )
193
+
194
+ return cumop
195
+
120
196
Step 4: Write tests
121
197
-------------------
122
198
123
199
Test that your registered `Op ` is working correctly by adding tests to the
124
- appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax `` and one of
125
- the modules in ``tests.link.numba.dispatch ``). The tests should ensure that your implementation can
200
+ appropriate test suites in PyTensor (e.g. in ``tests.link.jax `` and one of
201
+ the modules in ``tests.link.numba ``). The tests should ensure that your implementation can
126
202
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform `.
127
203
Check the existing tests for the general outline of these kinds of tests. In
128
204
most cases, a helper function can be used to easily verify the correspondence
@@ -131,23 +207,79 @@ between a JAX/Numba implementation and its `Op`.
131
207
For example, the :func: `compare_jax_and_py ` function streamlines the steps
132
208
involved in making comparisons with `Op.perform `.
133
209
134
- Here's a small example of a test for :class: `Eye ` :
210
+ Here's a small example of a test for :class: `CumOp ` above :
135
211
136
212
.. code :: python
213
+
214
+ import numpy as np
215
+ import pytensor.tensor as pt
216
+ from pytensor.configdefaults import config
217
+ from tests.link.jax.test_basic import compare_jax_and_py
218
+ from pytensor.graph import FunctionGraph
219
+ from pytensor.graph.op import get_test_value
220
+
221
+ def test_jax_CumOp ():
222
+ """ Test JAX conversion of the `CumOp` `Op`."""
223
+
224
+ # Create a symbolic input for the first input of `CumOp`
225
+ a = pt.matrix(" a" )
137
226
138
- import pytensor.tensor as pt
227
+ # Create test value tag for a
228
+ a.tag.test_value = np.arange(9 , dtype = config.floatX).reshape((3 , 3 ))
229
+
230
+ # Create the output variable
231
+ out = pt.cumsum(a, axis = 0 )
232
+
233
+ # Create a PyTensor `FunctionGraph`
234
+ fgraph = FunctionGraph([a], [out])
235
+
236
+ # Pass the graph and inputs to the testing function
237
+ compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
238
+
239
+ # For the second mode of CumOp
240
+ out = pt.cumprod(a, axis = 1 )
241
+ fgraph = FunctionGraph([a], [out])
242
+ compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
243
+
244
+ If the variant :class: `CumprodOp ` is not implemented, we can add a test for it as follows:
245
+
246
+ .. code :: python
247
+
248
+ import pytest
249
+
250
+ def test_jax_CumOp ():
251
+ """ Test JAX conversion of the `CumOp` `Op`."""
252
+ a = pt.matrix(" a" )
253
+ a.tag.test_value = np.arange(9 , dtype = config.floatX).reshape((3 , 3 ))
254
+
255
+ with pytest.raises(NotImplementedError ):
256
+ out = pt.cumprod(a, axis = 1 )
257
+ fgraph = FunctionGraph([a], [out])
258
+ compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
259
+
260
+ Note
261
+ ----
262
+ In out previous example of extending JAX, :class: `Eye `\ :class: `Op ` was used with the test function as follows:
263
+
264
+ .. code :: python
265
+ def test_jax_Eye ():
266
+ """ Test JAX conversion of the `Eye` `Op`."""
139
267
140
- def test_jax_Eye ():
141
- """ Test JAX conversion of the `Eye` `Op`. """
268
+ # Create a symbolic input for `Eye`
269
+ x_at = pt.scalar()
142
270
143
- # Create a symbolic input for `Eye`
144
- x_at = pt.scalar( )
271
+ # Create a variable that is the output of an `Eye` `Op `
272
+ eye_var = pt.eye(x_at )
145
273
146
- # Create a variable that is the output of an `Eye` `Op `
147
- eye_var = pt.eye(x_at )
274
+ # Create an PyTensor `FunctionGraph `
275
+ out_fg = FunctionGraph( outputs = [eye_var] )
148
276
149
- # Create an PyTensor `FunctionGraph`
150
- out_fg = FunctionGraph( outputs = [eye_var ])
277
+ # Pass the graph and any inputs to the testing function
278
+ compare_jax_and_py(out_fg, [ 3 ])
151
279
152
- # Pass the graph and any inputs to the testing function
153
- compare_jax_and_py(out_fg, [3 ])
280
+ This one nowadays leads to a test failure due to new restrictions in JAX + JIT,
281
+ as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654 >`_.
282
+ All jitted functions now must have constant shape, which means a graph like the
283
+ one of :class: `Eye ` can never be translated to JAX, since it's fundamentally a
284
+ function with dynamic shapes. In other words, only PyTensor graphs with static shapes
285
+ can be translated to JAX at the moment.
0 commit comments