Skip to content

Commit 597f84e

Browse files
committed
Almost working
1 parent 5ffc5ef commit 597f84e

File tree

2 files changed

+80
-49
lines changed

2 files changed

+80
-49
lines changed

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,35 @@
55
from pytensor.tensor.elemwise import CAReduce, DimShuffle
66
from pytensor.tensor.special import Softmax, SoftmaxGrad
77

8+
from pytensor.scalar.basic import (
9+
AND,
10+
EQ,
11+
GE,
12+
GT,
13+
LE,
14+
LT,
15+
NEQ,
16+
OR,
17+
Abs,
18+
Add,
19+
Cast,
20+
Cos,
21+
Exp,
22+
Log,
23+
Log1p,
24+
Mul,
25+
Neg,
26+
Pow,
27+
ScalarMaximum,
28+
ScalarMinimum,
29+
Sign,
30+
Sin,
31+
Sqr,
32+
Sqrt,
33+
Sub,
34+
Switch,
35+
TrueDiv,
36+
)
837

938
@mlx_funcify.register(DimShuffle)
1039
def mlx_funcify_DimShuffle(op, **kwargs):
@@ -21,55 +50,57 @@ def dimshuffle(x):
2150
return dimshuffle
2251

2352

53+
@mlx_funcify.register(DimShuffle)
54+
def mlx_funcify_DimShuffle(op, **kwargs):
55+
def dimshuffle(x):
56+
res = mx.transpose(x, op.transposition)
57+
shape = list(res.shape[: len(op.shuffle)])
58+
for augm in op.augment:
59+
shape.insert(augm, 1)
60+
return mx.reshape(res, shape)
61+
return dimshuffle
62+
2463
@mlx_funcify.register(CAReduce)
2564
def mlx_funcify_CAReduce(op, **kwargs):
26-
axis = op.axis
27-
op_nfunc_spec = getattr(op, "nfunc_spec", None)
28-
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
29-
scalar_op_name = getattr(op.scalar_op, "name", None)
30-
scalar_op_identity = getattr(op.scalar_op, "identity", None)
31-
acc_dtype = getattr(op, "acc_dtype", None)
32-
33-
def careduce(x):
34-
nonlocal \
35-
axis, \
36-
op_nfunc_spec, \
37-
scalar_nfunc_spec, \
38-
scalar_op_name, \
39-
scalar_op_identity, \
40-
acc_dtype
41-
42-
if axis is None:
43-
axis = list(range(x.ndim))
44-
45-
if acc_dtype is None:
46-
acc_dtype = x.dtype
47-
48-
if op_nfunc_spec:
49-
mlx_op = getattr(mx, op_nfunc_spec[0])
50-
return mlx_op(x, axis=axis)
51-
# return mlx_op(x, axis=axis).astype(acc_dtype)
52-
53-
# The PyTensor `Op` didn't tell us which NumPy equivalent to use (or
54-
# there isn't one), so we use this fallback approach
55-
if scalar_nfunc_spec:
56-
scalar_fn_name = scalar_nfunc_spec[0]
57-
elif scalar_op_name:
58-
scalar_fn_name = scalar_op_name
59-
60-
to_reduce = sorted(axis, reverse=True)
61-
62-
if to_reduce:
63-
raise NotImplementedError("Not implemented yet")
64-
# In this case, we need to use the `jax.lax` function (if there
65-
# is one), and not the `jnp` version.
66-
mlx_op = getattr(mx, scalar_fn_name)
67-
init_value = mx.array(scalar_op_identity, dtype=acc_dtype)
68-
return mx.reduce(x, init_value, mlx_op, to_reduce).astype(acc_dtype)
69-
else:
70-
return x
71-
72-
return careduce
65+
if isinstance(op.scalar_op, Add):
66+
67+
def sum(x):
68+
return mx.sum(x, axis=op.axis)
69+
70+
return sum
71+
elif isinstance(op.scalar_op, Mul):
72+
73+
def prod(x):
74+
return mx.prod(x, axis=op.axis)
75+
76+
return prod
77+
elif isinstance(op.scalar_op, AND):
78+
79+
def all(x):
80+
return x.all(axis=op.axis)
81+
82+
return all
83+
elif isinstance(op.scalar_op, OR):
84+
85+
def any(x):
86+
return mx.any(x, axis=op.axis)
87+
88+
return any
89+
elif isinstance(op.scalar_op, ScalarMaximum):
90+
91+
def max(x):
92+
return x.max(axis=op.axis)
93+
94+
return max
95+
elif isinstance(op.scalar_op, ScalarMinimum):
96+
97+
def min(x):
98+
return x.min(axis=op.axis)
99+
100+
return min
101+
else:
102+
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}")
103+
73104

74105

75106
@mlx_funcify.register(Softmax)

pytensor/link/mlx/dispatch/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,13 @@ def any(x, y):
211211
elif isinstance(op.scalar_op, ScalarMaximum):
212212

213213
def max(x):
214-
return mx.max(x, axis=op.axis)
214+
return x.max(axis=op.axis)
215215

216216
return max
217217
elif isinstance(op.scalar_op, ScalarMinimum):
218218

219219
def min(x):
220-
return mx.min(x, axis=op.axis)
220+
return x.min(axis=op.axis)
221221

222222
return min
223223
elif isinstance(op.scalar_op, Cast):

0 commit comments

Comments
 (0)