20
20
from ._normalizations import (
21
21
ArrayLike ,
22
22
AxisLike ,
23
+ CastingModes ,
23
24
DTypeLike ,
24
25
NDArray ,
25
26
NotImplementedType ,
@@ -39,7 +40,7 @@ def copy(
39
40
def copyto (
40
41
dst : NDArray ,
41
42
src : ArrayLike ,
42
- casting = "same_kind" ,
43
+ casting : Optional [ CastingModes ] = "same_kind" ,
43
44
where : NotImplementedType = None ,
44
45
):
45
46
(src ,) = _util .typecast_tensors ((src ,), dst .dtype , casting = casting )
@@ -98,7 +99,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
98
99
return tensors
99
100
100
101
101
- def _concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
102
+ def _concatenate (
103
+ tensors , axis = 0 , out = None , dtype = None , casting : Optional [CastingModes ] = "same_kind"
104
+ ):
102
105
# pure torch implementation, used below and in cov/corrcoef below
103
106
tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
104
107
tensors = _concat_cast_helper (tensors , out , dtype , casting )
@@ -110,15 +113,18 @@ def concatenate(
110
113
axis = 0 ,
111
114
out : Optional [OutArray ] = None ,
112
115
dtype : Optional [DTypeLike ] = None ,
113
- casting = "same_kind" ,
116
+ casting : Optional [ CastingModes ] = "same_kind" ,
114
117
):
115
118
_concat_check (ar_tuple , dtype , out = out )
116
119
result = _concatenate (ar_tuple , axis = axis , out = out , dtype = dtype , casting = casting )
117
120
return result
118
121
119
122
120
123
def vstack (
121
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
124
+ tup : Sequence [ArrayLike ],
125
+ * ,
126
+ dtype : Optional [DTypeLike ] = None ,
127
+ casting : Optional [CastingModes ] = "same_kind" ,
122
128
):
123
129
_concat_check (tup , dtype , out = None )
124
130
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
@@ -129,15 +135,21 @@ def vstack(
129
135
130
136
131
137
def hstack (
132
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
138
+ tup : Sequence [ArrayLike ],
139
+ * ,
140
+ dtype : Optional [DTypeLike ] = None ,
141
+ casting : Optional [CastingModes ] = "same_kind" ,
133
142
):
134
143
_concat_check (tup , dtype , out = None )
135
144
tensors = _concat_cast_helper (tup , dtype = dtype , casting = casting )
136
145
return torch .hstack (tensors )
137
146
138
147
139
148
def dstack (
140
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
149
+ tup : Sequence [ArrayLike ],
150
+ * ,
151
+ dtype : Optional [DTypeLike ] = None ,
152
+ casting : Optional [CastingModes ] = "same_kind" ,
141
153
):
142
154
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
143
155
# but {h,v}stack do. Hence add them here for consistency.
@@ -147,7 +159,10 @@ def dstack(
147
159
148
160
149
161
def column_stack (
150
- tup : Sequence [ArrayLike ], * , dtype : Optional [DTypeLike ] = None , casting = "same_kind"
162
+ tup : Sequence [ArrayLike ],
163
+ * ,
164
+ dtype : Optional [DTypeLike ] = None ,
165
+ casting : Optional [CastingModes ] = "same_kind" ,
151
166
):
152
167
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
153
168
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -163,7 +178,7 @@ def stack(
163
178
out : Optional [OutArray ] = None ,
164
179
* ,
165
180
dtype : Optional [DTypeLike ] = None ,
166
- casting = "same_kind" ,
181
+ casting : Optional [ CastingModes ] = "same_kind" ,
167
182
):
168
183
_concat_check (arrays , dtype , out = out )
169
184
@@ -1166,6 +1181,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
1166
1181
def tensordot (a : ArrayLike , b : ArrayLike , axes = 2 ):
1167
1182
if isinstance (axes , (list , tuple )):
1168
1183
axes = [[ax ] if isinstance (ax , int ) else ax for ax in axes ]
1184
+
1185
+ target_dtype = _dtypes_impl .result_type_impl ((a .dtype , b .dtype ))
1186
+ a = _util .cast_if_needed (a , target_dtype )
1187
+ b = _util .cast_if_needed (b , target_dtype )
1188
+
1169
1189
return torch .tensordot (a , b , dims = axes )
1170
1190
1171
1191
@@ -1208,6 +1228,68 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1208
1228
return torch .outer (a , b )
1209
1229
1210
1230
1231
+ def einsum (* operands , out = None , optimize = False , ** kwargs ):
1232
+ # Have to manually normalize *operands and **kwargs, following the NumPy signature
1233
+ # >>> np.einsum?
1234
+ # Signature: np.einsum(*operands, out=None, optimize=False, **kwargs)
1235
+ # Docstring:
1236
+ # einsum(subscripts, *operands, out=None, dtype=None, order='K',
1237
+ # casting='safe', optimize=False)
1238
+
1239
+ from ._normalizations import (
1240
+ maybe_copy_to ,
1241
+ normalize_casting ,
1242
+ normalize_dtype ,
1243
+ normalize_not_implemented ,
1244
+ normalize_outarray ,
1245
+ wrap_tensors ,
1246
+ )
1247
+
1248
+ dtype = normalize_dtype (kwargs .pop ("dtype" , None ))
1249
+ casting = normalize_casting (kwargs .pop ("casting" , "safe" ))
1250
+
1251
+ parm = lambda _ : None # a fake duck-typed inspect.Parameter stub
1252
+ parm .name = "out"
1253
+ out = normalize_outarray (out , parm = parm )
1254
+
1255
+ parm .default = "K"
1256
+ parm .name = "order"
1257
+ order = normalize_not_implemented (kwargs .pop ("order" , "K" ), parm = parm )
1258
+ if kwargs :
1259
+ raise TypeError ("unknown arguments: " , kwargs )
1260
+
1261
+ # parse arrays and normalize them
1262
+ if isinstance (operands [0 ], str ):
1263
+ # ("ij->", arrays) format
1264
+ sublist_format = False
1265
+ subscripts , array_operands = operands [0 ], operands [1 :]
1266
+ else :
1267
+ # op, str, op, str ... format: normalize every other argument
1268
+ sublist_format = True
1269
+ array_operands = operands [:- 1 ][::2 ]
1270
+
1271
+ tensors = [normalize_array_like (op ) for op in array_operands ]
1272
+ target_dtype = (
1273
+ _dtypes_impl .result_type_impl ([op .dtype for op in tensors ])
1274
+ if dtype is None
1275
+ else dtype
1276
+ )
1277
+ tensors = _util .typecast_tensors (tensors , target_dtype , casting )
1278
+
1279
+ if sublist_format :
1280
+ # recombine operands
1281
+ sublists = operands [1 ::2 ]
1282
+ sublistout = (operands [- 1 ],) if len (operands ) % 2 == 1 else ()
1283
+ operands = builtins .sum ((_ for _ in zip (tensors , sublists )), ()) + sublistout
1284
+
1285
+ result = torch .einsum (* operands )
1286
+ else :
1287
+ result = torch .einsum (subscripts , * tensors )
1288
+
1289
+ result = maybe_copy_to (out , result )
1290
+ return wrap_tensors (result )
1291
+
1292
+
1211
1293
# ### sort and partition ###
1212
1294
1213
1295
@@ -1798,8 +1880,6 @@ def bartlett(M):
1798
1880
1799
1881
1800
1882
def common_type (* tensors : ArrayLike ):
1801
- import builtins
1802
-
1803
1883
is_complex = False
1804
1884
precision = 0
1805
1885
for a in tensors :
0 commit comments