Skip to content

Commit cd5f74a

Browse files
authored
Merge pull request #127 from Quansight-Labs/einsum
ENH: add einsum
2 parents b0ade4e + 7e9f49c commit cd5f74a

File tree

6 files changed

+157
-55
lines changed

6 files changed

+157
-55
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,6 @@ def ediff1d(ary, to_end=None, to_begin=None):
220220
raise NotImplementedError
221221

222222

223-
def einsum(*operands, out=None, optimize=False, **kwargs):
224-
raise NotImplementedError
225-
226-
227223
def einsum_path(*operands, optimize="greedy", einsum_call=False):
228224
raise NotImplementedError
229225

torch_np/_funcs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,13 @@
3131
func = getattr(_funcs_impl, name)
3232
if name in ["percentile", "quantile", "median"]:
3333
decorated = normalizer(func, promote_scalar_result=True)
34+
elif name == "einsum":
35+
# normalized manually
36+
decorated = func
3437
else:
3538
decorated = normalizer(func)
3639

37-
decorated.__qualname__ = name # XXX: is this really correct?
40+
decorated.__qualname__ = name
3841
decorated.__name__ = name
3942
vars()[name] = decorated
4043

torch_np/_funcs_impl.py

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import builtins
11+
import itertools
1112
import math
1213
import operator
1314
from typing import Optional, Sequence
@@ -17,14 +18,21 @@
1718
from . import _dtypes_impl
1819
from . import _reductions as _impl
1920
from . import _util
20-
from ._normalizations import (
21+
22+
# these imports are for einsum only
23+
from ._normalizations import ( # isort: skip
2124
ArrayLike,
2225
AxisLike,
26+
CastingModes,
2327
DTypeLike,
2428
NDArray,
2529
NotImplementedType,
2630
OutArray,
31+
maybe_copy_to,
2732
normalize_array_like,
33+
normalize_casting,
34+
normalize_dtype,
35+
wrap_tensors,
2836
)
2937

3038
# ###### array creation routines
@@ -39,7 +47,7 @@ def copy(
3947
def copyto(
4048
dst: NDArray,
4149
src: ArrayLike,
42-
casting="same_kind",
50+
casting: Optional[CastingModes] = "same_kind",
4351
where: NotImplementedType = None,
4452
):
4553
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
@@ -98,7 +106,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
98106
return tensors
99107

100108

101-
def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
109+
def _concatenate(
110+
tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
111+
):
102112
# pure torch implementation, used below and in cov/corrcoef below
103113
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
104114
tensors = _concat_cast_helper(tensors, out, dtype, casting)
@@ -110,15 +120,18 @@ def concatenate(
110120
axis=0,
111121
out: Optional[OutArray] = None,
112122
dtype: Optional[DTypeLike] = None,
113-
casting="same_kind",
123+
casting: Optional[CastingModes] = "same_kind",
114124
):
115125
_concat_check(ar_tuple, dtype, out=out)
116126
result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
117127
return result
118128

119129

120130
def vstack(
121-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
131+
tup: Sequence[ArrayLike],
132+
*,
133+
dtype: Optional[DTypeLike] = None,
134+
casting: Optional[CastingModes] = "same_kind",
122135
):
123136
_concat_check(tup, dtype, out=None)
124137
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
@@ -129,15 +142,21 @@ def vstack(
129142

130143

131144
def hstack(
132-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
145+
tup: Sequence[ArrayLike],
146+
*,
147+
dtype: Optional[DTypeLike] = None,
148+
casting: Optional[CastingModes] = "same_kind",
133149
):
134150
_concat_check(tup, dtype, out=None)
135151
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
136152
return torch.hstack(tensors)
137153

138154

139155
def dstack(
140-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
156+
tup: Sequence[ArrayLike],
157+
*,
158+
dtype: Optional[DTypeLike] = None,
159+
casting: Optional[CastingModes] = "same_kind",
141160
):
142161
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
143162
# but {h,v}stack do. Hence add them here for consistency.
@@ -147,7 +166,10 @@ def dstack(
147166

148167

149168
def column_stack(
150-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
169+
tup: Sequence[ArrayLike],
170+
*,
171+
dtype: Optional[DTypeLike] = None,
172+
casting: Optional[CastingModes] = "same_kind",
151173
):
152174
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
153175
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -163,7 +185,7 @@ def stack(
163185
out: Optional[OutArray] = None,
164186
*,
165187
dtype: Optional[DTypeLike] = None,
166-
casting="same_kind",
188+
casting: Optional[CastingModes] = "same_kind",
167189
):
168190
_concat_check(arrays, dtype, out=out)
169191

@@ -1166,6 +1188,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
11661188
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
11671189
if isinstance(axes, (list, tuple)):
11681190
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
1191+
1192+
target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
1193+
a = _util.cast_if_needed(a, target_dtype)
1194+
b = _util.cast_if_needed(b, target_dtype)
1195+
11691196
return torch.tensordot(a, b, dims=axes)
11701197

11711198

@@ -1208,6 +1235,77 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
12081235
return torch.outer(a, b)
12091236

12101237

1238+
def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False):
1239+
# Have to manually normalize *operands and **kwargs, following the NumPy signature
1240+
1241+
from ._ndarray import ndarray
1242+
1243+
dtype = normalize_dtype(dtype)
1244+
casting = normalize_casting(casting)
1245+
if out is not None and not isinstance(out, ndarray):
1246+
raise TypeError("'out' must be an array")
1247+
if order != "K":
1248+
raise NotImplementedError("'order' parameter is not supported.")
1249+
1250+
# parse arrays and normalize them
1251+
sublist_format = not isinstance(operands[0], str)
1252+
if sublist_format:
1253+
# op, str, op, str ... [sublistout] format: normalize every other argument
1254+
1255+
# - if sublistout is not given, the length of operands is even, and we pick
1256+
# odd-numbered elements, which are arrays.
1257+
# - if sublistout is given, the length of operands is odd, we peel off
1258+
# the last one, and pick odd-numbered elements, which are arrays.
1259+
# Without [:-1], we would have picked sublistout, too.
1260+
array_operands = operands[:-1][::2]
1261+
else:
1262+
# ("ij->", arrays) format
1263+
subscripts, array_operands = operands[0], operands[1:]
1264+
1265+
tensors = [normalize_array_like(op) for op in array_operands]
1266+
target_dtype = (
1267+
_dtypes_impl.result_type_impl([op.dtype for op in tensors])
1268+
if dtype is None
1269+
else dtype
1270+
)
1271+
1272+
# work around 'bmm' not implemented for 'Half' etc
1273+
is_half = target_dtype == torch.float16
1274+
if is_half:
1275+
target_dtype = torch.float32
1276+
1277+
is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
1278+
if is_short_int:
1279+
target_dtype = torch.int64
1280+
1281+
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
1282+
1283+
try:
1284+
# set the global state to handle the optimize=... argument, restore on exit
1285+
old_strategy = torch.backends.opt_einsum.strategy
1286+
torch.backends.opt_einsum.strategy = optimize
1287+
1288+
if sublist_format:
1289+
# recombine operands
1290+
sublists = operands[1::2]
1291+
has_sublistout = len(operands) % 2 == 1
1292+
if has_sublistout:
1293+
sublistout = operands[-1]
1294+
operands = list(itertools.chain(*zip(tensors, sublists)))
1295+
if has_sublistout:
1296+
operands.append(sublistout)
1297+
1298+
result = torch.einsum(*operands)
1299+
else:
1300+
result = torch.einsum(subscripts, *tensors)
1301+
1302+
finally:
1303+
torch.backends.opt_einsum.strategy = old_strategy
1304+
1305+
result = maybe_copy_to(out, result)
1306+
return wrap_tensors(result)
1307+
1308+
12111309
# ### sort and partition ###
12121310

12131311

@@ -1798,8 +1896,6 @@ def bartlett(M):
17981896

17991897

18001898
def common_type(*tensors: ArrayLike):
1801-
import builtins
1802-
18031899
is_complex = False
18041900
precision = 0
18051901
for a in tensors:

torch_np/_normalizations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DTypeLike = typing.TypeVar("DTypeLike")
1515
AxisLike = typing.TypeVar("AxisLike")
1616
NDArray = typing.TypeVar("NDarray")
17+
CastingModes = typing.TypeVar("CastingModes")
1718

1819
# OutArray is to annotate the out= array argument.
1920
#
@@ -101,6 +102,14 @@ def normalize_outarray(arg, parm=None):
101102
return arg
102103

103104

105+
def normalize_casting(arg, parm=None):
106+
if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
107+
raise ValueError(
108+
f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
109+
)
110+
return arg
111+
112+
104113
normalizers = {
105114
"ArrayLike": normalize_array_like,
106115
"Optional[ArrayLike]": normalize_optional_array_like,
@@ -111,6 +120,7 @@ def normalize_outarray(arg, parm=None):
111120
"Optional[DTypeLike]": normalize_dtype,
112121
"AxisLike": normalize_axis_like,
113122
"NotImplementedType": normalize_not_implemented,
123+
"Optional[CastingModes]": normalize_casting,
114124
}
115125

116126

torch_np/_ufuncs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import _binary_ufuncs_impl, _dtypes_impl, _helpers, _unary_ufuncs_impl, _util
88
from ._normalizations import (
99
ArrayLike,
10+
CastingModes,
1011
DTypeLike,
1112
NotImplementedType,
1213
OutArray,
@@ -54,7 +55,7 @@ def wrapped(
5455
out: Optional[OutArray] = None,
5556
*,
5657
where=True,
57-
casting="same_kind",
58+
casting: Optional[CastingModes] = "same_kind",
5859
order="K",
5960
dtype: Optional[DTypeLike] = None,
6061
subok: NotImplementedType = False,
@@ -87,7 +88,7 @@ def matmul(
8788
/,
8889
out: Optional[OutArray] = None,
8990
*,
90-
casting="same_kind",
91+
casting: Optional[CastingModes] = "same_kind",
9192
order: NotImplementedType = "K",
9293
dtype: Optional[DTypeLike] = None,
9394
subok: NotImplementedType = False,
@@ -118,7 +119,7 @@ def divmod(
118119
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
119120
*,
120121
where: NotImplementedType = True,
121-
casting="same_kind",
122+
casting: Optional[CastingModes] = "same_kind",
122123
order: NotImplementedType = "K",
123124
dtype: Optional[DTypeLike] = None,
124125
subok: NotImplementedType = False,
@@ -190,7 +191,7 @@ def wrapped(
190191
out: Optional[OutArray] = None,
191192
*,
192193
where=True,
193-
casting="same_kind",
194+
casting: Optional[CastingModes] = "same_kind",
194195
order="K",
195196
dtype: Optional[DTypeLike] = None,
196197
subok: NotImplementedType = False,

0 commit comments

Comments
 (0)