Skip to content

Commit 602f0ed

Browse files
committed
Pre commit changes
1 parent 6a2b774 commit 602f0ed

File tree

4 files changed

+97
-89
lines changed

4 files changed

+97
-89
lines changed

pytensor/link/mlx/dispatch/core.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313

1414
import warnings
1515

16-
import mlx.core as mx
16+
import mlx.core as mx
1717
import numpy as np
1818

19-
from pytensor.link.mlx.dispatch.basic import mlx_funcify
19+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
2020
from pytensor.tensor import get_vector_length
2121
from pytensor.tensor.basic import (
2222
Alloc,
@@ -34,22 +34,22 @@
3434
from pytensor.tensor.exceptions import NotScalarConstantError
3535

3636

37-
@mlx_funcify.register(Join)
37+
@mlx_funcify.register(Join)
3838
def mlx_funcify_Join(op, **kwargs):
3939
def join(axis, *tensors):
4040
view = op.view
4141
if (view != -1) and all(
42-
tensors[i].shape[axis] == 0
42+
tensors[i].shape[axis] == 0
4343
for i in list(range(view)) + list(range(view + 1, len(tensors)))
4444
):
4545
return tensors[view]
4646

47-
return mx.concatenate(tensors, axis=axis)
47+
return mx.concatenate(tensors, axis=axis)
4848

4949
return join
5050

5151

52-
@mlx_funcify.register(Split)
52+
@mlx_funcify.register(Split)
5353
def mlx_funcify_Split(op: Split, node, **kwargs):
5454
_, axis_sym, splits_sym = node.inputs
5555

@@ -84,7 +84,7 @@ def split(x, axis, splits):
8484
cumsum_splits = np.cumsum(splits[:-1])
8585
else:
8686
# dynamic - keep in graph
87-
splits_arr = mx.array(splits)
87+
splits_arr = mx.array(splits)
8888
cumsum_splits = mx.cumsum(
8989
splits_arr[:-1]
9090
).tolist() # python list for mx.split
@@ -98,29 +98,27 @@ def split(x, axis, splits):
9898
if np.any(np.asarray(splits) < 0):
9999
raise ValueError("Split sizes cannot be negative.")
100100

101-
return mx.split(x, cumsum_splits, axis=axis)
101+
return mx.split(x, cumsum_splits, axis=axis)
102102

103103
return split
104104

105105

106-
107-
@mlx_funcify.register(ExtractDiag)
106+
@mlx_funcify.register(ExtractDiag)
108107
def mlx_funcify_ExtractDiag(op, **kwargs):
109108
offset, axis1, axis2 = op.offset, op.axis1, op.axis2
110109

111110
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
112-
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
111+
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
113112

114113
return extract_diag
115114

116115

117-
118-
@mlx_funcify.register(Eye)
116+
@mlx_funcify.register(Eye)
119117
def mlx_funcify_Eye(op, **kwargs):
120118
dtype = convert_dtype_to_mlx(op.dtype)
121119

122120
def eye(N, M, k):
123-
return mx.eye(int(N), int(M), int(k), dtype=dtype)
121+
return mx.eye(int(N), int(M), int(k), dtype=dtype)
124122

125123
return eye
126124

@@ -166,37 +164,33 @@ def convert_dtype_to_mlx(dtype_str):
166164
return dtype_str
167165

168166

169-
170-
@mlx_funcify.register(MakeVector)
167+
@mlx_funcify.register(MakeVector)
171168
def mlx_funcify_MakeVector(op, **kwargs):
172169
dtype = convert_dtype_to_mlx(op.dtype)
173170

174171
def makevector(*x):
175-
return mx.array(x, dtype=dtype)
172+
return mx.array(x, dtype=dtype)
176173

177174
return makevector
178175

179176

180-
181-
@mlx_funcify.register(TensorFromScalar)
177+
@mlx_funcify.register(TensorFromScalar)
182178
def mlx_funcify_TensorFromScalar(op, **kwargs):
183179
def tensor_from_scalar(x):
184180
return x # already an MLX array / scalar
185181

186182
return tensor_from_scalar
187183

188184

189-
190-
@mlx_funcify.register(ScalarFromTensor)
185+
@mlx_funcify.register(ScalarFromTensor)
191186
def mlx_funcify_ScalarFromTensor(op, **kwargs):
192187
def scalar_from_tensor(x):
193-
return mx.array(x).reshape(-1)[0]
188+
return mx.array(x).reshape(-1)[0]
194189

195190
return scalar_from_tensor
196191

197192

198-
199-
@mlx_funcify.register(Tri)
193+
@mlx_funcify.register(Tri)
200194
def mlx_funcify_Tri(op, node, **kwargs):
201195
# node.inputs -> N, M, k
202196
const_args = [getattr(inp, "data", None) for inp in node.inputs]
@@ -208,7 +202,7 @@ def tri(*args):
208202
arg if const_a is None else const_a
209203
for arg, const_a in zip(args, const_args, strict=True)
210204
]
211-
return mx.tri(*args, dtype=dtype)
205+
return mx.tri(*args, dtype=dtype)
212206

213207
return tri
214208

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from functools import singledispatch
2+
13
import mlx.core as mx
24
import numpy as np
3-
from functools import singledispatch
45

56
from pytensor.link.mlx.dispatch.basic import mlx_funcify
67
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
@@ -38,48 +39,56 @@ def dimshuffle(x):
3839
# Second-level dispatch for scalar operations in CAReduce
3940
@singledispatch
4041
def mlx_funcify_CAReduce_scalar_op(scalar_op):
41-
raise NotImplementedError(f"MLX does not support CAReduce with scalar op {scalar_op}")
42+
raise NotImplementedError(
43+
f"MLX does not support CAReduce with scalar op {scalar_op}"
44+
)
4245

4346

4447
@mlx_funcify_CAReduce_scalar_op.register(Add)
4548
def _(scalar_op):
4649
def sum_reduce(x, axis):
4750
return mx.sum(x, axis=axis)
51+
4852
return sum_reduce
4953

5054

5155
@mlx_funcify_CAReduce_scalar_op.register(Mul)
5256
def _(scalar_op):
5357
def prod_reduce(x, axis):
5458
return mx.prod(x, axis=axis)
59+
5560
return prod_reduce
5661

5762

5863
@mlx_funcify_CAReduce_scalar_op.register(AND)
5964
def _(scalar_op):
6065
def all_reduce(x, axis):
6166
return x.all(axis=axis)
67+
6268
return all_reduce
6369

6470

6571
@mlx_funcify_CAReduce_scalar_op.register(OR)
6672
def _(scalar_op):
6773
def any_reduce(x, axis):
6874
return mx.any(x, axis=axis)
75+
6976
return any_reduce
7077

7178

7279
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum)
7380
def _(scalar_op):
7481
def max_reduce(x, axis):
7582
return mx.max(x, axis=axis)
83+
7684
return max_reduce
7785

7886

7987
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum)
8088
def _(scalar_op):
8189
def min_reduce(x, axis):
8290
return mx.min(x, axis=axis)
91+
8392
return min_reduce
8493

8594

@@ -88,10 +97,10 @@ def mlx_funcify_CAReduce(op, **kwargs):
8897
# Dispatch to the appropriate scalar op handler
8998
scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op)
9099
axis = op.axis
91-
100+
92101
def reduce(x):
93102
return scalar_reduce_fn(x, axis)
94-
103+
95104
return reduce
96105

97106

0 commit comments

Comments
 (0)