Skip to content

Commit 53447cd

Browse files
committed
ENH: add einsum
1 parent b0ade4e commit 53447cd

File tree

6 files changed

+132
-52
lines changed

6 files changed

+132
-52
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,6 @@ def ediff1d(ary, to_end=None, to_begin=None):
220220
raise NotImplementedError
221221

222222

223-
def einsum(*operands, out=None, optimize=False, **kwargs):
224-
raise NotImplementedError
225-
226-
227223
def einsum_path(*operands, optimize="greedy", einsum_call=False):
228224
raise NotImplementedError
229225

torch_np/_funcs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,13 @@
3131
func = getattr(_funcs_impl, name)
3232
if name in ["percentile", "quantile", "median"]:
3333
decorated = normalizer(func, promote_scalar_result=True)
34+
elif name == "einsum":
35+
# normalized manually
36+
decorated = func
3437
else:
3538
decorated = normalizer(func)
3639

37-
decorated.__qualname__ = name # XXX: is this really correct?
40+
decorated.__qualname__ = name
3841
decorated.__name__ = name
3942
vars()[name] = decorated
4043

torch_np/_funcs_impl.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ._normalizations import (
2121
ArrayLike,
2222
AxisLike,
23+
CastingModes,
2324
DTypeLike,
2425
NDArray,
2526
NotImplementedType,
@@ -39,7 +40,7 @@ def copy(
3940
def copyto(
4041
dst: NDArray,
4142
src: ArrayLike,
42-
casting="same_kind",
43+
casting: Optional[CastingModes] = "same_kind",
4344
where: NotImplementedType = None,
4445
):
4546
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
@@ -98,7 +99,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
9899
return tensors
99100

100101

101-
def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
102+
def _concatenate(
103+
tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
104+
):
102105
# pure torch implementation, used below and in cov/corrcoef below
103106
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
104107
tensors = _concat_cast_helper(tensors, out, dtype, casting)
@@ -110,15 +113,18 @@ def concatenate(
110113
axis=0,
111114
out: Optional[OutArray] = None,
112115
dtype: Optional[DTypeLike] = None,
113-
casting="same_kind",
116+
casting: Optional[CastingModes] = "same_kind",
114117
):
115118
_concat_check(ar_tuple, dtype, out=out)
116119
result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
117120
return result
118121

119122

120123
def vstack(
121-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
124+
tup: Sequence[ArrayLike],
125+
*,
126+
dtype: Optional[DTypeLike] = None,
127+
casting: Optional[CastingModes] = "same_kind",
122128
):
123129
_concat_check(tup, dtype, out=None)
124130
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
@@ -129,15 +135,21 @@ def vstack(
129135

130136

131137
def hstack(
132-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
138+
tup: Sequence[ArrayLike],
139+
*,
140+
dtype: Optional[DTypeLike] = None,
141+
casting: Optional[CastingModes] = "same_kind",
133142
):
134143
_concat_check(tup, dtype, out=None)
135144
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
136145
return torch.hstack(tensors)
137146

138147

139148
def dstack(
140-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
149+
tup: Sequence[ArrayLike],
150+
*,
151+
dtype: Optional[DTypeLike] = None,
152+
casting: Optional[CastingModes] = "same_kind",
141153
):
142154
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
143155
# but {h,v}stack do. Hence add them here for consistency.
@@ -147,7 +159,10 @@ def dstack(
147159

148160

149161
def column_stack(
150-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
162+
tup: Sequence[ArrayLike],
163+
*,
164+
dtype: Optional[DTypeLike] = None,
165+
casting: Optional[CastingModes] = "same_kind",
151166
):
152167
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
153168
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -163,7 +178,7 @@ def stack(
163178
out: Optional[OutArray] = None,
164179
*,
165180
dtype: Optional[DTypeLike] = None,
166-
casting="same_kind",
181+
casting: Optional[CastingModes] = "same_kind",
167182
):
168183
_concat_check(arrays, dtype, out=out)
169184

@@ -1166,6 +1181,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
11661181
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
11671182
if isinstance(axes, (list, tuple)):
11681183
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
1184+
1185+
target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
1186+
a = _util.cast_if_needed(a, target_dtype)
1187+
b = _util.cast_if_needed(b, target_dtype)
1188+
11691189
return torch.tensordot(a, b, dims=axes)
11701190

11711191

@@ -1208,6 +1228,68 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
12081228
return torch.outer(a, b)
12091229

12101230

1231+
def einsum(*operands, out=None, optimize=False, **kwargs):
1232+
# Have to manually normalize *operands and **kwargs, following the NumPy signature
1233+
# >>> np.einsum?
1234+
# Signature: np.einsum(*operands, out=None, optimize=False, **kwargs)
1235+
# Docstring:
1236+
# einsum(subscripts, *operands, out=None, dtype=None, order='K',
1237+
# casting='safe', optimize=False)
1238+
1239+
from ._normalizations import (
1240+
maybe_copy_to,
1241+
normalize_casting,
1242+
normalize_dtype,
1243+
normalize_not_implemented,
1244+
normalize_outarray,
1245+
wrap_tensors,
1246+
)
1247+
1248+
dtype = normalize_dtype(kwargs.pop("dtype", None))
1249+
casting = normalize_casting(kwargs.pop("casting", "safe"))
1250+
1251+
parm = lambda _: None # a fake duck-typed inspect.Parameter stub
1252+
parm.name = "out"
1253+
out = normalize_outarray(out, parm=parm)
1254+
1255+
parm.default = "K"
1256+
parm.name = "order"
1257+
order = normalize_not_implemented(kwargs.pop("order", "K"), parm=parm)
1258+
if kwargs:
1259+
raise TypeError("unknown arguments: ", kwargs)
1260+
1261+
# parse arrays and normalize them
1262+
if isinstance(operands[0], str):
1263+
# ("ij->", arrays) format
1264+
sublist_format = False
1265+
subscripts, array_operands = operands[0], operands[1:]
1266+
else:
1267+
# op, str, op, str ... format: normalize every other argument
1268+
sublist_format = True
1269+
array_operands = operands[:-1][::2]
1270+
1271+
tensors = [normalize_array_like(op) for op in array_operands]
1272+
target_dtype = (
1273+
_dtypes_impl.result_type_impl([op.dtype for op in tensors])
1274+
if dtype is None
1275+
else dtype
1276+
)
1277+
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
1278+
1279+
if sublist_format:
1280+
# recombine operands
1281+
sublists = operands[1::2]
1282+
sublistout = (operands[-1],) if len(operands) % 2 == 1 else ()
1283+
operands = builtins.sum((_ for _ in zip(tensors, sublists)), ()) + sublistout
1284+
1285+
result = torch.einsum(*operands)
1286+
else:
1287+
result = torch.einsum(subscripts, *tensors)
1288+
1289+
result = maybe_copy_to(out, result)
1290+
return wrap_tensors(result)
1291+
1292+
12111293
# ### sort and partition ###
12121294

12131295

@@ -1798,8 +1880,6 @@ def bartlett(M):
17981880

17991881

18001882
def common_type(*tensors: ArrayLike):
1801-
import builtins
1802-
18031883
is_complex = False
18041884
precision = 0
18051885
for a in tensors:

torch_np/_normalizations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DTypeLike = typing.TypeVar("DTypeLike")
1515
AxisLike = typing.TypeVar("AxisLike")
1616
NDArray = typing.TypeVar("NDarray")
17+
CastingModes = typing.TypeVar("CastingModes")
1718

1819
# OutArray is to annotate the out= array argument.
1920
#
@@ -101,6 +102,14 @@ def normalize_outarray(arg, parm=None):
101102
return arg
102103

103104

105+
def normalize_casting(arg, parm=None):
106+
if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
107+
raise ValueError(
108+
f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
109+
)
110+
return arg
111+
112+
104113
normalizers = {
105114
"ArrayLike": normalize_array_like,
106115
"Optional[ArrayLike]": normalize_optional_array_like,
@@ -111,6 +120,7 @@ def normalize_outarray(arg, parm=None):
111120
"Optional[DTypeLike]": normalize_dtype,
112121
"AxisLike": normalize_axis_like,
113122
"NotImplementedType": normalize_not_implemented,
123+
"Optional[CastingModes]": normalize_casting,
114124
}
115125

116126

torch_np/_ufuncs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import _binary_ufuncs_impl, _dtypes_impl, _helpers, _unary_ufuncs_impl, _util
88
from ._normalizations import (
99
ArrayLike,
10+
CastingModes,
1011
DTypeLike,
1112
NotImplementedType,
1213
OutArray,
@@ -54,7 +55,7 @@ def wrapped(
5455
out: Optional[OutArray] = None,
5556
*,
5657
where=True,
57-
casting="same_kind",
58+
casting: Optional[CastingModes] = "same_kind",
5859
order="K",
5960
dtype: Optional[DTypeLike] = None,
6061
subok: NotImplementedType = False,
@@ -87,7 +88,7 @@ def matmul(
8788
/,
8889
out: Optional[OutArray] = None,
8990
*,
90-
casting="same_kind",
91+
casting: Optional[CastingModes] = "same_kind",
9192
order: NotImplementedType = "K",
9293
dtype: Optional[DTypeLike] = None,
9394
subok: NotImplementedType = False,
@@ -118,7 +119,7 @@ def divmod(
118119
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
119120
*,
120121
where: NotImplementedType = True,
121-
casting="same_kind",
122+
casting: Optional[CastingModes] = "same_kind",
122123
order: NotImplementedType = "K",
123124
dtype: Optional[DTypeLike] = None,
124125
subok: NotImplementedType = False,
@@ -190,7 +191,7 @@ def wrapped(
190191
out: Optional[OutArray] = None,
191192
*,
192193
where=True,
193-
casting="same_kind",
194+
casting: Optional[CastingModes] = "same_kind",
194195
order="K",
195196
dtype: Optional[DTypeLike] = None,
196197
subok: NotImplementedType = False,

0 commit comments

Comments
 (0)