Skip to content

Commit 47f7637

Browse files
rloufricardoV94
authored andcommitted
Dispatch some Ops to Python operators when scalar inputs
1 parent efb4996 commit 47f7637

File tree

2 files changed

+129
-2
lines changed

2 files changed

+129
-2
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,57 @@
55

66
from pytensor.link.jax.dispatch.basic import jax_funcify
77
from pytensor.scalar import Softplus
8-
from pytensor.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
8+
from pytensor.scalar.basic import (
9+
Add,
10+
Cast,
11+
Clip,
12+
Composite,
13+
Identity,
14+
IntDiv,
15+
Mod,
16+
Mul,
17+
ScalarOp,
18+
Second,
19+
Sub,
20+
)
921
from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
1022

1123

24+
def check_if_inputs_scalars(node):
25+
"""Check whether all the inputs of an `Elemwise` are scalar values.
26+
27+
`jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
28+
while the corresponding Python operators return concrete values when passed
29+
concrete values. In order to be able to compile the largest number of graphs
30+
possible we need to preserve concrete values whenever we can. We thus need
31+
to dispatch differently the PyTensor operators depending on whether the inputs
32+
are scalars.
33+
34+
"""
35+
ndims_input = [inp.type.ndim for inp in node.inputs]
36+
are_inputs_scalars = True
37+
for ndim in ndims_input:
38+
try:
39+
if ndim > 0:
40+
are_inputs_scalars = False
41+
except TypeError:
42+
are_inputs_scalars = False
43+
44+
return are_inputs_scalars
45+
46+
1247
@jax_funcify.register(ScalarOp)
13-
def jax_funcify_ScalarOp(op, **kwargs):
48+
def jax_funcify_ScalarOp(op, node, **kwargs):
1449
func_name = op.nfunc_spec[0]
1550

51+
# We dispatch some PyTensor operators to Python operators
52+
# whenever the inputs are all scalars.
53+
are_inputs_scalars = check_if_inputs_scalars(node)
54+
if are_inputs_scalars:
55+
elemwise = elemwise_scalar(op)
56+
if elemwise is not None:
57+
return elemwise
58+
1659
if "." in func_name:
1760
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
1861
else:
@@ -38,6 +81,54 @@ def elemwise(*args):
3881
return jnp_func
3982

4083

84+
@functools.singledispatch
85+
def elemwise_scalar(op):
86+
return None
87+
88+
89+
@elemwise_scalar.register(Add)
90+
def elemwise_scalar_add(op):
91+
def elemwise(*inputs):
92+
return sum(inputs)
93+
94+
return elemwise
95+
96+
97+
@elemwise_scalar.register(Mul)
98+
def elemwise_scalar_mul(op):
99+
import operator
100+
from functools import reduce
101+
102+
def elemwise(*inputs):
103+
return reduce(operator.mul, inputs, 1)
104+
105+
return elemwise
106+
107+
108+
@elemwise_scalar.register(Sub)
109+
def elemwise_scalar_sub(op):
110+
def elemwise(x, y):
111+
return x - y
112+
113+
return elemwise
114+
115+
116+
@elemwise_scalar.register(IntDiv)
117+
def elemwise_scalar_intdiv(op):
118+
def elemwise(x, y):
119+
return x // y
120+
121+
return elemwise
122+
123+
124+
@elemwise_scalar.register(Mod)
125+
def elemwise_scalar_mod(op):
126+
def elemwise(x, y):
127+
return x % y
128+
129+
return elemwise
130+
131+
41132
@jax_funcify.register(Cast)
42133
def jax_funcify_Cast(op, **kwargs):
43134
def cast(x):

tests/link/jax/test_scalar.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,42 @@ def test_jax_variadic_Scalar():
161161
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
162162

163163

164+
def test_add_scalars():
165+
x = at.matrix("x")
166+
size = x.shape[0] + x.shape[0] + x.shape[1]
167+
out = at.ones(size).astype(config.floatX)
168+
169+
out_fg = FunctionGraph([x], [out])
170+
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
171+
172+
173+
def test_mul_scalars():
174+
x = at.matrix("x")
175+
size = x.shape[0] * x.shape[0] * x.shape[1]
176+
out = at.ones(size).astype(config.floatX)
177+
178+
out_fg = FunctionGraph([x], [out])
179+
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
180+
181+
182+
def test_div_scalars():
183+
x = at.matrix("x")
184+
size = x.shape[0] // x.shape[1]
185+
out = at.ones(size).astype(config.floatX)
186+
187+
out_fg = FunctionGraph([x], [out])
188+
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
189+
190+
191+
def test_mod_scalars():
192+
x = at.matrix("x")
193+
size = x.shape[0] % x.shape[1]
194+
out = at.ones(size).astype(config.floatX)
195+
196+
out_fg = FunctionGraph([x], [out])
197+
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
198+
199+
164200
def test_jax_multioutput():
165201
x = vector("x")
166202
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)

0 commit comments

Comments
 (0)