Skip to content

Commit fb8fd2f

Browse files
committed
Last PR sampling working
Working
1 parent 597f84e commit fb8fd2f

File tree

3 files changed

+87
-63
lines changed

3 files changed

+87
-63
lines changed

pytensor/link/mlx/dispatch/core.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,64 @@ def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
127127
# ------------------------------------------------------------------
128128
@mlx_funcify.register(Eye) # MLX
129129
def mlx_funcify_Eye(op, **kwargs):
130-
dtype = op.dtype
130+
dtype = convert_dtype_to_mlx(op.dtype)
131131

132132
def eye(N, M, k):
133133
return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX
134134

135135
return eye
136136

137137

138+
def convert_dtype_to_mlx(dtype_str):
139+
"""Convert PyTensor dtype strings to MLX dtype objects.
140+
141+
MLX expects dtype objects rather than string literals for type conversion.
142+
This function maps common dtype strings to their MLX equivalents.
143+
"""
144+
if isinstance(dtype_str, str):
145+
if dtype_str == "bool":
146+
return mx.bool_
147+
elif dtype_str == "int8":
148+
return mx.int8
149+
elif dtype_str == "int16":
150+
return mx.int16
151+
elif dtype_str == "int32":
152+
return mx.int32
153+
elif dtype_str == "int64":
154+
return mx.int64
155+
elif dtype_str == "uint8":
156+
return mx.uint8
157+
elif dtype_str == "uint16":
158+
return mx.uint16
159+
elif dtype_str == "uint32":
160+
return mx.uint32
161+
elif dtype_str == "uint64":
162+
return mx.uint64
163+
elif dtype_str == "float16":
164+
return mx.float16
165+
elif dtype_str == "float32":
166+
return mx.float32
167+
elif dtype_str == "float64":
168+
return mx.float64
169+
elif dtype_str == "bfloat16":
170+
return mx.bfloat16
171+
elif dtype_str == "complex64":
172+
return mx.complex64
173+
elif dtype_str == "complex128":
174+
return mx.complex128
175+
# Return as is if it's already an MLX dtype or not a recognized string
176+
return dtype_str
177+
178+
138179
# ------------------------------------------------------------------
139180
# MakeVector
140181
# ------------------------------------------------------------------
141182
@mlx_funcify.register(MakeVector) # MLX
142183
def mlx_funcify_MakeVector(op, **kwargs):
184+
dtype = convert_dtype_to_mlx(op.dtype)
185+
143186
def makevector(*x):
144-
return mx.array(x, dtype=op.dtype) # MLX
187+
return mx.array(x, dtype=dtype) # MLX
145188

146189
return makevector
147190

@@ -175,31 +218,36 @@ def scalar_from_tensor(x):
175218
def mlx_funcify_Tri(op, node, **kwargs):
176219
# node.inputs -> N, M, k
177220
const_args = [getattr(inp, "data", None) for inp in node.inputs]
221+
dtype = convert_dtype_to_mlx(op.dtype)
178222

179223
def tri(*args):
180224
# Replace args with compile-time constants when available
181225
args = [
182226
arg if const_a is None else const_a
183227
for arg, const_a in zip(args, const_args, strict=True)
184228
]
185-
return mx.tri(*args, dtype=op.dtype) # MLX
229+
return mx.tri(*args, dtype=dtype) # MLX
186230

187231
return tri
188232

189233

190234
@mlx_funcify.register(AllocEmpty)
191235
def mlx_funcify_AllocEmpty(op, **kwargs):
236+
dtype = convert_dtype_to_mlx(op.dtype)
237+
192238
def allocempty(*shape):
193-
return mx.zeros(shape, dtype=op.dtype)
239+
return mx.zeros(shape, dtype=dtype)
194240

195241
return allocempty
196242

197243

198244
@mlx_funcify.register(Alloc)
199245
def mlx_funcify_Alloc(op, node, **kwargs):
200246
def alloc(x, *shape):
201-
res = mx.broadcast_to(x, shape)
202-
Alloc._check_runtime_broadcast(node, mx.array(x), res.shape)
247+
# Convert x to an MLX array with the correct dtype if it's a scalar
248+
x_array = mx.array(x)
249+
res = mx.broadcast_to(x_array, shape)
250+
Alloc._check_runtime_broadcast(node, x_array, res.shape)
203251
return res
204252

205253
return alloc
Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,37 @@
11
import mlx.core as mx
2+
import numpy as np
23

34
from pytensor.link.mlx.dispatch.basic import mlx_funcify
5+
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
46
from pytensor.scalar import Softplus
5-
from pytensor.tensor.elemwise import CAReduce, DimShuffle
6-
from pytensor.tensor.special import Softmax, SoftmaxGrad
7-
87
from pytensor.scalar.basic import (
98
AND,
10-
EQ,
11-
GE,
12-
GT,
13-
LE,
14-
LT,
15-
NEQ,
169
OR,
17-
Abs,
1810
Add,
1911
Cast,
20-
Cos,
21-
Exp,
22-
Log,
23-
Log1p,
2412
Mul,
25-
Neg,
26-
Pow,
27-
ScalarMaximum,
28-
ScalarMinimum,
29-
Sign,
30-
Sin,
31-
Sqr,
32-
Sqrt,
33-
Sub,
34-
Switch,
35-
TrueDiv,
3613
)
14+
from pytensor.tensor.elemwise import CAReduce, DimShuffle
15+
from pytensor.tensor.special import Softmax, SoftmaxGrad
16+
3717

3818
@mlx_funcify.register(DimShuffle)
3919
def mlx_funcify_DimShuffle(op, **kwargs):
4020
def dimshuffle(x):
21+
# Convert scalar to array if needed
22+
if isinstance(x, int | float) or (
23+
isinstance(x, np.number) and not isinstance(x, np.ndarray)
24+
):
25+
x = mx.array(x)
4126
res = mx.transpose(x, op.transposition)
42-
4327
shape = list(res.shape[: len(op.shuffle)])
44-
4528
for augm in op.augment:
4629
shape.insert(augm, 1)
47-
4830
return mx.reshape(res, shape)
4931

5032
return dimshuffle
5133

5234

53-
@mlx_funcify.register(DimShuffle)
54-
def mlx_funcify_DimShuffle(op, **kwargs):
55-
def dimshuffle(x):
56-
res = mx.transpose(x, op.transposition)
57-
shape = list(res.shape[: len(op.shuffle)])
58-
for augm in op.augment:
59-
shape.insert(augm, 1)
60-
return mx.reshape(res, shape)
61-
return dimshuffle
62-
6335
@mlx_funcify.register(CAReduce)
6436
def mlx_funcify_CAReduce(op, **kwargs):
6537
if isinstance(op.scalar_op, Add):
@@ -86,23 +58,10 @@ def any(x):
8658
return mx.any(x, axis=op.axis)
8759

8860
return any
89-
elif isinstance(op.scalar_op, ScalarMaximum):
90-
91-
def max(x):
92-
return x.max(axis=op.axis)
93-
94-
return max
95-
elif isinstance(op.scalar_op, ScalarMinimum):
96-
97-
def min(x):
98-
return x.min(axis=op.axis)
99-
100-
return min
10161
else:
10262
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}")
10363

10464

105-
10665
@mlx_funcify.register(Softmax)
10766
def mlx_funcify_Softmax(op, **kwargs):
10867
axis = op.axis
@@ -142,3 +101,12 @@ def softplus(x):
142101
)
143102

144103
return softplus
104+
105+
106+
@mlx_funcify.register(Cast)
107+
def mlx_funcify_Cast(op, **kwargs):
108+
def cast(x):
109+
dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype)
110+
return x.astype(dtype)
111+
112+
return cast

pytensor/link/mlx/dispatch/math.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import mlx.core as mx
22

3-
from pytensor.link.mlx.dispatch import mlx_funcify
3+
from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify
4+
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
45
from pytensor.scalar import Softplus
56
from pytensor.scalar.basic import (
67
AND,
@@ -36,6 +37,12 @@
3637
from pytensor.tensor.math import Dot
3738

3839

40+
@mlx_typify.register(int)
41+
@mlx_typify.register(float)
42+
def mlx_typify_python_scalar(data, **kwargs):
43+
return mx.array(data)
44+
45+
3946
@mlx_funcify.register(Dot)
4047
def mlx_funcify_Dot(op, **kwargs):
4148
def dot(x, y):
@@ -210,20 +217,21 @@ def any(x, y):
210217
return any
211218
elif isinstance(op.scalar_op, ScalarMaximum):
212219

213-
def max(x):
214-
return x.max(axis=op.axis)
220+
def max(x, y):
221+
return mx.maximum(x, y)
215222

216223
return max
217224
elif isinstance(op.scalar_op, ScalarMinimum):
218225

219-
def min(x):
220-
return x.min(axis=op.axis)
226+
def min(x, y):
227+
return mx.minimum(x, y)
221228

222229
return min
223230
elif isinstance(op.scalar_op, Cast):
224231

225232
def cast(x):
226-
return mx.cast(x, op.dtype)
233+
dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype)
234+
return x.astype(dtype)
227235

228236
return cast
229237
elif isinstance(op.scalar_op, Sign):

0 commit comments

Comments
 (0)