5
5
6
6
from pytensor .link .jax .dispatch .basic import jax_funcify
7
7
from pytensor .scalar import Softplus
8
- from pytensor .scalar .basic import Cast , Clip , Composite , Identity , ScalarOp , Second
8
+ from pytensor .scalar .basic import (
9
+ Add ,
10
+ Cast ,
11
+ Clip ,
12
+ Composite ,
13
+ Identity ,
14
+ IntDiv ,
15
+ Mod ,
16
+ Mul ,
17
+ ScalarOp ,
18
+ Second ,
19
+ Sub ,
20
+ )
9
21
from pytensor .scalar .math import Erf , Erfc , Erfinv , Log1mexp , Psi
10
22
11
23
24
+ def check_if_inputs_scalars (node ):
25
+ """Check whether all the inputs of an `Elemwise` are scalar values.
26
+
27
+ `jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
28
+ while the corresponding Python operators return concrete values when passed
29
+ concrete values. In order to be able to compile the largest number of graphs
30
+ possible we need to preserve concrete values whenever we can. We thus need
31
+ to dispatch differently the PyTensor operators depending on whether the inputs
32
+ are scalars.
33
+
34
+ """
35
+ ndims_input = [inp .type .ndim for inp in node .inputs ]
36
+ are_inputs_scalars = True
37
+ for ndim in ndims_input :
38
+ try :
39
+ if ndim > 0 :
40
+ are_inputs_scalars = False
41
+ except TypeError :
42
+ are_inputs_scalars = False
43
+
44
+ return are_inputs_scalars
45
+
46
+
12
47
@jax_funcify .register (ScalarOp )
13
- def jax_funcify_ScalarOp (op , ** kwargs ):
48
+ def jax_funcify_ScalarOp (op , node , ** kwargs ):
14
49
func_name = op .nfunc_spec [0 ]
15
50
51
+ # We dispatch some PyTensor operators to Python operators
52
+ # whenever the inputs are all scalars.
53
+ are_inputs_scalars = check_if_inputs_scalars (node )
54
+ if are_inputs_scalars :
55
+ elemwise = elemwise_scalar (op )
56
+ if elemwise is not None :
57
+ return elemwise
58
+
16
59
if "." in func_name :
17
60
jnp_func = functools .reduce (getattr , [jax ] + func_name .split ("." ))
18
61
else :
@@ -38,6 +81,54 @@ def elemwise(*args):
38
81
return jnp_func
39
82
40
83
84
+ @functools .singledispatch
85
+ def elemwise_scalar (op ):
86
+ return None
87
+
88
+
89
+ @elemwise_scalar .register (Add )
90
+ def elemwise_scalar_add (op ):
91
+ def elemwise (* inputs ):
92
+ return sum (inputs )
93
+
94
+ return elemwise
95
+
96
+
97
+ @elemwise_scalar .register (Mul )
98
+ def elemwise_scalar_mul (op ):
99
+ import operator
100
+ from functools import reduce
101
+
102
+ def elemwise (* inputs ):
103
+ return reduce (operator .mul , inputs , 1 )
104
+
105
+ return elemwise
106
+
107
+
108
+ @elemwise_scalar .register (Sub )
109
+ def elemwise_scalar_sub (op ):
110
+ def elemwise (x , y ):
111
+ return x - y
112
+
113
+ return elemwise
114
+
115
+
116
+ @elemwise_scalar .register (IntDiv )
117
+ def elemwise_scalar_intdiv (op ):
118
+ def elemwise (x , y ):
119
+ return x // y
120
+
121
+ return elemwise
122
+
123
+
124
+ @elemwise_scalar .register (Mod )
125
+ def elemwise_scalar_mod (op ):
126
+ def elemwise (x , y ):
127
+ return x % y
128
+
129
+ return elemwise
130
+
131
+
41
132
@jax_funcify .register (Cast )
42
133
def jax_funcify_Cast (op , ** kwargs ):
43
134
def cast (x ):
0 commit comments