Skip to content

Commit 6a2b774

Browse files
committed
Requested changes by Ricardo
1 parent fb8fd2f commit 6a2b774

File tree

5 files changed

+593
-217
lines changed

5 files changed

+593
-217
lines changed

pytensor/link/mlx/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ def mlx_typify(data, **kwargs):
1818

1919

2020
@mlx_typify.register(np.ndarray)
21-
@mlx_typify.register(mx.array)
2221
def mlx_typify_tensor(data, dtype=None, **kwargs):
2322
return mx.array(data, dtype=dtype)
2423

2524

2625
@mlx_typify.register(slice)
2726
@mlx_typify.register(NoneType)
2827
@mlx_typify.register(np.number)
28+
@mlx_typify.register(mx.array)
2929
def mlx_typify_no_conversion_needed(data, **kwargs):
3030
return data
3131

pytensor/link/mlx/dispatch/core.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313

1414
import warnings
1515

16-
import mlx.core as mx # MLX
16+
import mlx.core as mx
1717
import numpy as np
1818

19-
from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX
19+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
2020
from pytensor.tensor import get_vector_length
2121
from pytensor.tensor.basic import (
2222
Alloc,
@@ -34,28 +34,22 @@
3434
from pytensor.tensor.exceptions import NotScalarConstantError
3535

3636

37-
# ------------------------------------------------------------------
38-
# Join
39-
# ------------------------------------------------------------------
40-
@mlx_funcify.register(Join) # MLX
37+
@mlx_funcify.register(Join)
4138
def mlx_funcify_Join(op, **kwargs):
4239
def join(axis, *tensors):
4340
view = op.view
4441
if (view != -1) and all(
45-
tensors[i].shape[axis] == 0 # MLX
42+
tensors[i].shape[axis] == 0
4643
for i in list(range(view)) + list(range(view + 1, len(tensors)))
4744
):
4845
return tensors[view]
4946

50-
return mx.concatenate(tensors, axis=axis) # MLX
47+
return mx.concatenate(tensors, axis=axis)
5148

5249
return join
5350

5451

55-
# ------------------------------------------------------------------
56-
# Split
57-
# ------------------------------------------------------------------
58-
@mlx_funcify.register(Split) # MLX
52+
@mlx_funcify.register(Split)
5953
def mlx_funcify_Split(op: Split, node, **kwargs):
6054
_, axis_sym, splits_sym = node.inputs
6155

@@ -90,7 +84,7 @@ def split(x, axis, splits):
9084
cumsum_splits = np.cumsum(splits[:-1])
9185
else:
9286
# dynamic - keep in graph
93-
splits_arr = mx.array(splits) # MLX
87+
splits_arr = mx.array(splits)
9488
cumsum_splits = mx.cumsum(
9589
splits_arr[:-1]
9690
).tolist() # python list for mx.split
@@ -104,33 +98,29 @@ def split(x, axis, splits):
10498
if np.any(np.asarray(splits) < 0):
10599
raise ValueError("Split sizes cannot be negative.")
106100

107-
return mx.split(x, cumsum_splits, axis=axis) # MLX
101+
return mx.split(x, cumsum_splits, axis=axis)
108102

109103
return split
110104

111105

112-
# ------------------------------------------------------------------
113-
# ExtractDiag
114-
# ------------------------------------------------------------------
115-
@mlx_funcify.register(ExtractDiag) # MLX
106+
107+
@mlx_funcify.register(ExtractDiag)
116108
def mlx_funcify_ExtractDiag(op, **kwargs):
117109
offset, axis1, axis2 = op.offset, op.axis1, op.axis2
118110

119111
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
120-
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX
112+
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
121113

122114
return extract_diag
123115

124116

125-
# ------------------------------------------------------------------
126-
# Eye
127-
# ------------------------------------------------------------------
128-
@mlx_funcify.register(Eye) # MLX
117+
118+
@mlx_funcify.register(Eye)
129119
def mlx_funcify_Eye(op, **kwargs):
130120
dtype = convert_dtype_to_mlx(op.dtype)
131121

132122
def eye(N, M, k):
133-
return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX
123+
return mx.eye(int(N), int(M), int(k), dtype=dtype)
134124

135125
return eye
136126

@@ -176,45 +166,37 @@ def convert_dtype_to_mlx(dtype_str):
176166
return dtype_str
177167

178168

179-
# ------------------------------------------------------------------
180-
# MakeVector
181-
# ------------------------------------------------------------------
182-
@mlx_funcify.register(MakeVector) # MLX
169+
170+
@mlx_funcify.register(MakeVector)
183171
def mlx_funcify_MakeVector(op, **kwargs):
184172
dtype = convert_dtype_to_mlx(op.dtype)
185173

186174
def makevector(*x):
187-
return mx.array(x, dtype=dtype) # MLX
175+
return mx.array(x, dtype=dtype)
188176

189177
return makevector
190178

191179

192-
# ------------------------------------------------------------------
193-
# TensorFromScalar (identity for MLX)
194-
# ------------------------------------------------------------------
195-
@mlx_funcify.register(TensorFromScalar) # MLX
180+
181+
@mlx_funcify.register(TensorFromScalar)
196182
def mlx_funcify_TensorFromScalar(op, **kwargs):
197183
def tensor_from_scalar(x):
198184
return x # already an MLX array / scalar
199185

200186
return tensor_from_scalar
201187

202188

203-
# ------------------------------------------------------------------
204-
# ScalarFromTensor
205-
# ------------------------------------------------------------------
206-
@mlx_funcify.register(ScalarFromTensor) # MLX
189+
190+
@mlx_funcify.register(ScalarFromTensor)
207191
def mlx_funcify_ScalarFromTensor(op, **kwargs):
208192
def scalar_from_tensor(x):
209-
return mx.array(x).reshape(-1)[0] # MLX
193+
return mx.array(x).reshape(-1)[0]
210194

211195
return scalar_from_tensor
212196

213197

214-
# ------------------------------------------------------------------
215-
# Tri
216-
# ------------------------------------------------------------------
217-
@mlx_funcify.register(Tri) # MLX
198+
199+
@mlx_funcify.register(Tri)
218200
def mlx_funcify_Tri(op, node, **kwargs):
219201
# node.inputs -> N, M, k
220202
const_args = [getattr(inp, "data", None) for inp in node.inputs]
@@ -226,7 +208,7 @@ def tri(*args):
226208
arg if const_a is None else const_a
227209
for arg, const_a in zip(args, const_args, strict=True)
228210
]
229-
return mx.tri(*args, dtype=dtype) # MLX
211+
return mx.tri(*args, dtype=dtype)
230212

231213
return tri
232214

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import mlx.core as mx
22
import numpy as np
3+
from functools import singledispatch
34

45
from pytensor.link.mlx.dispatch.basic import mlx_funcify
56
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
@@ -10,6 +11,8 @@
1011
Add,
1112
Cast,
1213
Mul,
14+
ScalarMaximum,
15+
ScalarMinimum,
1316
)
1417
from pytensor.tensor.elemwise import CAReduce, DimShuffle
1518
from pytensor.tensor.special import Softmax, SoftmaxGrad
@@ -32,34 +35,64 @@ def dimshuffle(x):
3235
return dimshuffle
3336

3437

35-
@mlx_funcify.register(CAReduce)
36-
def mlx_funcify_CAReduce(op, **kwargs):
37-
if isinstance(op.scalar_op, Add):
38+
# Second-level dispatch for scalar operations in CAReduce
39+
@singledispatch
40+
def mlx_funcify_CAReduce_scalar_op(scalar_op):
41+
raise NotImplementedError(f"MLX does not support CAReduce with scalar op {scalar_op}")
42+
43+
44+
@mlx_funcify_CAReduce_scalar_op.register(Add)
45+
def _(scalar_op):
46+
def sum_reduce(x, axis):
47+
return mx.sum(x, axis=axis)
48+
return sum_reduce
49+
50+
51+
@mlx_funcify_CAReduce_scalar_op.register(Mul)
52+
def _(scalar_op):
53+
def prod_reduce(x, axis):
54+
return mx.prod(x, axis=axis)
55+
return prod_reduce
3856

39-
def sum(x):
40-
return mx.sum(x, axis=op.axis)
4157

42-
return sum
43-
elif isinstance(op.scalar_op, Mul):
58+
@mlx_funcify_CAReduce_scalar_op.register(AND)
59+
def _(scalar_op):
60+
def all_reduce(x, axis):
61+
return x.all(axis=axis)
62+
return all_reduce
4463

45-
def prod(x):
46-
return mx.prod(x, axis=op.axis)
4764

48-
return prod
49-
elif isinstance(op.scalar_op, AND):
65+
@mlx_funcify_CAReduce_scalar_op.register(OR)
66+
def _(scalar_op):
67+
def any_reduce(x, axis):
68+
return mx.any(x, axis=axis)
69+
return any_reduce
5070

51-
def all(x):
52-
return x.all(axis=op.axis)
5371

54-
return all
55-
elif isinstance(op.scalar_op, OR):
72+
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum)
73+
def _(scalar_op):
74+
def max_reduce(x, axis):
75+
return mx.max(x, axis=axis)
76+
return max_reduce
5677

57-
def any(x):
58-
return mx.any(x, axis=op.axis)
5978

60-
return any
61-
else:
62-
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}")
79+
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum)
80+
def _(scalar_op):
81+
def min_reduce(x, axis):
82+
return mx.min(x, axis=axis)
83+
return min_reduce
84+
85+
86+
@mlx_funcify.register(CAReduce)
87+
def mlx_funcify_CAReduce(op, **kwargs):
88+
# Dispatch to the appropriate scalar op handler
89+
scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op)
90+
axis = op.axis
91+
92+
def reduce(x):
93+
return scalar_reduce_fn(x, axis)
94+
95+
return reduce
6396

6497

6598
@mlx_funcify.register(Softmax)

0 commit comments

Comments
 (0)