From 78c333ab93a8442a521b5185509d07682e698350 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 24 Apr 2023 13:50:15 +0300 Subject: [PATCH 1/6] ENH: add einsum --- autogen/numpy_api_dump.py | 4 - torch_np/_funcs.py | 5 +- torch_np/_funcs_impl.py | 107 ++++++++++++++++-- torch_np/_normalizations.py | 10 ++ torch_np/_ufuncs.py | 9 +- .../tests/numpy_tests/core/test_einsum.py | 56 ++++----- 6 files changed, 139 insertions(+), 52 deletions(-) diff --git a/autogen/numpy_api_dump.py b/autogen/numpy_api_dump.py index 450fef1b..4ae0df25 100644 --- a/autogen/numpy_api_dump.py +++ b/autogen/numpy_api_dump.py @@ -220,10 +220,6 @@ def ediff1d(ary, to_end=None, to_begin=None): raise NotImplementedError -def einsum(*operands, out=None, optimize=False, **kwargs): - raise NotImplementedError - - def einsum_path(*operands, optimize="greedy", einsum_call=False): raise NotImplementedError diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 5fab34ef..c62b735b 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -31,10 +31,13 @@ func = getattr(_funcs_impl, name) if name in ["percentile", "quantile", "median"]: decorated = normalizer(func, promote_scalar_result=True) + elif name == "einsum": + # normalized manually + decorated = func else: decorated = normalizer(func) - decorated.__qualname__ = name # XXX: is this really correct? + decorated.__qualname__ = name decorated.__name__ = name vars()[name] = decorated diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index e1171391..72a0ceea 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -8,6 +8,7 @@ from __future__ import annotations import builtins +import itertools import math import operator from typing import Optional, Sequence @@ -20,6 +21,7 @@ from ._normalizations import ( ArrayLike, AxisLike, + CastingModes, DTypeLike, NDArray, NotImplementedType, @@ -39,7 +41,7 @@ def copy( def copyto( dst: NDArray, src: ArrayLike, - casting="same_kind", + casting: Optional[CastingModes] = "same_kind", where: NotImplementedType = None, ): (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting) @@ -98,7 +100,9 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): return tensors -def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): +def _concatenate( + tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind" +): # pure torch implementation, used below and in cov/corrcoef below tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) tensors = _concat_cast_helper(tensors, out, dtype, casting) @@ -110,7 +114,7 @@ def concatenate( axis=0, out: Optional[OutArray] = None, dtype: Optional[DTypeLike] = None, - casting="same_kind", + casting: Optional[CastingModes] = "same_kind", ): _concat_check(ar_tuple, dtype, out=out) result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting) @@ -118,7 +122,10 @@ def concatenate( def vstack( - tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind" + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", ): _concat_check(tup, dtype, out=None) tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) @@ -129,7 +136,10 @@ def vstack( def hstack( - tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind" + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", ): _concat_check(tup, dtype, out=None) tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) @@ -137,7 +147,10 @@ def hstack( def dstack( - tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind" + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", ): # XXX: in numpy 1.24 dstack does not have dtype and casting keywords # but {h,v}stack do. Hence add them here for consistency. @@ -147,7 +160,10 @@ def dstack( def column_stack( - tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting="same_kind" + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", ): # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords # but row_stack does. (because row_stack is an alias for vstack, really). @@ -163,7 +179,7 @@ def stack( out: Optional[OutArray] = None, *, dtype: Optional[DTypeLike] = None, - casting="same_kind", + casting: Optional[CastingModes] = "same_kind", ): _concat_check(arrays, dtype, out=out) @@ -1166,6 +1182,11 @@ def vdot(a: ArrayLike, b: ArrayLike, /): def tensordot(a: ArrayLike, b: ArrayLike, axes=2): if isinstance(axes, (list, tuple)): axes = [[ax] if isinstance(ax, int) else ax for ax in axes] + + target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + a = _util.cast_if_needed(a, target_dtype) + b = _util.cast_if_needed(b, target_dtype) + return torch.tensordot(a, b, dims=axes) @@ -1208,6 +1229,74 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): return torch.outer(a, b) +def einsum(*operands, out=None, optimize=False, **kwargs): + # Have to manually normalize *operands and **kwargs, following the NumPy signature + # >>> np.einsum? + # Signature: np.einsum(*operands, out=None, optimize=False, **kwargs) + # Docstring: + # einsum(subscripts, *operands, out=None, dtype=None, order='K', + # casting='safe', optimize=False) + + from ._normalizations import ( + maybe_copy_to, + normalize_casting, + normalize_dtype, + normalize_not_implemented, + normalize_outarray, + wrap_tensors, + ) + + dtype = normalize_dtype(kwargs.pop("dtype", None)) + casting = normalize_casting(kwargs.pop("casting", "safe")) + + parm = lambda _: None # a fake duck-typed inspect.Parameter stub + parm.name = "out" + out = normalize_outarray(out, parm=parm) + + parm.default = "K" + parm.name = "order" + order = normalize_not_implemented(kwargs.pop("order", "K"), parm=parm) + if kwargs: + raise TypeError("unknown arguments: ", kwargs) + + # parse arrays and normalize them + if isinstance(operands[0], str): + # ("ij->", arrays) format + sublist_format = False + subscripts, array_operands = operands[0], operands[1:] + else: + # op, str, op, str ... format: normalize every other argument + sublist_format = True + array_operands = operands[:-1][::2] + + tensors = [normalize_array_like(op) for op in array_operands] + target_dtype = ( + _dtypes_impl.result_type_impl([op.dtype for op in tensors]) + if dtype is None + else dtype + ) + tensors = _util.typecast_tensors(tensors, target_dtype, casting) + + if sublist_format: + # recombine operands + sublists = operands[1::2] + has_sublistout = len(operands) % 2 == 1 + if has_sublistout: + sublistout = operands[-1] + operands = list(itertools.chain(*zip(tensors, sublists))) + if has_sublistout: + operands.append(sublistout) + + result = torch.einsum(*operands) + else: + result = torch.einsum(subscripts, *tensors) + + + + result = maybe_copy_to(out, result) + return wrap_tensors(result) + + # ### sort and partition ### @@ -1798,8 +1887,6 @@ def bartlett(M): def common_type(*tensors: ArrayLike): - import builtins - is_complex = False precision = 0 for a in tensors: diff --git a/torch_np/_normalizations.py b/torch_np/_normalizations.py index ccfb6d0d..f83211bc 100644 --- a/torch_np/_normalizations.py +++ b/torch_np/_normalizations.py @@ -14,6 +14,7 @@ DTypeLike = typing.TypeVar("DTypeLike") AxisLike = typing.TypeVar("AxisLike") NDArray = typing.TypeVar("NDarray") +CastingModes = typing.TypeVar("CastingModes") # OutArray is to annotate the out= array argument. # @@ -101,6 +102,14 @@ def normalize_outarray(arg, parm=None): return arg +def normalize_casting(arg, parm=None): + if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]: + raise ValueError( + f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')" + ) + return arg + + normalizers = { "ArrayLike": normalize_array_like, "Optional[ArrayLike]": normalize_optional_array_like, @@ -111,6 +120,7 @@ def normalize_outarray(arg, parm=None): "Optional[DTypeLike]": normalize_dtype, "AxisLike": normalize_axis_like, "NotImplementedType": normalize_not_implemented, + "Optional[CastingModes]": normalize_casting, } diff --git a/torch_np/_ufuncs.py b/torch_np/_ufuncs.py index f05c95ab..dcb3ac8f 100644 --- a/torch_np/_ufuncs.py +++ b/torch_np/_ufuncs.py @@ -7,6 +7,7 @@ from . import _binary_ufuncs_impl, _dtypes_impl, _helpers, _unary_ufuncs_impl, _util from ._normalizations import ( ArrayLike, + CastingModes, DTypeLike, NotImplementedType, OutArray, @@ -54,7 +55,7 @@ def wrapped( out: Optional[OutArray] = None, *, where=True, - casting="same_kind", + casting: Optional[CastingModes] = "same_kind", order="K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, @@ -87,7 +88,7 @@ def matmul( /, out: Optional[OutArray] = None, *, - casting="same_kind", + casting: Optional[CastingModes] = "same_kind", order: NotImplementedType = "K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, @@ -118,7 +119,7 @@ def divmod( out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None), *, where: NotImplementedType = True, - casting="same_kind", + casting: Optional[CastingModes] = "same_kind", order: NotImplementedType = "K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, @@ -190,7 +191,7 @@ def wrapped( out: Optional[OutArray] = None, *, where=True, - casting="same_kind", + casting: Optional[CastingModes] = "same_kind", order="K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, diff --git a/torch_np/tests/numpy_tests/core/test_einsum.py b/torch_np/tests/numpy_tests/core/test_einsum.py index c9a50c83..8a497b35 100644 --- a/torch_np/tests/numpy_tests/core/test_einsum.py +++ b/torch_np/tests/numpy_tests/core/test_einsum.py @@ -20,11 +20,11 @@ class TestEinsum: def test_einsum_errors(self): for do_opt in [True, False]: # Need enough arguments - assert_raises((TypeError, ValueError), np.einsum, optimize=do_opt) + assert_raises((TypeError, IndexError, ValueError), np.einsum, optimize=do_opt) assert_raises((IndexError, ValueError), np.einsum, "", optimize=do_opt) # subscripts must be a string - assert_raises(TypeError, np.einsum, 0, 0, optimize=do_opt) + assert_raises((AttributeError, TypeError), np.einsum, 0, 0, optimize=do_opt) # out parameter must be an array assert_raises(TypeError, np.einsum, "", 0, out='test', @@ -47,11 +47,11 @@ def test_einsum_errors(self): optimize=do_opt) # issue 4528 revealed a segfault with this call - assert_raises((NotImplementedError, TypeError), np.einsum, *(None,)*63, optimize=do_opt) + assert_raises((RuntimeError, TypeError), np.einsum, *(None,)*63, optimize=do_opt) # number of operands must match count in subscripts string - assert_raises((TypeError, ValueError), np.einsum, "", 0, 0, optimize=do_opt) - assert_raises((TypeError, ValueError), np.einsum, ",", 0, [0], [0], + assert_raises((RuntimeError, ValueError), np.einsum, "", 0, 0, optimize=do_opt) + assert_raises((RuntimeError, ValueError), np.einsum, ",", 0, [0], [0], optimize=do_opt) assert_raises((RuntimeError, ValueError), np.einsum, ",", [0], optimize=do_opt) @@ -282,15 +282,15 @@ def check_einsum_sums(self, dtype, do_opt=False): a = np.arange(n*n, dtype=dtype).reshape(n, n) assert_equal(np.einsum("ii", a, optimize=do_opt), np.trace(a).astype(dtype)) - assert_equal(np.einsum(a, [0, 0], optimize=do_opt), + assert_equal(np.einsum(a, [0, 0], optimize=do_opt), # torch? np.trace(a).astype(dtype)) # gh-15961: should accept numpy int64 type in subscript list - np_array = np.asarray([0, 0]) - assert_equal(np.einsum(a, np_array, optimize=do_opt), - np.trace(a).astype(dtype)) - assert_equal(np.einsum(a, list(np_array), optimize=do_opt), - np.trace(a).astype(dtype)) + # np_array = np.asarray([0, 0]) + # assert_equal(np.einsum(a, np_array, optimize=do_opt), + # np.trace(a).astype(dtype)) + # assert_equal(np.einsum(a, list(np_array), optimize=do_opt), + # np.trace(a).astype(dtype)) # multiply(a, b) assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case @@ -329,7 +329,7 @@ def check_einsum_sums(self, dtype, do_opt=False): # Suppress the complex warnings for the 'as f8' tests with suppress_warnings() as sup: - sup.filter(np.ComplexWarning) + # sup.filter(np.ComplexWarning) # matvec(a,b) / a.dot(b) where a is matrix, b is vector for n in range(1, 17): @@ -488,15 +488,15 @@ def check_einsum_sums(self, dtype, do_opt=False): 2*np.sum(a[1:])) # An object array, summed as the data type - a = np.arange(9, dtype=object) - - b = np.einsum("i->", a, dtype=dtype, casting='unsafe') - assert_equal(b, np.sum(a)) - assert_equal(b.dtype, np.dtype(dtype)) - - b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe') - assert_equal(b, np.sum(a)) - assert_equal(b.dtype, np.dtype(dtype)) + # a = np.arange(9, dtype=object) + # + # b = np.einsum("i->", a, dtype=dtype, casting='unsafe') + # assert_equal(b, np.sum(a)) + # assert_equal(b.dtype, np.dtype(dtype)) + # + # b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe') + # assert_equal(b, np.sum(a)) + # assert_equal(b.dtype, np.dtype(dtype)) # A case which was failing (ticket #1885) p = np.arange(2) + 1 @@ -550,23 +550,15 @@ def test_einsum_sums_uint8(self): def test_einsum_sums_int16(self): self.check_einsum_sums('i2') - def test_einsum_sums_uint16(self): - self.check_einsum_sums('u2') def test_einsum_sums_int32(self): self.check_einsum_sums('i4') self.check_einsum_sums('i4', True) - def test_einsum_sums_uint32(self): - self.check_einsum_sums('u4') - self.check_einsum_sums('u4', True) def test_einsum_sums_int64(self): self.check_einsum_sums('i8') - def test_einsum_sums_uint64(self): - self.check_einsum_sums('u8') - def test_einsum_sums_float16(self): self.check_einsum_sums('f2') @@ -577,8 +569,6 @@ def test_einsum_sums_float64(self): self.check_einsum_sums('f8') self.check_einsum_sums('f8', True) - def test_einsum_sums_longdouble(self): - self.check_einsum_sums(np.longdouble) def test_einsum_sums_cfloat64(self): self.check_einsum_sums('c8') @@ -587,8 +577,6 @@ def test_einsum_sums_cfloat64(self): def test_einsum_sums_cfloat128(self): self.check_einsum_sums('c16') - def test_einsum_sums_clongdouble(self): - self.check_einsum_sums(np.clongdouble) def test_einsum_misc(self): # This call used to crash because of a bug in @@ -729,6 +717,7 @@ def test_einsum_failed_on_p9_and_s390x(self): y = tensor.trace(axis1=0, axis2=2).trace() assert_allclose(x, y) + @pytest.mark.xfail(reason="no base") def test_einsum_all_contig_non_contig_output(self): # Issue gh-5907, tests that the all contiguous special case # actually checks the contiguity of the output @@ -934,6 +923,7 @@ def test_broadcasting_dot_cases(self): g = np.arange(64).reshape(2, 4, 8) self.optimize_compare('obk,ijk->ioj', operands=[g, g]) + @pytest.mark.xfail(reason="order='F' not supported") def test_output_order(self): # Ensure output order is respected for optimize cases, the below # conraction should yield a reshaped tensor view From 344b3f7a6b7041c80bb6149b3987701cd7411929 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 26 Apr 2023 09:33:55 +0300 Subject: [PATCH 2/6] MAINT: address a review comment From 66d62db8ec88e81b56eeeee15acdd703e436e1b7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 26 Apr 2023 10:28:41 +0300 Subject: [PATCH 3/6] MAINT: review comments --- torch_np/_funcs_impl.py | 42 ++++++++++++++--------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 72a0ceea..d106ba66 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -1229,45 +1229,33 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): return torch.outer(a, b) -def einsum(*operands, out=None, optimize=False, **kwargs): +def einsum(*operands, out=None, dtype=None, order='K', + casting='safe', optimize=False): # Have to manually normalize *operands and **kwargs, following the NumPy signature - # >>> np.einsum? - # Signature: np.einsum(*operands, out=None, optimize=False, **kwargs) - # Docstring: - # einsum(subscripts, *operands, out=None, dtype=None, order='K', - # casting='safe', optimize=False) from ._normalizations import ( maybe_copy_to, normalize_casting, normalize_dtype, - normalize_not_implemented, - normalize_outarray, wrap_tensors, ) + from ._ndarray import ndarray - dtype = normalize_dtype(kwargs.pop("dtype", None)) - casting = normalize_casting(kwargs.pop("casting", "safe")) - - parm = lambda _: None # a fake duck-typed inspect.Parameter stub - parm.name = "out" - out = normalize_outarray(out, parm=parm) - - parm.default = "K" - parm.name = "order" - order = normalize_not_implemented(kwargs.pop("order", "K"), parm=parm) - if kwargs: - raise TypeError("unknown arguments: ", kwargs) + dtype = normalize_dtype(dtype) + casting = normalize_casting(casting) + if out is not None and not isinstance(out, ndarray): + raise TypeError("'out' must be an array") + if order != "K": + raise NotImplementedError("'order' parameter is not supported.") # parse arrays and normalize them - if isinstance(operands[0], str): - # ("ij->", arrays) format - sublist_format = False - subscripts, array_operands = operands[0], operands[1:] - else: + sublist_format = not isinstance(operands[0], str) + if sublist_format: # op, str, op, str ... format: normalize every other argument - sublist_format = True array_operands = operands[:-1][::2] + else: + # ("ij->", arrays) format + subscripts, array_operands = operands[0], operands[1:] tensors = [normalize_array_like(op) for op in array_operands] target_dtype = ( @@ -1291,8 +1279,6 @@ def einsum(*operands, out=None, optimize=False, **kwargs): else: result = torch.einsum(subscripts, *tensors) - - result = maybe_copy_to(out, result) return wrap_tensors(result) From 1dd74bf92e70c2895eed928958bfd16dce2fc274 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 27 Apr 2023 11:05:04 +0300 Subject: [PATCH 4/6] MAINT: einsum: work around some short int / float limitations, xfail the rest --- torch_np/_funcs_impl.py | 23 +++++++++++++++---- .../tests/numpy_tests/core/test_einsum.py | 10 ++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index d106ba66..f66568b0 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -1229,17 +1229,16 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): return torch.outer(a, b) -def einsum(*operands, out=None, dtype=None, order='K', - casting='safe', optimize=False): +def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False): # Have to manually normalize *operands and **kwargs, following the NumPy signature + from ._ndarray import ndarray from ._normalizations import ( maybe_copy_to, normalize_casting, normalize_dtype, wrap_tensors, ) - from ._ndarray import ndarray dtype = normalize_dtype(dtype) casting = normalize_casting(casting) @@ -1251,7 +1250,13 @@ def einsum(*operands, out=None, dtype=None, order='K', # parse arrays and normalize them sublist_format = not isinstance(operands[0], str) if sublist_format: - # op, str, op, str ... format: normalize every other argument + # op, str, op, str ... [sublistout] format: normalize every other argument + + # - if sublistout is not given, the length of operands is even, and we pick + # odd-numbered elements, which are arrays. + # - if sublistout is given, the length of operands is odd, we peel off + # the last one, and pick odd-numbered elements, which are arrays. + # Without [:-1], we would have picked sublistout, too. array_operands = operands[:-1][::2] else: # ("ij->", arrays) format @@ -1263,6 +1268,16 @@ def einsum(*operands, out=None, dtype=None, order='K', if dtype is None else dtype ) + + # work around 'bmm' not implemented for 'Half' etc + is_half = target_dtype == torch.float16 + if is_half: + target_dtype = torch.float32 + + is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32] + if is_short_int: + target_dtype, result_dtype = torch.int64, target_dtype + tensors = _util.typecast_tensors(tensors, target_dtype, casting) if sublist_format: diff --git a/torch_np/tests/numpy_tests/core/test_einsum.py b/torch_np/tests/numpy_tests/core/test_einsum.py index 8a497b35..fc94bc03 100644 --- a/torch_np/tests/numpy_tests/core/test_einsum.py +++ b/torch_np/tests/numpy_tests/core/test_einsum.py @@ -541,24 +541,26 @@ def check_einsum_sums(self, dtype, do_opt=False): assert_array_equal(np.einsum("ij,i->", x, y, optimize=optimize), [2.]) # contig_stride0_outstride0_two + @pytest.mark.xfail(reason="int overflow differs in numpy and pytorch") def test_einsum_sums_int8(self): self.check_einsum_sums('i1') + @pytest.mark.xfail(reason="int overflow differs in numpy and pytorch") def test_einsum_sums_uint8(self): self.check_einsum_sums('u1') + @pytest.mark.xfail(reason="int overflow differs in numpy and pytorch") def test_einsum_sums_int16(self): self.check_einsum_sums('i2') - def test_einsum_sums_int32(self): self.check_einsum_sums('i4') self.check_einsum_sums('i4', True) - def test_einsum_sums_int64(self): self.check_einsum_sums('i8') + @pytest.mark.xfail(reason="np.float16(4641) == 4640.0") def test_einsum_sums_float16(self): self.check_einsum_sums('f2') @@ -780,6 +782,10 @@ def test_different_paths(self, dtype): # Use einsum to compare to not have difference due to sum round-offs: assert res == np.einsum('i->', scalar * arr) # contig + contig + contig -> scalar + + if dtype in ['e', 'B', 'b']: + pytest.xfail(reason='overflow differs in pytorch and numpy') + arr = np.array([0.5, 0.5, 0.25, 4.5, 3.], dtype=dtype) res = np.einsum('i,i,i->', arr, arr, arr) assert_array_equal(res, (arr * arr * arr).sum()) From 211a4611a60d652e9e52d7e234521dbe2a767497 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 27 Apr 2023 11:28:22 +0300 Subject: [PATCH 5/6] einsum: forward optimize= to torch.backends --- torch_np/_funcs_impl.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index f66568b0..05576971 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -1280,19 +1280,27 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize= tensors = _util.typecast_tensors(tensors, target_dtype, casting) - if sublist_format: - # recombine operands - sublists = operands[1::2] - has_sublistout = len(operands) % 2 == 1 - if has_sublistout: - sublistout = operands[-1] - operands = list(itertools.chain(*zip(tensors, sublists))) - if has_sublistout: - operands.append(sublistout) - - result = torch.einsum(*operands) - else: - result = torch.einsum(subscripts, *tensors) + try: + # set the global state to handle the optimize=... argument, restore on exit + old_strategy = torch.backends.opt_einsum.strategy + torch.backends.opt_einsum.strategy = optimize + + if sublist_format: + # recombine operands + sublists = operands[1::2] + has_sublistout = len(operands) % 2 == 1 + if has_sublistout: + sublistout = operands[-1] + operands = list(itertools.chain(*zip(tensors, sublists))) + if has_sublistout: + operands.append(sublistout) + + result = torch.einsum(*operands) + else: + result = torch.einsum(subscripts, *tensors) + + finally: + torch.backends.opt_einsum.strategy = old_strategy result = maybe_copy_to(out, result) return wrap_tensors(result) From 7e9f49c7e714dfe954e6e5364db61f7ba9f53bec Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 27 Apr 2023 12:52:33 +0300 Subject: [PATCH 6/6] MAINT: remove local imports from einsum --- torch_np/_funcs_impl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 05576971..85c39dcc 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -18,7 +18,9 @@ from . import _dtypes_impl from . import _reductions as _impl from . import _util -from ._normalizations import ( + +# these imports are for einsum only +from ._normalizations import ( # isort: skip ArrayLike, AxisLike, CastingModes, @@ -26,7 +28,11 @@ NDArray, NotImplementedType, OutArray, + maybe_copy_to, normalize_array_like, + normalize_casting, + normalize_dtype, + wrap_tensors, ) # ###### array creation routines @@ -1233,12 +1239,6 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize= # Have to manually normalize *operands and **kwargs, following the NumPy signature from ._ndarray import ndarray - from ._normalizations import ( - maybe_copy_to, - normalize_casting, - normalize_dtype, - wrap_tensors, - ) dtype = normalize_dtype(dtype) casting = normalize_casting(casting) @@ -1276,7 +1276,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize= is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32] if is_short_int: - target_dtype, result_dtype = torch.int64, target_dtype + target_dtype = torch.int64 tensors = _util.typecast_tensors(tensors, target_dtype, casting)