@@ -1229,17 +1229,16 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1229
1229
return torch .outer (a , b )
1230
1230
1231
1231
1232
- def einsum (* operands , out = None , dtype = None , order = 'K' ,
1233
- casting = 'safe' , optimize = False ):
1232
+ def einsum (* operands , out = None , dtype = None , order = "K" , casting = "safe" , optimize = False ):
1234
1233
# Have to manually normalize *operands and **kwargs, following the NumPy signature
1235
1234
1235
+ from ._ndarray import ndarray
1236
1236
from ._normalizations import (
1237
1237
maybe_copy_to ,
1238
1238
normalize_casting ,
1239
1239
normalize_dtype ,
1240
1240
wrap_tensors ,
1241
1241
)
1242
- from ._ndarray import ndarray
1243
1242
1244
1243
dtype = normalize_dtype (dtype )
1245
1244
casting = normalize_casting (casting )
@@ -1251,7 +1250,13 @@ def einsum(*operands, out=None, dtype=None, order='K',
1251
1250
# parse arrays and normalize them
1252
1251
sublist_format = not isinstance (operands [0 ], str )
1253
1252
if sublist_format :
1254
- # op, str, op, str ... format: normalize every other argument
1253
+ # op, str, op, str ... [sublistout] format: normalize every other argument
1254
+
1255
+ # - if sublistout is not given, the length of operands is even, and we pick
1256
+ # odd-numbered elements, which are arrays.
1257
+ # - if sublistout is given, the length of operands is odd, we peel off
1258
+ # the last one, and pick odd-numbered elements, which are arrays.
1259
+ # Without [:-1], we would have picked sublistout, too.
1255
1260
array_operands = operands [:- 1 ][::2 ]
1256
1261
else :
1257
1262
# ("ij->", arrays) format
@@ -1263,6 +1268,16 @@ def einsum(*operands, out=None, dtype=None, order='K',
1263
1268
if dtype is None
1264
1269
else dtype
1265
1270
)
1271
+
1272
+ # work around 'bmm' not implemented for 'Half' etc
1273
+ is_half = target_dtype == torch .float16
1274
+ if is_half :
1275
+ target_dtype = torch .float32
1276
+
1277
+ is_short_int = target_dtype in [torch .uint8 , torch .int8 , torch .int16 , torch .int32 ]
1278
+ if is_short_int :
1279
+ target_dtype , result_dtype = torch .int64 , target_dtype
1280
+
1266
1281
tensors = _util .typecast_tensors (tensors , target_dtype , casting )
1267
1282
1268
1283
if sublist_format :
0 commit comments