13
13
14
14
import warnings
15
15
16
- import mlx .core as mx
16
+ import mlx .core as mx
17
17
import numpy as np
18
18
19
- from pytensor .link .mlx .dispatch .basic import mlx_funcify
19
+ from pytensor .link .mlx .dispatch .basic import mlx_funcify
20
20
from pytensor .tensor import get_vector_length
21
21
from pytensor .tensor .basic import (
22
22
Alloc ,
34
34
from pytensor .tensor .exceptions import NotScalarConstantError
35
35
36
36
37
- @mlx_funcify .register (Join )
37
+ @mlx_funcify .register (Join )
38
38
def mlx_funcify_Join (op , ** kwargs ):
39
39
def join (axis , * tensors ):
40
40
view = op .view
41
41
if (view != - 1 ) and all (
42
- tensors [i ].shape [axis ] == 0
42
+ tensors [i ].shape [axis ] == 0
43
43
for i in list (range (view )) + list (range (view + 1 , len (tensors )))
44
44
):
45
45
return tensors [view ]
46
46
47
- return mx .concatenate (tensors , axis = axis )
47
+ return mx .concatenate (tensors , axis = axis )
48
48
49
49
return join
50
50
51
51
52
- @mlx_funcify .register (Split )
52
+ @mlx_funcify .register (Split )
53
53
def mlx_funcify_Split (op : Split , node , ** kwargs ):
54
54
_ , axis_sym , splits_sym = node .inputs
55
55
@@ -84,7 +84,7 @@ def split(x, axis, splits):
84
84
cumsum_splits = np .cumsum (splits [:- 1 ])
85
85
else :
86
86
# dynamic - keep in graph
87
- splits_arr = mx .array (splits )
87
+ splits_arr = mx .array (splits )
88
88
cumsum_splits = mx .cumsum (
89
89
splits_arr [:- 1 ]
90
90
).tolist () # python list for mx.split
@@ -98,29 +98,27 @@ def split(x, axis, splits):
98
98
if np .any (np .asarray (splits ) < 0 ):
99
99
raise ValueError ("Split sizes cannot be negative." )
100
100
101
- return mx .split (x , cumsum_splits , axis = axis )
101
+ return mx .split (x , cumsum_splits , axis = axis )
102
102
103
103
return split
104
104
105
105
106
-
107
- @mlx_funcify .register (ExtractDiag )
106
+ @mlx_funcify .register (ExtractDiag )
108
107
def mlx_funcify_ExtractDiag (op , ** kwargs ):
109
108
offset , axis1 , axis2 = op .offset , op .axis1 , op .axis2
110
109
111
110
def extract_diag (x , offset = offset , axis1 = axis1 , axis2 = axis2 ):
112
- return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 )
111
+ return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 )
113
112
114
113
return extract_diag
115
114
116
115
117
-
118
- @mlx_funcify .register (Eye )
116
+ @mlx_funcify .register (Eye )
119
117
def mlx_funcify_Eye (op , ** kwargs ):
120
118
dtype = convert_dtype_to_mlx (op .dtype )
121
119
122
120
def eye (N , M , k ):
123
- return mx .eye (int (N ), int (M ), int (k ), dtype = dtype )
121
+ return mx .eye (int (N ), int (M ), int (k ), dtype = dtype )
124
122
125
123
return eye
126
124
@@ -166,37 +164,33 @@ def convert_dtype_to_mlx(dtype_str):
166
164
return dtype_str
167
165
168
166
169
-
170
- @mlx_funcify .register (MakeVector )
167
+ @mlx_funcify .register (MakeVector )
171
168
def mlx_funcify_MakeVector (op , ** kwargs ):
172
169
dtype = convert_dtype_to_mlx (op .dtype )
173
170
174
171
def makevector (* x ):
175
- return mx .array (x , dtype = dtype )
172
+ return mx .array (x , dtype = dtype )
176
173
177
174
return makevector
178
175
179
176
180
-
181
- @mlx_funcify .register (TensorFromScalar )
177
+ @mlx_funcify .register (TensorFromScalar )
182
178
def mlx_funcify_TensorFromScalar (op , ** kwargs ):
183
179
def tensor_from_scalar (x ):
184
180
return x # already an MLX array / scalar
185
181
186
182
return tensor_from_scalar
187
183
188
184
189
-
190
- @mlx_funcify .register (ScalarFromTensor )
185
+ @mlx_funcify .register (ScalarFromTensor )
191
186
def mlx_funcify_ScalarFromTensor (op , ** kwargs ):
192
187
def scalar_from_tensor (x ):
193
- return mx .array (x ).reshape (- 1 )[0 ]
188
+ return mx .array (x ).reshape (- 1 )[0 ]
194
189
195
190
return scalar_from_tensor
196
191
197
192
198
-
199
- @mlx_funcify .register (Tri )
193
+ @mlx_funcify .register (Tri )
200
194
def mlx_funcify_Tri (op , node , ** kwargs ):
201
195
# node.inputs -> N, M, k
202
196
const_args = [getattr (inp , "data" , None ) for inp in node .inputs ]
@@ -208,7 +202,7 @@ def tri(*args):
208
202
arg if const_a is None else const_a
209
203
for arg , const_a in zip (args , const_args , strict = True )
210
204
]
211
- return mx .tri (* args , dtype = dtype )
205
+ return mx .tri (* args , dtype = dtype )
212
206
213
207
return tri
214
208
0 commit comments