Skip to content

Commit 78c333a

Browse files
committed
ENH: add einsum
1 parent b0ade4e commit 78c333a

File tree

6 files changed

+139
-52
lines changed

6 files changed

+139
-52
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,6 @@ def ediff1d(ary, to_end=None, to_begin=None):
220220
raise NotImplementedError
221221

222222

223-
def einsum(*operands, out=None, optimize=False, **kwargs):
224-
raise NotImplementedError
225-
226-
227223
def einsum_path(*operands, optimize="greedy", einsum_call=False):
228224
raise NotImplementedError
229225

torch_np/_funcs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,13 @@
3131
func = getattr(_funcs_impl, name)
3232
if name in ["percentile", "quantile", "median"]:
3333
decorated = normalizer(func, promote_scalar_result=True)
34+
elif name == "einsum":
35+
# normalized manually
36+
decorated = func
3437
else:
3538
decorated = normalizer(func)
3639

37-
decorated.__qualname__ = name # XXX: is this really correct?
40+
decorated.__qualname__ = name
3841
decorated.__name__ = name
3942
vars()[name] = decorated
4043

torch_np/_funcs_impl.py

Lines changed: 97 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import builtins
11+
import itertools
1112
import math
1213
import operator
1314
from typing import Optional, Sequence
@@ -20,6 +21,7 @@
2021
from ._normalizations import (
2122
ArrayLike,
2223
AxisLike,
24+
CastingModes,
2325
DTypeLike,
2426
NDArray,
2527
NotImplementedType,
@@ -39,7 +41,7 @@ def copy(
3941
def copyto(
4042
dst: NDArray,
4143
src: ArrayLike,
42-
casting="same_kind",
44+
casting: Optional[CastingModes] = "same_kind",
4345
where: NotImplementedType = None,
4446
):
4547
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
@@ -98,7 +100,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
98100
return tensors
99101

100102

101-
def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
103+
def _concatenate(
104+
tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
105+
):
102106
# pure torch implementation, used below and in cov/corrcoef below
103107
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
104108
tensors = _concat_cast_helper(tensors, out, dtype, casting)
@@ -110,15 +114,18 @@ def concatenate(
110114
axis=0,
111115
out: Optional[OutArray] = None,
112116
dtype: Optional[DTypeLike] = None,
113-
casting="same_kind",
117+
casting: Optional[CastingModes] = "same_kind",
114118
):
115119
_concat_check(ar_tuple, dtype, out=out)
116120
result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
117121
return result
118122

119123

120124
def vstack(
121-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
125+
tup: Sequence[ArrayLike],
126+
*,
127+
dtype: Optional[DTypeLike] = None,
128+
casting: Optional[CastingModes] = "same_kind",
122129
):
123130
_concat_check(tup, dtype, out=None)
124131
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
@@ -129,15 +136,21 @@ def vstack(
129136

130137

131138
def hstack(
132-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
139+
tup: Sequence[ArrayLike],
140+
*,
141+
dtype: Optional[DTypeLike] = None,
142+
casting: Optional[CastingModes] = "same_kind",
133143
):
134144
_concat_check(tup, dtype, out=None)
135145
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
136146
return torch.hstack(tensors)
137147

138148

139149
def dstack(
140-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
150+
tup: Sequence[ArrayLike],
151+
*,
152+
dtype: Optional[DTypeLike] = None,
153+
casting: Optional[CastingModes] = "same_kind",
141154
):
142155
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
143156
# but {h,v}stack do. Hence add them here for consistency.
@@ -147,7 +160,10 @@ def dstack(
147160

148161

149162
def column_stack(
150-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
163+
tup: Sequence[ArrayLike],
164+
*,
165+
dtype: Optional[DTypeLike] = None,
166+
casting: Optional[CastingModes] = "same_kind",
151167
):
152168
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
153169
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -163,7 +179,7 @@ def stack(
163179
out: Optional[OutArray] = None,
164180
*,
165181
dtype: Optional[DTypeLike] = None,
166-
casting="same_kind",
182+
casting: Optional[CastingModes] = "same_kind",
167183
):
168184
_concat_check(arrays, dtype, out=out)
169185

@@ -1166,6 +1182,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
11661182
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
11671183
if isinstance(axes, (list, tuple)):
11681184
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
1185+
1186+
target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
1187+
a = _util.cast_if_needed(a, target_dtype)
1188+
b = _util.cast_if_needed(b, target_dtype)
1189+
11691190
return torch.tensordot(a, b, dims=axes)
11701191

11711192

@@ -1208,6 +1229,74 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
12081229
return torch.outer(a, b)
12091230

12101231

1232+
def einsum(*operands, out=None, optimize=False, **kwargs):
1233+
# Have to manually normalize *operands and **kwargs, following the NumPy signature
1234+
# >>> np.einsum?
1235+
# Signature: np.einsum(*operands, out=None, optimize=False, **kwargs)
1236+
# Docstring:
1237+
# einsum(subscripts, *operands, out=None, dtype=None, order='K',
1238+
# casting='safe', optimize=False)
1239+
1240+
from ._normalizations import (
1241+
maybe_copy_to,
1242+
normalize_casting,
1243+
normalize_dtype,
1244+
normalize_not_implemented,
1245+
normalize_outarray,
1246+
wrap_tensors,
1247+
)
1248+
1249+
dtype = normalize_dtype(kwargs.pop("dtype", None))
1250+
casting = normalize_casting(kwargs.pop("casting", "safe"))
1251+
1252+
parm = lambda _: None # a fake duck-typed inspect.Parameter stub
1253+
parm.name = "out"
1254+
out = normalize_outarray(out, parm=parm)
1255+
1256+
parm.default = "K"
1257+
parm.name = "order"
1258+
order = normalize_not_implemented(kwargs.pop("order", "K"), parm=parm)
1259+
if kwargs:
1260+
raise TypeError("unknown arguments: ", kwargs)
1261+
1262+
# parse arrays and normalize them
1263+
if isinstance(operands[0], str):
1264+
# ("ij->", arrays) format
1265+
sublist_format = False
1266+
subscripts, array_operands = operands[0], operands[1:]
1267+
else:
1268+
# op, str, op, str ... format: normalize every other argument
1269+
sublist_format = True
1270+
array_operands = operands[:-1][::2]
1271+
1272+
tensors = [normalize_array_like(op) for op in array_operands]
1273+
target_dtype = (
1274+
_dtypes_impl.result_type_impl([op.dtype for op in tensors])
1275+
if dtype is None
1276+
else dtype
1277+
)
1278+
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
1279+
1280+
if sublist_format:
1281+
# recombine operands
1282+
sublists = operands[1::2]
1283+
has_sublistout = len(operands) % 2 == 1
1284+
if has_sublistout:
1285+
sublistout = operands[-1]
1286+
operands = list(itertools.chain(*zip(tensors, sublists)))
1287+
if has_sublistout:
1288+
operands.append(sublistout)
1289+
1290+
result = torch.einsum(*operands)
1291+
else:
1292+
result = torch.einsum(subscripts, *tensors)
1293+
1294+
1295+
1296+
result = maybe_copy_to(out, result)
1297+
return wrap_tensors(result)
1298+
1299+
12111300
# ### sort and partition ###
12121301

12131302

@@ -1798,8 +1887,6 @@ def bartlett(M):
17981887

17991888

18001889
def common_type(*tensors: ArrayLike):
1801-
import builtins
1802-
18031890
is_complex = False
18041891
precision = 0
18051892
for a in tensors:

torch_np/_normalizations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DTypeLike = typing.TypeVar("DTypeLike")
1515
AxisLike = typing.TypeVar("AxisLike")
1616
NDArray = typing.TypeVar("NDarray")
17+
CastingModes = typing.TypeVar("CastingModes")
1718

1819
# OutArray is to annotate the out= array argument.
1920
#
@@ -101,6 +102,14 @@ def normalize_outarray(arg, parm=None):
101102
return arg
102103

103104

105+
def normalize_casting(arg, parm=None):
106+
if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
107+
raise ValueError(
108+
f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
109+
)
110+
return arg
111+
112+
104113
normalizers = {
105114
"ArrayLike": normalize_array_like,
106115
"Optional[ArrayLike]": normalize_optional_array_like,
@@ -111,6 +120,7 @@ def normalize_outarray(arg, parm=None):
111120
"Optional[DTypeLike]": normalize_dtype,
112121
"AxisLike": normalize_axis_like,
113122
"NotImplementedType": normalize_not_implemented,
123+
"Optional[CastingModes]": normalize_casting,
114124
}
115125

116126

torch_np/_ufuncs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import _binary_ufuncs_impl, _dtypes_impl, _helpers, _unary_ufuncs_impl, _util
88
from ._normalizations import (
99
ArrayLike,
10+
CastingModes,
1011
DTypeLike,
1112
NotImplementedType,
1213
OutArray,
@@ -54,7 +55,7 @@ def wrapped(
5455
out: Optional[OutArray] = None,
5556
*,
5657
where=True,
57-
casting="same_kind",
58+
casting: Optional[CastingModes] = "same_kind",
5859
order="K",
5960
dtype: Optional[DTypeLike] = None,
6061
subok: NotImplementedType = False,
@@ -87,7 +88,7 @@ def matmul(
8788
/,
8889
out: Optional[OutArray] = None,
8990
*,
90-
casting="same_kind",
91+
casting: Optional[CastingModes] = "same_kind",
9192
order: NotImplementedType = "K",
9293
dtype: Optional[DTypeLike] = None,
9394
subok: NotImplementedType = False,
@@ -118,7 +119,7 @@ def divmod(
118119
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
119120
*,
120121
where: NotImplementedType = True,
121-
casting="same_kind",
122+
casting: Optional[CastingModes] = "same_kind",
122123
order: NotImplementedType = "K",
123124
dtype: Optional[DTypeLike] = None,
124125
subok: NotImplementedType = False,
@@ -190,7 +191,7 @@ def wrapped(
190191
out: Optional[OutArray] = None,
191192
*,
192193
where=True,
193-
casting="same_kind",
194+
casting: Optional[CastingModes] = "same_kind",
194195
order="K",
195196
dtype: Optional[DTypeLike] = None,
196197
subok: NotImplementedType = False,

0 commit comments

Comments
 (0)