5
5
from pytensor .tensor .elemwise import CAReduce , DimShuffle
6
6
from pytensor .tensor .special import Softmax , SoftmaxGrad
7
7
8
+ from pytensor .scalar .basic import (
9
+ AND ,
10
+ EQ ,
11
+ GE ,
12
+ GT ,
13
+ LE ,
14
+ LT ,
15
+ NEQ ,
16
+ OR ,
17
+ Abs ,
18
+ Add ,
19
+ Cast ,
20
+ Cos ,
21
+ Exp ,
22
+ Log ,
23
+ Log1p ,
24
+ Mul ,
25
+ Neg ,
26
+ Pow ,
27
+ ScalarMaximum ,
28
+ ScalarMinimum ,
29
+ Sign ,
30
+ Sin ,
31
+ Sqr ,
32
+ Sqrt ,
33
+ Sub ,
34
+ Switch ,
35
+ TrueDiv ,
36
+ )
8
37
9
38
@mlx_funcify .register (DimShuffle )
10
39
def mlx_funcify_DimShuffle (op , ** kwargs ):
@@ -21,55 +50,57 @@ def dimshuffle(x):
21
50
return dimshuffle
22
51
23
52
53
+ @mlx_funcify .register (DimShuffle )
54
+ def mlx_funcify_DimShuffle (op , ** kwargs ):
55
+ def dimshuffle (x ):
56
+ res = mx .transpose (x , op .transposition )
57
+ shape = list (res .shape [: len (op .shuffle )])
58
+ for augm in op .augment :
59
+ shape .insert (augm , 1 )
60
+ return mx .reshape (res , shape )
61
+ return dimshuffle
62
+
24
63
@mlx_funcify .register (CAReduce )
25
64
def mlx_funcify_CAReduce (op , ** kwargs ):
26
- axis = op .axis
27
- op_nfunc_spec = getattr (op , "nfunc_spec" , None )
28
- scalar_nfunc_spec = getattr (op .scalar_op , "nfunc_spec" , None )
29
- scalar_op_name = getattr (op .scalar_op , "name" , None )
30
- scalar_op_identity = getattr (op .scalar_op , "identity" , None )
31
- acc_dtype = getattr (op , "acc_dtype" , None )
32
-
33
- def careduce (x ):
34
- nonlocal \
35
- axis , \
36
- op_nfunc_spec , \
37
- scalar_nfunc_spec , \
38
- scalar_op_name , \
39
- scalar_op_identity , \
40
- acc_dtype
41
-
42
- if axis is None :
43
- axis = list (range (x .ndim ))
44
-
45
- if acc_dtype is None :
46
- acc_dtype = x .dtype
47
-
48
- if op_nfunc_spec :
49
- mlx_op = getattr (mx , op_nfunc_spec [0 ])
50
- return mlx_op (x , axis = axis )
51
- # return mlx_op(x, axis=axis).astype(acc_dtype)
52
-
53
- # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or
54
- # there isn't one), so we use this fallback approach
55
- if scalar_nfunc_spec :
56
- scalar_fn_name = scalar_nfunc_spec [0 ]
57
- elif scalar_op_name :
58
- scalar_fn_name = scalar_op_name
59
-
60
- to_reduce = sorted (axis , reverse = True )
61
-
62
- if to_reduce :
63
- raise NotImplementedError ("Not implemented yet" )
64
- # In this case, we need to use the `jax.lax` function (if there
65
- # is one), and not the `jnp` version.
66
- mlx_op = getattr (mx , scalar_fn_name )
67
- init_value = mx .array (scalar_op_identity , dtype = acc_dtype )
68
- return mx .reduce (x , init_value , mlx_op , to_reduce ).astype (acc_dtype )
69
- else :
70
- return x
71
-
72
- return careduce
65
+ if isinstance (op .scalar_op , Add ):
66
+
67
+ def sum (x ):
68
+ return mx .sum (x , axis = op .axis )
69
+
70
+ return sum
71
+ elif isinstance (op .scalar_op , Mul ):
72
+
73
+ def prod (x ):
74
+ return mx .prod (x , axis = op .axis )
75
+
76
+ return prod
77
+ elif isinstance (op .scalar_op , AND ):
78
+
79
+ def all (x ):
80
+ return x .all (axis = op .axis )
81
+
82
+ return all
83
+ elif isinstance (op .scalar_op , OR ):
84
+
85
+ def any (x ):
86
+ return mx .any (x , axis = op .axis )
87
+
88
+ return any
89
+ elif isinstance (op .scalar_op , ScalarMaximum ):
90
+
91
+ def max (x ):
92
+ return x .max (axis = op .axis )
93
+
94
+ return max
95
+ elif isinstance (op .scalar_op , ScalarMinimum ):
96
+
97
+ def min (x ):
98
+ return x .min (axis = op .axis )
99
+
100
+ return min
101
+ else :
102
+ raise NotImplementedError (f"MLX does not support Elemwise { op .scalar_op } " )
103
+
73
104
74
105
75
106
@mlx_funcify .register (Softmax )
0 commit comments