8
8
from __future__ import annotations
9
9
10
10
import builtins
11
+ import itertools
11
12
import operator
12
13
from typing import Optional , Sequence
13
14
16
17
from . import _dtypes_impl
17
18
from . import _reductions as _impl
18
19
from . import _util
19
- from ._normalizations import (
20
+
21
+ # these imports are for einsum only
22
+ from ._normalizations import ( # isort: skip
20
23
ArrayLike ,
21
24
AxisLike ,
25
+ CastingModes ,
22
26
DTypeLike ,
23
27
NDArray ,
24
28
NotImplementedType ,
25
29
OutArray ,
30
+ maybe_copy_to ,
26
31
normalize_array_like ,
32
+ normalize_casting ,
33
+ normalize_dtype ,
34
+ wrap_tensors ,
27
35
)
28
36
29
37
# ###### array creation routines
@@ -38,7 +46,7 @@ def copy(
38
46
def copyto (
39
47
dst : NDArray ,
40
48
src : ArrayLike ,
41
- casting = "same_kind" ,
49
+ casting : Optional [ CastingModes ] = "same_kind" ,
42
50
where : NotImplementedType = None ,
43
51
):
44
52
(src ,) = _util .typecast_tensors ((src ,), dst .dtype , casting = casting )
@@ -97,7 +105,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
97
105
return tensors
98
106
99
107
100
- def _concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
108
+ def _concatenate (
109
+ tensors , axis = 0 , out = None , dtype = None , casting : Optional [CastingModes ] = "same_kind"
110
+ ):
101
111
# pure torch implementation, used below and in cov/corrcoef below
102
112
tensors , axis = _util .axis_none_flatten (* tensors , axis = axis )
103
113
tensors = _concat_cast_helper (tensors , out , dtype , casting )
@@ -109,15 +119,18 @@ def concatenate(
109
119
axis = 0 ,
110
120
out : Optional [OutArray ] = None ,
111
121
dtype : Optional [DTypeLike ] = None ,
112
- casting = "same_kind" ,
122
+ casting : Optional [ CastingModes ] = "same_kind" ,
113
123
):
114
124
_concat_check (ar_tuple , dtype , out = out )
115
125
result = _concatenate (ar_tuple , axis = axis , out = out , dtype = dtype , casting = casting )
116
126
return result
117
127
118
128
119
129
def vstack (
120
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
130
+ tup : Sequence [ArrayLike ],
131
+ * ,
132
+ dtype : Optional [DTypeLike ] = None ,
133
+ casting : Optional [CastingModes ] = "same_kind" ,
121
134
):
122
135
_concat_check (tup , dtype , out = None )
123
136
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
@@ -128,15 +141,21 @@ def vstack(
128
141
129
142
130
143
def hstack (
131
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
144
+ tup : Sequence [ArrayLike ],
145
+ * ,
146
+ dtype : Optional [DTypeLike ] = None ,
147
+ casting : Optional [CastingModes ] = "same_kind" ,
132
148
):
133
149
_concat_check (tup , dtype , out = None )
134
150
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
135
151
return torch .hstack (tensors )
136
152
137
153
138
154
def dstack (
139
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
155
+ tup : Sequence [ArrayLike ],
156
+ * ,
157
+ dtype : Optional [DTypeLike ] = None ,
158
+ casting : Optional [CastingModes ] = "same_kind" ,
140
159
):
141
160
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
142
161
# but {h,v}stack do. Hence add them here for consistency.
@@ -146,7 +165,10 @@ def dstack(
146
165
147
166
148
167
def column_stack (
149
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
168
+ tup : Sequence [ArrayLike ],
169
+ * ,
170
+ dtype : Optional [DTypeLike ] = None ,
171
+ casting : Optional [CastingModes ] = "same_kind" ,
150
172
):
151
173
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
152
174
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -162,7 +184,7 @@ def stack(
162
184
out : Optional [OutArray ] = None ,
163
185
* ,
164
186
dtype : Optional [DTypeLike ] = None ,
165
- casting = "same_kind" ,
187
+ casting : Optional [ CastingModes ] = "same_kind" ,
166
188
):
167
189
_concat_check (arrays , dtype , out = out )
168
190
@@ -1152,6 +1174,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
1152
1174
def tensordot (a : ArrayLike , b : ArrayLike , axes = 2 ):
1153
1175
if isinstance (axes , (list , tuple )):
1154
1176
axes = [[ax ] if isinstance (ax , int ) else ax for ax in axes ]
1177
+
1178
+ target_dtype = _dtypes_impl .result_type_impl ((a .dtype , b .dtype ))
1179
+ a = _util .cast_if_needed (a , target_dtype )
1180
+ b = _util .cast_if_needed (b , target_dtype )
1181
+
1155
1182
return torch .tensordot (a , b , dims = axes )
1156
1183
1157
1184
@@ -1194,6 +1221,77 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1194
1221
return torch .outer (a , b )
1195
1222
1196
1223
1224
+ def einsum (* operands , out = None , dtype = None , order = "K" , casting = "safe" , optimize = False ):
1225
+ # Have to manually normalize *operands and **kwargs, following the NumPy signature
1226
+
1227
+ from ._ndarray import ndarray
1228
+
1229
+ dtype = normalize_dtype (dtype )
1230
+ casting = normalize_casting (casting )
1231
+ if out is not None and not isinstance (out , ndarray ):
1232
+ raise TypeError ("'out' must be an array" )
1233
+ if order != "K" :
1234
+ raise NotImplementedError ("'order' parameter is not supported." )
1235
+
1236
+ # parse arrays and normalize them
1237
+ sublist_format = not isinstance (operands [0 ], str )
1238
+ if sublist_format :
1239
+ # op, str, op, str ... [sublistout] format: normalize every other argument
1240
+
1241
+ # - if sublistout is not given, the length of operands is even, and we pick
1242
+ # odd-numbered elements, which are arrays.
1243
+ # - if sublistout is given, the length of operands is odd, we peel off
1244
+ # the last one, and pick odd-numbered elements, which are arrays.
1245
+ # Without [:-1], we would have picked sublistout, too.
1246
+ array_operands = operands [:- 1 ][::2 ]
1247
+ else :
1248
+ # ("ij->", arrays) format
1249
+ subscripts , array_operands = operands [0 ], operands [1 :]
1250
+
1251
+ tensors = [normalize_array_like (op ) for op in array_operands ]
1252
+ target_dtype = (
1253
+ _dtypes_impl .result_type_impl ([op .dtype for op in tensors ])
1254
+ if dtype is None
1255
+ else dtype
1256
+ )
1257
+
1258
+ # work around 'bmm' not implemented for 'Half' etc
1259
+ is_half = target_dtype == torch .float16
1260
+ if is_half :
1261
+ target_dtype = torch .float32
1262
+
1263
+ is_short_int = target_dtype in [torch .uint8 , torch .int8 , torch .int16 , torch .int32 ]
1264
+ if is_short_int :
1265
+ target_dtype = torch .int64
1266
+
1267
+ tensors = _util .typecast_tensors (tensors , target_dtype , casting )
1268
+
1269
+ try :
1270
+ # set the global state to handle the optimize=... argument, restore on exit
1271
+ old_strategy = torch .backends .opt_einsum .strategy
1272
+ torch .backends .opt_einsum .strategy = optimize
1273
+
1274
+ if sublist_format :
1275
+ # recombine operands
1276
+ sublists = operands [1 ::2 ]
1277
+ has_sublistout = len (operands ) % 2 == 1
1278
+ if has_sublistout :
1279
+ sublistout = operands [- 1 ]
1280
+ operands = list (itertools .chain (* zip (tensors , sublists )))
1281
+ if has_sublistout :
1282
+ operands .append (sublistout )
1283
+
1284
+ result = torch .einsum (* operands )
1285
+ else :
1286
+ result = torch .einsum (subscripts , * tensors )
1287
+
1288
+ finally :
1289
+ torch .backends .opt_einsum .strategy = old_strategy
1290
+
1291
+ result = maybe_copy_to (out , result )
1292
+ return wrap_tensors (result )
1293
+
1294
+
1197
1295
# ### sort and partition ###
1198
1296
1199
1297
0 commit comments