Skip to content

ENH: add einsum + its numpy tests #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 27, 2023
4 changes: 0 additions & 4 deletions autogen/numpy_api_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,6 @@ def ediff1d(ary, to_end=None, to_begin=None):
raise NotImplementedError


def einsum(*operands, out=None, optimize=False, **kwargs):
raise NotImplementedError


def einsum_path(*operands, optimize="greedy", einsum_call=False):
raise NotImplementedError

Expand Down
5 changes: 4 additions & 1 deletion torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
func = getattr(_funcs_impl, name)
if name in ["percentile", "quantile", "median"]:
decorated = normalizer(func, promote_scalar_result=True)
elif name == "einsum":
# normalized manually
decorated = func
else:
decorated = normalizer(func)

decorated.__qualname__ = name # XXX: is this really correct?
decorated.__qualname__ = name
decorated.__name__ = name
vars()[name] = decorated

Expand Down
116 changes: 107 additions & 9 deletions torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import builtins
import itertools
import operator
from typing import Optional, Sequence

Expand All @@ -16,14 +17,21 @@
from . import _dtypes_impl
from . import _reductions as _impl
from . import _util
from ._normalizations import (

# these imports are for einsum only
from ._normalizations import ( # isort: skip
ArrayLike,
AxisLike,
CastingModes,
DTypeLike,
NDArray,
NotImplementedType,
OutArray,
maybe_copy_to,
normalize_array_like,
normalize_casting,
normalize_dtype,
wrap_tensors,
)

# ###### array creation routines
Expand All @@ -38,7 +46,7 @@ def copy(
def copyto(
dst: NDArray,
src: ArrayLike,
casting="same_kind",
casting: Optional[CastingModes] = "same_kind",
where: NotImplementedType = None,
):
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
Expand Down Expand Up @@ -97,7 +105,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
return tensors


def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
def _concatenate(
tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
):
# pure torch implementation, used below and in cov/corrcoef below
tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
tensors = _concat_cast_helper(tensors, out, dtype, casting)
Expand All @@ -109,15 +119,18 @@ def concatenate(
axis=0,
out: Optional[OutArray] = None,
dtype: Optional[DTypeLike] = None,
casting="same_kind",
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(ar_tuple, dtype, out=out)
result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
return result


def vstack(
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(tup, dtype, out=None)
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
Expand All @@ -128,15 +141,21 @@ def vstack(


def hstack(
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(tup, dtype, out=None)
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
return torch.hstack(tensors)


def dstack(
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
# but {h,v}stack do. Hence add them here for consistency.
Expand All @@ -146,7 +165,10 @@ def dstack(


def column_stack(
tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind"
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
# but row_stack does. (because row_stack is an alias for vstack, really).
Expand All @@ -162,7 +184,7 @@ def stack(
out: Optional[OutArray] = None,
*,
dtype: Optional[DTypeLike] = None,
casting="same_kind",
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(arrays, dtype, out=out)

Expand Down Expand Up @@ -1152,6 +1174,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
if isinstance(axes, (list, tuple)):
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]

target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
a = _util.cast_if_needed(a, target_dtype)
b = _util.cast_if_needed(b, target_dtype)

return torch.tensordot(a, b, dims=axes)


Expand Down Expand Up @@ -1194,6 +1221,77 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
return torch.outer(a, b)


def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False):
# Have to manually normalize *operands and **kwargs, following the NumPy signature

from ._ndarray import ndarray

dtype = normalize_dtype(dtype)
casting = normalize_casting(casting)
if out is not None and not isinstance(out, ndarray):
raise TypeError("'out' must be an array")
if order != "K":
raise NotImplementedError("'order' parameter is not supported.")

# parse arrays and normalize them
sublist_format = not isinstance(operands[0], str)
if sublist_format:
# op, str, op, str ... [sublistout] format: normalize every other argument

# - if sublistout is not given, the length of operands is even, and we pick
# odd-numbered elements, which are arrays.
# - if sublistout is given, the length of operands is odd, we peel off
# the last one, and pick odd-numbered elements, which are arrays.
# Without [:-1], we would have picked sublistout, too.
array_operands = operands[:-1][::2]
else:
# ("ij->", arrays) format
subscripts, array_operands = operands[0], operands[1:]

tensors = [normalize_array_like(op) for op in array_operands]
target_dtype = (
_dtypes_impl.result_type_impl([op.dtype for op in tensors])
if dtype is None
else dtype
)

# work around 'bmm' not implemented for 'Half' etc
is_half = target_dtype == torch.float16
if is_half:
target_dtype = torch.float32

is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
if is_short_int:
target_dtype = torch.int64

tensors = _util.typecast_tensors(tensors, target_dtype, casting)

try:
# set the global state to handle the optimize=... argument, restore on exit
old_strategy = torch.backends.opt_einsum.strategy
torch.backends.opt_einsum.strategy = optimize

if sublist_format:
# recombine operands
sublists = operands[1::2]
has_sublistout = len(operands) % 2 == 1
if has_sublistout:
sublistout = operands[-1]
operands = list(itertools.chain(*zip(tensors, sublists)))
if has_sublistout:
operands.append(sublistout)

result = torch.einsum(*operands)
else:
result = torch.einsum(subscripts, *tensors)

finally:
torch.backends.opt_einsum.strategy = old_strategy

result = maybe_copy_to(out, result)
return wrap_tensors(result)


# ### sort and partition ###


Expand Down
10 changes: 10 additions & 0 deletions torch_np/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DTypeLike = typing.TypeVar("DTypeLike")
AxisLike = typing.TypeVar("AxisLike")
NDArray = typing.TypeVar("NDarray")
CastingModes = typing.TypeVar("CastingModes")

# OutArray is to annotate the out= array argument.
#
Expand Down Expand Up @@ -97,6 +98,14 @@ def normalize_outarray(arg, parm=None):
return arg


def normalize_casting(arg, parm=None):
if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
raise ValueError(
f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
)
return arg


normalizers = {
"ArrayLike": normalize_array_like,
"Optional[ArrayLike]": normalize_optional_array_like,
Expand All @@ -107,6 +116,7 @@ def normalize_outarray(arg, parm=None):
"Optional[DTypeLike]": normalize_dtype,
"AxisLike": normalize_axis_like,
"NotImplementedType": normalize_not_implemented,
"Optional[CastingModes]": normalize_casting,
}


Expand Down
9 changes: 5 additions & 4 deletions torch_np/_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import _binary_ufuncs_impl, _dtypes_impl, _helpers, _unary_ufuncs_impl, _util
from ._normalizations import (
ArrayLike,
CastingModes,
DTypeLike,
NotImplementedType,
OutArray,
Expand Down Expand Up @@ -54,7 +55,7 @@ def wrapped(
out: Optional[OutArray] = None,
*,
where=True,
casting="same_kind",
casting: Optional[CastingModes] = "same_kind",
order="K",
dtype: Optional[DTypeLike] = None,
subok: NotImplementedType = False,
Expand Down Expand Up @@ -87,7 +88,7 @@ def matmul(
/,
out: Optional[OutArray] = None,
*,
casting="same_kind",
casting: Optional[CastingModes] = "same_kind",
order: NotImplementedType = "K",
dtype: Optional[DTypeLike] = None,
subok: NotImplementedType = False,
Expand Down Expand Up @@ -118,7 +119,7 @@ def divmod(
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
*,
where: NotImplementedType = True,
casting="same_kind",
casting: Optional[CastingModes] = "same_kind",
order: NotImplementedType = "K",
dtype: Optional[DTypeLike] = None,
subok: NotImplementedType = False,
Expand Down Expand Up @@ -190,7 +191,7 @@ def wrapped(
out: Optional[OutArray] = None,
*,
where=True,
casting="same_kind",
casting: Optional[CastingModes] = "same_kind",
order="K",
dtype: Optional[DTypeLike] = None,
subok: NotImplementedType = False,
Expand Down
Loading