13
13
14
14
import warnings
15
15
16
- import mlx .core as mx # MLX
16
+ import mlx .core as mx
17
17
import numpy as np
18
18
19
- from pytensor .link .mlx .dispatch .basic import mlx_funcify # MLX
19
+ from pytensor .link .mlx .dispatch .basic import mlx_funcify
20
20
from pytensor .tensor import get_vector_length
21
21
from pytensor .tensor .basic import (
22
22
Alloc ,
34
34
from pytensor .tensor .exceptions import NotScalarConstantError
35
35
36
36
37
- # ------------------------------------------------------------------
38
- # Join
39
- # ------------------------------------------------------------------
40
- @mlx_funcify .register (Join ) # MLX
37
+ @mlx_funcify .register (Join )
41
38
def mlx_funcify_Join (op , ** kwargs ):
42
39
def join (axis , * tensors ):
43
40
view = op .view
44
41
if (view != - 1 ) and all (
45
- tensors [i ].shape [axis ] == 0 # MLX
42
+ tensors [i ].shape [axis ] == 0
46
43
for i in list (range (view )) + list (range (view + 1 , len (tensors )))
47
44
):
48
45
return tensors [view ]
49
46
50
- return mx .concatenate (tensors , axis = axis ) # MLX
47
+ return mx .concatenate (tensors , axis = axis )
51
48
52
49
return join
53
50
54
51
55
- # ------------------------------------------------------------------
56
- # Split
57
- # ------------------------------------------------------------------
58
- @mlx_funcify .register (Split ) # MLX
52
+ @mlx_funcify .register (Split )
59
53
def mlx_funcify_Split (op : Split , node , ** kwargs ):
60
54
_ , axis_sym , splits_sym = node .inputs
61
55
@@ -90,7 +84,7 @@ def split(x, axis, splits):
90
84
cumsum_splits = np .cumsum (splits [:- 1 ])
91
85
else :
92
86
# dynamic - keep in graph
93
- splits_arr = mx .array (splits ) # MLX
87
+ splits_arr = mx .array (splits )
94
88
cumsum_splits = mx .cumsum (
95
89
splits_arr [:- 1 ]
96
90
).tolist () # python list for mx.split
@@ -104,33 +98,29 @@ def split(x, axis, splits):
104
98
if np .any (np .asarray (splits ) < 0 ):
105
99
raise ValueError ("Split sizes cannot be negative." )
106
100
107
- return mx .split (x , cumsum_splits , axis = axis ) # MLX
101
+ return mx .split (x , cumsum_splits , axis = axis )
108
102
109
103
return split
110
104
111
105
112
- # ------------------------------------------------------------------
113
- # ExtractDiag
114
- # ------------------------------------------------------------------
115
- @mlx_funcify .register (ExtractDiag ) # MLX
106
+
107
+ @mlx_funcify .register (ExtractDiag )
116
108
def mlx_funcify_ExtractDiag (op , ** kwargs ):
117
109
offset , axis1 , axis2 = op .offset , op .axis1 , op .axis2
118
110
119
111
def extract_diag (x , offset = offset , axis1 = axis1 , axis2 = axis2 ):
120
- return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 ) # MLX
112
+ return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 )
121
113
122
114
return extract_diag
123
115
124
116
125
- # ------------------------------------------------------------------
126
- # Eye
127
- # ------------------------------------------------------------------
128
- @mlx_funcify .register (Eye ) # MLX
117
+
118
+ @mlx_funcify .register (Eye )
129
119
def mlx_funcify_Eye (op , ** kwargs ):
130
120
dtype = convert_dtype_to_mlx (op .dtype )
131
121
132
122
def eye (N , M , k ):
133
- return mx .eye (int (N ), int (M ), int (k ), dtype = dtype ) # MLX
123
+ return mx .eye (int (N ), int (M ), int (k ), dtype = dtype )
134
124
135
125
return eye
136
126
@@ -176,45 +166,37 @@ def convert_dtype_to_mlx(dtype_str):
176
166
return dtype_str
177
167
178
168
179
- # ------------------------------------------------------------------
180
- # MakeVector
181
- # ------------------------------------------------------------------
182
- @mlx_funcify .register (MakeVector ) # MLX
169
+
170
+ @mlx_funcify .register (MakeVector )
183
171
def mlx_funcify_MakeVector (op , ** kwargs ):
184
172
dtype = convert_dtype_to_mlx (op .dtype )
185
173
186
174
def makevector (* x ):
187
- return mx .array (x , dtype = dtype ) # MLX
175
+ return mx .array (x , dtype = dtype )
188
176
189
177
return makevector
190
178
191
179
192
- # ------------------------------------------------------------------
193
- # TensorFromScalar (identity for MLX)
194
- # ------------------------------------------------------------------
195
- @mlx_funcify .register (TensorFromScalar ) # MLX
180
+
181
+ @mlx_funcify .register (TensorFromScalar )
196
182
def mlx_funcify_TensorFromScalar (op , ** kwargs ):
197
183
def tensor_from_scalar (x ):
198
184
return x # already an MLX array / scalar
199
185
200
186
return tensor_from_scalar
201
187
202
188
203
- # ------------------------------------------------------------------
204
- # ScalarFromTensor
205
- # ------------------------------------------------------------------
206
- @mlx_funcify .register (ScalarFromTensor ) # MLX
189
+
190
+ @mlx_funcify .register (ScalarFromTensor )
207
191
def mlx_funcify_ScalarFromTensor (op , ** kwargs ):
208
192
def scalar_from_tensor (x ):
209
- return mx .array (x ).reshape (- 1 )[0 ] # MLX
193
+ return mx .array (x ).reshape (- 1 )[0 ]
210
194
211
195
return scalar_from_tensor
212
196
213
197
214
- # ------------------------------------------------------------------
215
- # Tri
216
- # ------------------------------------------------------------------
217
- @mlx_funcify .register (Tri ) # MLX
198
+
199
+ @mlx_funcify .register (Tri )
218
200
def mlx_funcify_Tri (op , node , ** kwargs ):
219
201
# node.inputs -> N, M, k
220
202
const_args = [getattr (inp , "data" , None ) for inp in node .inputs ]
@@ -226,7 +208,7 @@ def tri(*args):
226
208
arg if const_a is None else const_a
227
209
for arg , const_a in zip (args , const_args , strict = True )
228
210
]
229
- return mx .tri (* args , dtype = dtype ) # MLX
211
+ return mx .tri (* args , dtype = dtype )
230
212
231
213
return tri
232
214
0 commit comments