8
8
from __future__ import annotations
9
9
10
10
import builtins
11
+ import itertools
11
12
import math
12
13
import operator
13
14
from typing import Optional , Sequence
20
21
from ._normalizations import (
21
22
ArrayLike ,
22
23
AxisLike ,
24
+ CastingModes ,
23
25
DTypeLike ,
24
26
NDArray ,
25
27
NotImplementedType ,
@@ -39,7 +41,7 @@ def copy(
39
41
def copyto (
40
42
dst : NDArray ,
41
43
src : ArrayLike ,
42
- casting = "same_kind" ,
44
+ casting : Optional [ CastingModes ] = "same_kind" ,
43
45
where : NotImplementedType = None ,
44
46
):
45
47
(src ,) = _util .typecast_tensors ((src ,), dst .dtype , casting = casting )
@@ -98,7 +100,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
98
100
return tensors
99
101
100
102
101
- def _concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
103
+ def _concatenate (
104
+ tensors , axis = 0 , out = None , dtype = None , casting : Optional [CastingModes ] = "same_kind"
105
+ ):
102
106
# pure torch implementation, used below and in cov/corrcoef below
103
107
tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
104
108
tensors = _concat_cast_helper (tensors , out , dtype , casting )
@@ -110,15 +114,18 @@ def concatenate(
110
114
axis = 0 ,
111
115
out : Optional [OutArray ] = None ,
112
116
dtype : Optional [DTypeLike ] = None ,
113
- casting = "same_kind" ,
117
+ casting : Optional [ CastingModes ] = "same_kind" ,
114
118
):
115
119
_concat_check (ar_tuple , dtype , out = out )
116
120
result = _concatenate (ar_tuple , axis = axis , out = out , dtype = dtype , casting = casting )
117
121
return result
118
122
119
123
120
124
def vstack (
121
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
125
+ tup : Sequence [ArrayLike ],
126
+ * ,
127
+ dtype : Optional [DTypeLike ] = None ,
128
+ casting : Optional [CastingModes ] = "same_kind" ,
122
129
):
123
130
_concat_check (tup , dtype , out = None )
124
131
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
@@ -129,15 +136,21 @@ def vstack(
129
136
130
137
131
138
def hstack (
132
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
139
+ tup : Sequence [ArrayLike ],
140
+ * ,
141
+ dtype : Optional [DTypeLike ] = None ,
142
+ casting : Optional [CastingModes ] = "same_kind" ,
133
143
):
134
144
_concat_check (tup , dtype , out = None )
135
145
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
136
146
return torch .hstack (tensors )
137
147
138
148
139
149
def dstack (
140
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
150
+ tup : Sequence [ArrayLike ],
151
+ * ,
152
+ dtype : Optional [DTypeLike ] = None ,
153
+ casting : Optional [CastingModes ] = "same_kind" ,
141
154
):
142
155
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
143
156
# but {h,v}stack do. Hence add them here for consistency.
@@ -147,7 +160,10 @@ def dstack(
147
160
148
161
149
162
def column_stack (
150
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
163
+ tup : Sequence [ArrayLike ],
164
+ * ,
165
+ dtype : Optional [DTypeLike ] = None ,
166
+ casting : Optional [CastingModes ] = "same_kind" ,
151
167
):
152
168
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
153
169
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -163,7 +179,7 @@ def stack(
163
179
out : Optional [OutArray ] = None ,
164
180
* ,
165
181
dtype : Optional [DTypeLike ] = None ,
166
- casting = "same_kind" ,
182
+ casting : Optional [ CastingModes ] = "same_kind" ,
167
183
):
168
184
_concat_check (arrays , dtype , out = out )
169
185
@@ -1166,6 +1182,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
1166
1182
def tensordot (a : ArrayLike , b : ArrayLike , axes = 2 ):
1167
1183
if isinstance (axes , (list , tuple )):
1168
1184
axes = [[ax ] if isinstance (ax , int ) else ax for ax in axes ]
1185
+
1186
+ target_dtype = _dtypes_impl .result_type_impl ((a .dtype , b .dtype ))
1187
+ a = _util .cast_if_needed (a , target_dtype )
1188
+ b = _util .cast_if_needed (b , target_dtype )
1189
+
1169
1190
return torch .tensordot (a , b , dims = axes )
1170
1191
1171
1192
@@ -1208,6 +1229,74 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1208
1229
return torch .outer (a , b )
1209
1230
1210
1231
1232
+ def einsum (* operands , out = None , optimize = False , ** kwargs ):
1233
+ # Have to manually normalize *operands and **kwargs, following the NumPy signature
1234
+ # >>> np.einsum?
1235
+ # Signature: np.einsum(*operands, out=None, optimize=False, **kwargs)
1236
+ # Docstring:
1237
+ # einsum(subscripts, *operands, out=None, dtype=None, order='K',
1238
+ # casting='safe', optimize=False)
1239
+
1240
+ from ._normalizations import (
1241
+ maybe_copy_to ,
1242
+ normalize_casting ,
1243
+ normalize_dtype ,
1244
+ normalize_not_implemented ,
1245
+ normalize_outarray ,
1246
+ wrap_tensors ,
1247
+ )
1248
+
1249
+ dtype = normalize_dtype (kwargs .pop ("dtype" , None ))
1250
+ casting = normalize_casting (kwargs .pop ("casting" , "safe" ))
1251
+
1252
+ parm = lambda _ : None # a fake duck-typed inspect.Parameter stub
1253
+ parm .name = "out"
1254
+ out = normalize_outarray (out , parm = parm )
1255
+
1256
+ parm .default = "K"
1257
+ parm .name = "order"
1258
+ order = normalize_not_implemented (kwargs .pop ("order" , "K" ), parm = parm )
1259
+ if kwargs :
1260
+ raise TypeError ("unknown arguments: " , kwargs )
1261
+
1262
+ # parse arrays and normalize them
1263
+ if isinstance (operands [0 ], str ):
1264
+ # ("ij->", arrays) format
1265
+ sublist_format = False
1266
+ subscripts , array_operands = operands [0 ], operands [1 :]
1267
+ else :
1268
+ # op, str, op, str ... format: normalize every other argument
1269
+ sublist_format = True
1270
+ array_operands = operands [:- 1 ][::2 ]
1271
+
1272
+ tensors = [normalize_array_like (op ) for op in array_operands ]
1273
+ target_dtype = (
1274
+ _dtypes_impl .result_type_impl ([op .dtype for op in tensors ])
1275
+ if dtype is None
1276
+ else dtype
1277
+ )
1278
+ tensors = _util .typecast_tensors (tensors , target_dtype , casting )
1279
+
1280
+ if sublist_format :
1281
+ # recombine operands
1282
+ sublists = operands [1 ::2 ]
1283
+ has_sublistout = len (operands ) % 2 == 1
1284
+ if has_sublistout :
1285
+ sublistout = operands [- 1 ]
1286
+ operands = list (itertools .chain (* zip (tensors , sublists )))
1287
+ if has_sublistout :
1288
+ operands .append (sublistout )
1289
+
1290
+ result = torch .einsum (* operands )
1291
+ else :
1292
+ result = torch .einsum (subscripts , * tensors )
1293
+
1294
+
1295
+
1296
+ result = maybe_copy_to (out , result )
1297
+ return wrap_tensors (result )
1298
+
1299
+
1211
1300
# ### sort and partition ###
1212
1301
1213
1302
@@ -1798,8 +1887,6 @@ def bartlett(M):
1798
1887
1799
1888
1800
1889
def common_type (* tensors : ArrayLike ):
1801
- import builtins
1802
-
1803
1890
is_complex = False
1804
1891
precision = 0
1805
1892
for a in tensors :
0 commit comments