8
8
from __future__ import annotations
9
9
10
10
import builtins
11
+ import itertools
11
12
import math
12
13
import operator
13
14
from typing import Optional , Sequence
17
18
from . import _dtypes_impl
18
19
from . import _reductions as _impl
19
20
from . import _util
20
- from ._normalizations import (
21
+
22
+ # these imports are for einsum only
23
+ from ._normalizations import ( # isort: skip
21
24
ArrayLike ,
22
25
AxisLike ,
26
+ CastingModes ,
23
27
DTypeLike ,
24
28
NDArray ,
25
29
NotImplementedType ,
26
30
OutArray ,
31
+ maybe_copy_to ,
27
32
normalize_array_like ,
33
+ normalize_casting ,
34
+ normalize_dtype ,
35
+ wrap_tensors ,
28
36
)
29
37
30
38
# ###### array creation routines
@@ -39,7 +47,7 @@ def copy(
39
47
def copyto (
40
48
dst : NDArray ,
41
49
src : ArrayLike ,
42
- casting = "same_kind" ,
50
+ casting : Optional [ CastingModes ] = "same_kind" ,
43
51
where : NotImplementedType = None ,
44
52
):
45
53
(src ,) = _util .typecast_tensors ((src ,), dst .dtype , casting = casting )
@@ -98,7 +106,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
98
106
return tensors
99
107
100
108
101
- def _concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
109
+ def _concatenate (
110
+ tensors , axis = 0 , out = None , dtype = None , casting : Optional [CastingModes ] = "same_kind"
111
+ ):
102
112
# pure torch implementation, used below and in cov/corrcoef below
103
113
tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
104
114
tensors = _concat_cast_helper (tensors , out , dtype , casting )
@@ -110,15 +120,18 @@ def concatenate(
110
120
axis = 0 ,
111
121
out : Optional [OutArray ] = None ,
112
122
dtype : Optional [DTypeLike ] = None ,
113
- casting = "same_kind" ,
123
+ casting : Optional [ CastingModes ] = "same_kind" ,
114
124
):
115
125
_concat_check (ar_tuple , dtype , out = out )
116
126
result = _concatenate (ar_tuple , axis = axis , out = out , dtype = dtype , casting = casting )
117
127
return result
118
128
119
129
120
130
def vstack (
121
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
131
+ tup : Sequence [ArrayLike ],
132
+ * ,
133
+ dtype : Optional [DTypeLike ] = None ,
134
+ casting : Optional [CastingModes ] = "same_kind" ,
122
135
):
123
136
_concat_check (tup , dtype , out = None )
124
137
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
@@ -129,15 +142,21 @@ def vstack(
129
142
130
143
131
144
def hstack (
132
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
145
+ tup : Sequence [ArrayLike ],
146
+ * ,
147
+ dtype : Optional [DTypeLike ] = None ,
148
+ casting : Optional [CastingModes ] = "same_kind" ,
133
149
):
134
150
_concat_check (tup , dtype , out = None )
135
151
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
136
152
return torch .hstack (tensors )
137
153
138
154
139
155
def dstack (
140
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
156
+ tup : Sequence [ArrayLike ],
157
+ * ,
158
+ dtype : Optional [DTypeLike ] = None ,
159
+ casting : Optional [CastingModes ] = "same_kind" ,
141
160
):
142
161
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
143
162
# but {h,v}stack do. Hence add them here for consistency.
@@ -147,7 +166,10 @@ def dstack(
147
166
148
167
149
168
def column_stack (
150
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
169
+ tup : Sequence [ArrayLike ],
170
+ * ,
171
+ dtype : Optional [DTypeLike ] = None ,
172
+ casting : Optional [CastingModes ] = "same_kind" ,
151
173
):
152
174
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
153
175
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -163,7 +185,7 @@ def stack(
163
185
out : Optional [OutArray ] = None ,
164
186
* ,
165
187
dtype : Optional [DTypeLike ] = None ,
166
- casting = "same_kind" ,
188
+ casting : Optional [ CastingModes ] = "same_kind" ,
167
189
):
168
190
_concat_check (arrays , dtype , out = out )
169
191
@@ -1166,6 +1188,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
1166
1188
def tensordot (a : ArrayLike , b : ArrayLike , axes = 2 ):
1167
1189
if isinstance (axes , (list , tuple )):
1168
1190
axes = [[ax ] if isinstance (ax , int ) else ax for ax in axes ]
1191
+
1192
+ target_dtype = _dtypes_impl .result_type_impl ((a .dtype , b .dtype ))
1193
+ a = _util .cast_if_needed (a , target_dtype )
1194
+ b = _util .cast_if_needed (b , target_dtype )
1195
+
1169
1196
return torch .tensordot (a , b , dims = axes )
1170
1197
1171
1198
@@ -1208,6 +1235,77 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1208
1235
return torch .outer (a , b )
1209
1236
1210
1237
1238
+ def einsum (* operands , out = None , dtype = None , order = "K" , casting = "safe" , optimize = False ):
1239
+ # Have to manually normalize *operands and **kwargs, following the NumPy signature
1240
+
1241
+ from ._ndarray import ndarray
1242
+
1243
+ dtype = normalize_dtype (dtype )
1244
+ casting = normalize_casting (casting )
1245
+ if out is not None and not isinstance (out , ndarray ):
1246
+ raise TypeError ("'out' must be an array" )
1247
+ if order != "K" :
1248
+ raise NotImplementedError ("'order' parameter is not supported." )
1249
+
1250
+ # parse arrays and normalize them
1251
+ sublist_format = not isinstance (operands [0 ], str )
1252
+ if sublist_format :
1253
+ # op, str, op, str ... [sublistout] format: normalize every other argument
1254
+
1255
+ # - if sublistout is not given, the length of operands is even, and we pick
1256
+ # odd-numbered elements, which are arrays.
1257
+ # - if sublistout is given, the length of operands is odd, we peel off
1258
+ # the last one, and pick odd-numbered elements, which are arrays.
1259
+ # Without [:-1], we would have picked sublistout, too.
1260
+ array_operands = operands [:- 1 ][::2 ]
1261
+ else :
1262
+ # ("ij->", arrays) format
1263
+ subscripts , array_operands = operands [0 ], operands [1 :]
1264
+
1265
+ tensors = [normalize_array_like (op ) for op in array_operands ]
1266
+ target_dtype = (
1267
+ _dtypes_impl .result_type_impl ([op .dtype for op in tensors ])
1268
+ if dtype is None
1269
+ else dtype
1270
+ )
1271
+
1272
+ # work around 'bmm' not implemented for 'Half' etc
1273
+ is_half = target_dtype == torch .float16
1274
+ if is_half :
1275
+ target_dtype = torch .float32
1276
+
1277
+ is_short_int = target_dtype in [torch .uint8 , torch .int8 , torch .int16 , torch .int32 ]
1278
+ if is_short_int :
1279
+ target_dtype = torch .int64
1280
+
1281
+ tensors = _util .typecast_tensors (tensors , target_dtype , casting )
1282
+
1283
+ try :
1284
+ # set the global state to handle the optimize=... argument, restore on exit
1285
+ old_strategy = torch .backends .opt_einsum .strategy
1286
+ torch .backends .opt_einsum .strategy = optimize
1287
+
1288
+ if sublist_format :
1289
+ # recombine operands
1290
+ sublists = operands [1 ::2 ]
1291
+ has_sublistout = len (operands ) % 2 == 1
1292
+ if has_sublistout :
1293
+ sublistout = operands [- 1 ]
1294
+ operands = list (itertools .chain (* zip (tensors , sublists )))
1295
+ if has_sublistout :
1296
+ operands .append (sublistout )
1297
+
1298
+ result = torch .einsum (* operands )
1299
+ else :
1300
+ result = torch .einsum (subscripts , * tensors )
1301
+
1302
+ finally :
1303
+ torch .backends .opt_einsum .strategy = old_strategy
1304
+
1305
+ result = maybe_copy_to (out , result )
1306
+ return wrap_tensors (result )
1307
+
1308
+
1211
1309
# ### sort and partition ###
1212
1310
1213
1311
@@ -1798,8 +1896,6 @@ def bartlett(M):
1798
1896
1799
1897
1800
1898
def common_type (* tensors : ArrayLike ):
1801
- import builtins
1802
-
1803
1899
is_complex = False
1804
1900
precision = 0
1805
1901
for a in tensors :
0 commit comments