Skip to content

Commit 96ac341

Browse files
authored
Merge pull request #126 from Quansight-Labs/einsum_tests
ENH: add einsum + its numpy tests
2 parents 9c627db + a16475b commit 96ac341

File tree

7 files changed

+1266
-19
lines changed

7 files changed

+1266
-19
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,6 @@ def ediff1d(ary, to_end=None, to_begin=None):
220220
raise NotImplementedError
221221

222222

223-
def einsum(*operands, out=None, optimize=False, **kwargs):
224-
raise NotImplementedError
225-
226-
227223
def einsum_path(*operands, optimize="greedy", einsum_call=False):
228224
raise NotImplementedError
229225

torch_np/_funcs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
func = getattr(_funcs_impl, name)
2525
if name in ["percentile", "quantile", "median"]:
2626
decorated = normalizer(func, promote_scalar_result=True)
27+
elif name == "einsum":
28+
# normalized manually
29+
decorated = func
2730
else:
2831
decorated = normalizer(func)
2932

30-
decorated.__qualname__ = name # XXX: is this really correct?
33+
decorated.__qualname__ = name
3134
decorated.__name__ = name
3235
vars()[name] = decorated
3336

torch_np/_funcs_impl.py

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import builtins
11+
import itertools
1112
import operator
1213
from typing import Optional, Sequence
1314

@@ -16,14 +17,21 @@
1617
from . import _dtypes_impl
1718
from . import _reductions as _impl
1819
from . import _util
19-
from ._normalizations import (
20+
21+
# these imports are for einsum only
22+
from ._normalizations import ( # isort: skip
2023
ArrayLike,
2124
AxisLike,
25+
CastingModes,
2226
DTypeLike,
2327
NDArray,
2428
NotImplementedType,
2529
OutArray,
30+
maybe_copy_to,
2631
normalize_array_like,
32+
normalize_casting,
33+
normalize_dtype,
34+
wrap_tensors,
2735
)
2836

2937
# ###### array creation routines
@@ -38,7 +46,7 @@ def copy(
3846
def copyto(
3947
dst: NDArray,
4048
src: ArrayLike,
41-
casting="same_kind",
49+
casting: Optional[CastingModes] = "same_kind",
4250
where: NotImplementedType = None,
4351
):
4452
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
@@ -97,7 +105,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
97105
return tensors
98106

99107

100-
def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
108+
def _concatenate(
109+
tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
110+
):
101111
# pure torch implementation, used below and in cov/corrcoef below
102112
tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
103113
tensors = _concat_cast_helper(tensors, out, dtype, casting)
@@ -109,15 +119,18 @@ def concatenate(
109119
axis=0,
110120
out: Optional[OutArray] = None,
111121
dtype: Optional[DTypeLike] = None,
112-
casting="same_kind",
122+
casting: Optional[CastingModes] = "same_kind",
113123
):
114124
_concat_check(ar_tuple, dtype, out=out)
115125
result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
116126
return result
117127

118128

119129
def vstack(
120-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
130+
tup: Sequence[ArrayLike],
131+
*,
132+
dtype: Optional[DTypeLike] = None,
133+
casting: Optional[CastingModes] = "same_kind",
121134
):
122135
_concat_check(tup, dtype, out=None)
123136
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
@@ -128,15 +141,21 @@ def vstack(
128141

129142

130143
def hstack(
131-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
144+
tup: Sequence[ArrayLike],
145+
*,
146+
dtype: Optional[DTypeLike] = None,
147+
casting: Optional[CastingModes] = "same_kind",
132148
):
133149
_concat_check(tup, dtype, out=None)
134150
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
135151
return torch.hstack(tensors)
136152

137153

138154
def dstack(
139-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
155+
tup: Sequence[ArrayLike],
156+
*,
157+
dtype: Optional[DTypeLike] = None,
158+
casting: Optional[CastingModes] = "same_kind",
140159
):
141160
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
142161
# but {h,v}stack do. Hence add them here for consistency.
@@ -146,7 +165,10 @@ def dstack(
146165

147166

148167
def column_stack(
149-
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
168+
tup: Sequence[ArrayLike],
169+
*,
170+
dtype: Optional[DTypeLike] = None,
171+
casting: Optional[CastingModes] = "same_kind",
150172
):
151173
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
152174
# but row_stack does. (because row_stack is an alias for vstack, really).
@@ -162,7 +184,7 @@ def stack(
162184
out: Optional[OutArray] = None,
163185
*,
164186
dtype: Optional[DTypeLike] = None,
165-
casting="same_kind",
187+
casting: Optional[CastingModes] = "same_kind",
166188
):
167189
_concat_check(arrays, dtype, out=out)
168190

@@ -1152,6 +1174,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
11521174
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
11531175
if isinstance(axes, (list, tuple)):
11541176
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
1177+
1178+
target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
1179+
a = _util.cast_if_needed(a, target_dtype)
1180+
b = _util.cast_if_needed(b, target_dtype)
1181+
11551182
return torch.tensordot(a, b, dims=axes)
11561183

11571184

@@ -1194,6 +1221,77 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
11941221
return torch.outer(a, b)
11951222

11961223

1224+
def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False):
1225+
# Have to manually normalize *operands and **kwargs, following the NumPy signature
1226+
1227+
from ._ndarray import ndarray
1228+
1229+
dtype = normalize_dtype(dtype)
1230+
casting = normalize_casting(casting)
1231+
if out is not None and not isinstance(out, ndarray):
1232+
raise TypeError("'out' must be an array")
1233+
if order != "K":
1234+
raise NotImplementedError("'order' parameter is not supported.")
1235+
1236+
# parse arrays and normalize them
1237+
sublist_format = not isinstance(operands[0], str)
1238+
if sublist_format:
1239+
# op, str, op, str ... [sublistout] format: normalize every other argument
1240+
1241+
# - if sublistout is not given, the length of operands is even, and we pick
1242+
# odd-numbered elements, which are arrays.
1243+
# - if sublistout is given, the length of operands is odd, we peel off
1244+
# the last one, and pick odd-numbered elements, which are arrays.
1245+
# Without [:-1], we would have picked sublistout, too.
1246+
array_operands = operands[:-1][::2]
1247+
else:
1248+
# ("ij->", arrays) format
1249+
subscripts, array_operands = operands[0], operands[1:]
1250+
1251+
tensors = [normalize_array_like(op) for op in array_operands]
1252+
target_dtype = (
1253+
_dtypes_impl.result_type_impl([op.dtype for op in tensors])
1254+
if dtype is None
1255+
else dtype
1256+
)
1257+
1258+
# work around 'bmm' not implemented for 'Half' etc
1259+
is_half = target_dtype == torch.float16
1260+
if is_half:
1261+
target_dtype = torch.float32
1262+
1263+
is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
1264+
if is_short_int:
1265+
target_dtype = torch.int64
1266+
1267+
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
1268+
1269+
try:
1270+
# set the global state to handle the optimize=... argument, restore on exit
1271+
old_strategy = torch.backends.opt_einsum.strategy
1272+
torch.backends.opt_einsum.strategy = optimize
1273+
1274+
if sublist_format:
1275+
# recombine operands
1276+
sublists = operands[1::2]
1277+
has_sublistout = len(operands) % 2 == 1
1278+
if has_sublistout:
1279+
sublistout = operands[-1]
1280+
operands = list(itertools.chain(*zip(tensors, sublists)))
1281+
if has_sublistout:
1282+
operands.append(sublistout)
1283+
1284+
result = torch.einsum(*operands)
1285+
else:
1286+
result = torch.einsum(subscripts, *tensors)
1287+
1288+
finally:
1289+
torch.backends.opt_einsum.strategy = old_strategy
1290+
1291+
result = maybe_copy_to(out, result)
1292+
return wrap_tensors(result)
1293+
1294+
11971295
# ### sort and partition ###
11981296

11991297

torch_np/_normalizations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
DTypeLike = typing.TypeVar("DTypeLike")
1616
AxisLike = typing.TypeVar("AxisLike")
1717
NDArray = typing.TypeVar("NDarray")
18+
CastingModes = typing.TypeVar("CastingModes")
1819

1920
# OutArray is to annotate the out= array argument.
2021
#
@@ -97,6 +98,14 @@ def normalize_outarray(arg, parm=None):
9798
return arg
9899

99100

101+
def normalize_casting(arg, parm=None):
102+
if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
103+
raise ValueError(
104+
f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
105+
)
106+
return arg
107+
108+
100109
normalizers = {
101110
"ArrayLike": normalize_array_like,
102111
"Optional[ArrayLike]": normalize_optional_array_like,
@@ -107,6 +116,7 @@ def normalize_outarray(arg, parm=None):
107116
"Optional[DTypeLike]": normalize_dtype,
108117
"AxisLike": normalize_axis_like,
109118
"NotImplementedType": normalize_not_implemented,
119+
"Optional[CastingModes]": normalize_casting,
110120
}
111121

112122

torch_np/_ufuncs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import _binary_ufuncs_impl, _dtypes_impl, _helpers, _unary_ufuncs_impl, _util
88
from ._normalizations import (
99
ArrayLike,
10+
CastingModes,
1011
DTypeLike,
1112
NotImplementedType,
1213
OutArray,
@@ -54,7 +55,7 @@ def wrapped(
5455
out: Optional[OutArray] = None,
5556
*,
5657
where=True,
57-
casting="same_kind",
58+
casting: Optional[CastingModes] = "same_kind",
5859
order="K",
5960
dtype: Optional[DTypeLike] = None,
6061
subok: NotImplementedType = False,
@@ -87,7 +88,7 @@ def matmul(
8788
/,
8889
out: Optional[OutArray] = None,
8990
*,
90-
casting="same_kind",
91+
casting: Optional[CastingModes] = "same_kind",
9192
order: NotImplementedType = "K",
9293
dtype: Optional[DTypeLike] = None,
9394
subok: NotImplementedType = False,
@@ -118,7 +119,7 @@ def divmod(
118119
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
119120
*,
120121
where: NotImplementedType = True,
121-
casting="same_kind",
122+
casting: Optional[CastingModes] = "same_kind",
122123
order: NotImplementedType = "K",
123124
dtype: Optional[DTypeLike] = None,
124125
subok: NotImplementedType = False,
@@ -190,7 +191,7 @@ def wrapped(
190191
out: Optional[OutArray] = None,
191192
*,
192193
where=True,
193-
casting="same_kind",
194+
casting: Optional[CastingModes] = "same_kind",
194195
order="K",
195196
dtype: Optional[DTypeLike] = None,
196197
subok: NotImplementedType = False,

0 commit comments

Comments
 (0)