diff --git a/docs/api-reference.md b/docs/api-reference.md index 9cb4ff0b..c81ef90d 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -7,4 +7,6 @@ :toctree: generated atleast_nd + expand_dims + kron ``` diff --git a/pixi.lock b/pixi.lock index 891d58e4..d69f607e 100644 --- a/pixi.lock +++ b/pixi.lock @@ -442,7 +442,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.18.0-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-cov-5.0.0-pyhd8ed1ab_0.conda @@ -558,7 +558,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.18.0-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-cov-5.0.0-pyhd8ed1ab_0.conda @@ -680,7 +680,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.18.0-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyh0701188_6.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-cov-5.0.0-pyhd8ed1ab_0.conda @@ -982,8 +982,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/dill-0.3.8-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/distlib-0.3.8-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.16.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/identify-2.6.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-24_linux64_openblas.conda @@ -1011,11 +1013,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/nodeenv-1.9.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.1.1-py312h58c1407_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.2-hb9d3cd8_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-24.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pre-commit-3.8.0-pyha770c72_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/psutil-6.0.0-py312h66e93f0_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.6-hc5c86c4_1_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.12-5_cp312.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.2-py312h66e93f0_1.conda @@ -1041,8 +1046,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/dill-0.3.8-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/distlib-0.3.8-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.16.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/identify-2.6.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libblas-3.9.0-24_osxarm64_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcblas-3.9.0-24_osxarm64_openblas.conda @@ -1063,11 +1070,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/nodeenv-1.9.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.1.1-py312h801f5e3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.3.2-h8359307_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-24.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pre-commit-3.8.0-pyha770c72_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/psutil-6.0.0-py312h024a12e_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.12.6-h739c21a_1_cpython.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python_abi-3.12-5_cp312.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.2-py312h024a12e_1.conda @@ -1093,8 +1103,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/dill-0.3.8-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/distlib-0.3.8-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.16.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/identify-2.6.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/intel-openmp-2024.2.1-h57928b3_1083.conda - conda: https://conda.anaconda.org/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/libblas-3.9.0-24_win64_mkl.conda @@ -1121,13 +1133,16 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/nodeenv-1.9.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/numpy-2.1.1-py312h49bc9c5_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/openssl-3.3.2-h2466b09_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-24.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pre-commit-3.8.0-pyha770c72_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/psutil-6.0.0-py312h4389bb4_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pthread-stubs-0.4-hcd874cb_1001.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/win-64/pthreads-win32-2.9.1-hfa6e2cd_3.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/python-3.12.6-hce54a09_1_cpython.conda - conda: https://conda.anaconda.org/conda-forge/win-64/python_abi-3.12-5_cp312.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pyyaml-6.0.2-py312h4389bb4_1.conda @@ -1347,7 +1362,7 @@ packages: name: array-api-extra version: 0.1.2.dev0 path: . - sha256: 1aea0df8782a134f9f7a0342e3ed376c70345be2c918236f27b1f64a8cee81a8 + sha256: 5f444b16d4b2888d478d4c2f6a540cd298091e14cf25bc5b3f89bf971522bf30 requires_dist: - furo>=2023.8.17 ; extra == 'docs' - myst-parser>=0.13 ; extra == 'docs' @@ -4018,15 +4033,15 @@ packages: timestamp: 1714846885370 - kind: conda name: pylint - version: 3.3.0 + version: 3.3.1 build: pyhd8ed1ab_0 subdir: noarch noarch: python - url: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.0-pyhd8ed1ab_0.conda - sha256: a96156b6e2820eee1ef6b8a703b1ed226e9a0bb515105f2697f53f1eebe7a82f - md5: 7d2b9f7c59cbbd1c28c8631a9a63bbee + url: https://conda.anaconda.org/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda + sha256: 35c0b0f3b8b0585fee0966f5a09b7bd43519a5bca58a9f29f502dd8442a9b14c + md5: 2a3426f75e2172c932131f4e3d51bcf4 depends: - - astroid >=3.3.3,<3.4.0-dev0 + - astroid >=3.3.4,<3.4.0-dev0 - colorama >=0.4.5 - dill >=0.3.7 - isort >=4.2.5,<6,!=5.13.0 @@ -4037,11 +4052,10 @@ packages: - tomlkit >=0.10.1 - typing_extensions >=3.10.0 license: GPL-2.0-or-later - license_family: GPL purls: - pkg:pypi/pylint?source=hash-mapping - size: 353136 - timestamp: 1726928350439 + size: 352873 + timestamp: 1727266530261 - kind: conda name: pysocks version: 1.7.1 diff --git a/pyproject.toml b/pyproject.toml index cd5eb8ed..20f9ef47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ pylint = "*" # import dependencies for mypy: array-api-strict = "*" numpy = "*" +pytest = "*" [tool.pixi.feature.lint.tasks] pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" } diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index a3a3f838..b26d27cd 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -from ._funcs import atleast_nd +from ._funcs import atleast_nd, expand_dims, kron __version__ = "0.1.2.dev0" -__all__ = ["__version__", "atleast_nd"] +__all__ = ["__version__", "atleast_nd", "expand_dims", "kron"] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 98c175bb..234617f3 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -5,10 +5,10 @@ if TYPE_CHECKING: from ._typing import Array, ModuleType -__all__ = ["atleast_nd"] +__all__ = ["atleast_nd", "expand_dims", "kron"] -def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array: +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: """ Recursively expand the dimension of an array to at least `ndim`. @@ -46,3 +46,193 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array: x = xp.expand_dims(x, axis=0) x = atleast_nd(x, ndim=ndim, xp=xp) return x + + +def expand_dims( + a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType +) -> Array: + """ + Expand the shape of an array. + + Insert (a) new axis/axes that will appear at the position(s) specified by + `axis` in the expanded array shape. + + This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*. + Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays. + + Parameters + ---------- + a : array + axis : int or tuple of ints, optional + Position(s) in the expanded axes where the new axis (or axes) is/are placed. + If multiple positions are provided, they should be unique (note that a position + given by a positive index could also be referred to by a negative index - + that will also result in an error). + Default: ``(0,)``. + xp : array_namespace + The standard-compatible namespace for `a`. + + Returns + ------- + res : array + `a` with an expanded shape. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([1, 2]) + >>> x.shape + (2,) + + The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``: + + >>> y = xpx.expand_dims(x, axis=0, xp=xp) + >>> y + Array([[1, 2]], dtype=array_api_strict.int64) + >>> y.shape + (1, 2) + + The following is equivalent to ``x[:, xp.newaxis]``: + + >>> y = xpx.expand_dims(x, axis=1, xp=xp) + >>> y + Array([[1], + [2]], dtype=array_api_strict.int64) + >>> y.shape + (2, 1) + + ``axis`` may also be a tuple: + + >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) + >>> y + Array([[[1, 2]]], dtype=array_api_strict.int64) + + >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) + >>> y + Array([[[1], + [2]]], dtype=array_api_strict.int64) + + """ + if not isinstance(axis, tuple): + axis = (axis,) + ndim = a.ndim + len(axis) + if axis != () and (min(axis) < -ndim or max(axis) >= ndim): + err_msg = ( + f"a provided axis position is out of bounds for array of dimension {a.ndim}" + ) + raise IndexError(err_msg) + axis = tuple(dim % ndim for dim in axis) + if len(set(axis)) != len(axis): + err_msg = "Duplicate dimensions specified in `axis`." + raise ValueError(err_msg) + for i in sorted(axis): + a = xp.expand_dims(a, axis=i) + return a + + +def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: + """ + Kronecker product of two arrays. + + Computes the Kronecker product, a composite array made of blocks of the + second array scaled by the first. + + Equivalent to ``numpy.kron`` for NumPy arrays. + + Parameters + ---------- + a, b : array + xp : array_namespace + The standard-compatible namespace for `a` and `b`. + + Returns + ------- + res : array + The Kronecker product of `a` and `b`. + + Notes + ----- + The function assumes that the number of dimensions of `a` and `b` + are the same, if necessary prepending the smallest with ones. + If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``, + the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``. + The elements are products of elements from `a` and `b`, organized + explicitly by:: + + kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN] + + where:: + + kt = it * st + jt, t = 0,...,N + + In the common 2-D case (N=1), the block structure can be visualized:: + + [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], + [ ... ... ], + [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] + + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp) + Array([ 5, 6, 7, 50, 60, 70, 500, + 600, 700], dtype=array_api_strict.int64) + + >>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp) + Array([ 5, 50, 500, 6, 60, 600, 7, + 70, 700], dtype=array_api_strict.int64) + + >>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp) + Array([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]], dtype=array_api_strict.float64) + + + >>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5)) + >>> b = xp.reshape(xp.arange(24), (2, 3, 4)) + >>> c = xpx.kron(a, b, xp=xp) + >>> c.shape + (2, 10, 6, 20) + >>> I = (1, 3, 0, 2) + >>> J = (0, 2, 1) + >>> J1 = (0,) + J # extend to ndim=4 + >>> S1 = (1,) + b.shape + >>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1)) + >>> c[K] == a[I]*b[J] + Array(True, dtype=array_api_strict.bool) + + """ + + b = xp.asarray(b) + singletons = (1,) * (b.ndim - a.ndim) + a = xp.broadcast_to(xp.asarray(a), singletons + a.shape) + + nd_b, nd_a = b.ndim, a.ndim + nd_max = max(nd_b, nd_a) + if nd_a == 0 or nd_b == 0: + return xp.multiply(a, b) + + a_shape = a.shape + b_shape = b.shape + + # Equalise the shapes by prepending smaller one with 1s + a_shape = (1,) * max(0, nd_b - nd_a) + a_shape + b_shape = (1,) * max(0, nd_a - nd_b) + b_shape + + # Insert empty dimensions + a_arr = expand_dims(a, axis=tuple(range(nd_b - nd_a)), xp=xp) + b_arr = expand_dims(b, axis=tuple(range(nd_a - nd_b)), xp=xp) + + # Compute the product + a_arr = expand_dims(a_arr, axis=tuple(range(1, nd_max * 2, 2)), xp=xp) + b_arr = expand_dims(b_arr, axis=tuple(range(0, nd_max * 2, 2)), xp=xp) + result = xp.multiply(a_arr, b_arr) + + # Reshape back and return + a_shape = xp.asarray(a_shape) + b_shape = xp.asarray(b_shape) + return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape))) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index e3412a19..6ed32784 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,10 +1,17 @@ from __future__ import annotations +import contextlib +from typing import TYPE_CHECKING, Any + # array-api-strict#6 import array_api_strict as xp # type: ignore[import-untyped] -from numpy.testing import assert_array_equal +import pytest +from numpy.testing import assert_array_equal, assert_equal + +from array_api_extra import atleast_nd, expand_dims, kron -from array_api_extra import atleast_nd +if TYPE_CHECKING: + Array = Any # To be changed to a Protocol later (see array-api#589) class TestAtLeastND: @@ -67,3 +74,117 @@ def test_5D(self): y = atleast_nd(x, ndim=9, xp=xp) assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) + + +class TestKron: + def test_basic(self): + # Using 0-dimensional array + a = xp.asarray(1) + b = xp.asarray([[1, 2], [3, 4]]) + k = xp.asarray([[1, 2], [3, 4]]) + assert_array_equal(kron(a, b, xp=xp), k) + a = xp.asarray([[1, 2], [3, 4]]) + b = xp.asarray(1) + assert_array_equal(kron(a, b, xp=xp), k) + + # Using 1-dimensional array + a = xp.asarray([3]) + b = xp.asarray([[1, 2], [3, 4]]) + k = xp.asarray([[3, 6], [9, 12]]) + assert_array_equal(kron(a, b, xp=xp), k) + a = xp.asarray([[1, 2], [3, 4]]) + b = xp.asarray([3]) + assert_array_equal(kron(a, b, xp=xp), k) + + # Using 3-dimensional array + a = xp.asarray([[[1]], [[2]]]) + b = xp.asarray([[1, 2], [3, 4]]) + k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) + assert_array_equal(kron(a, b, xp=xp), k) + a = xp.asarray([[1, 2], [3, 4]]) + b = xp.asarray([[[1]], [[2]]]) + k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) + assert_array_equal(kron(a, b, xp=xp), k) + + def test_kron_smoke(self): + a = xp.ones([3, 3]) + b = xp.ones([3, 3]) + k = xp.ones([9, 9]) + + assert_array_equal(kron(a, b, xp=xp), k) + + @pytest.mark.parametrize( + ("shape_a", "shape_b"), + [ + ((1, 1), (1, 1)), + ((1, 2, 3), (4, 5, 6)), + ((2, 2), (2, 2, 2)), + ((1, 0), (1, 1)), + ((2, 0, 2), (2, 2)), + ((2, 0, 0, 2), (2, 0, 2)), + ], + ) + def test_kron_shape(self, shape_a, shape_b): + a = xp.ones(shape_a) + b = xp.ones(shape_b) + normalised_shape_a = xp.asarray( + (1,) * max(0, len(shape_b) - len(shape_a)) + shape_a + ) + normalised_shape_b = xp.asarray( + (1,) * max(0, len(shape_a) - len(shape_b)) + shape_b + ) + expected_shape = tuple( + int(dim) for dim in xp.multiply(normalised_shape_a, normalised_shape_b) + ) + + k = kron(a, b, xp=xp) + assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron") + + +class TestExpandDims: + def test_functionality(self): + def _squeeze_all(b: Array) -> Array: + """Mimics `np.squeeze(b)`. `xpx.squeeze`?""" + for axis in range(b.ndim): + with contextlib.suppress(ValueError): + b = xp.squeeze(b, axis=axis) + return b + + s = (2, 3, 4, 5) + a = xp.empty(s) + for axis in range(-5, 4): + b = expand_dims(a, axis=axis, xp=xp) + assert b.shape[axis] == 1 + assert _squeeze_all(b).shape == s + + def test_axis_tuple(self): + a = xp.empty((3, 3, 3)) + assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3) + assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1) + assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1) + assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3) + + def test_axis_out_of_range(self): + s = (2, 3, 4, 5) + a = xp.empty(s) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=-6, xp=xp) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=5, xp=xp) + + a = xp.empty((3, 3, 3)) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=(0, -6), xp=xp) + with pytest.raises(IndexError, match="out of bounds"): + expand_dims(a, axis=(0, 5), xp=xp) + + def test_repeated_axis(self): + a = xp.empty((3, 3, 3)) + with pytest.raises(ValueError, match="Duplicate dimensions"): + expand_dims(a, axis=(1, 1), xp=xp) + + def test_positive_negative_repeated(self): + # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817 + a = xp.empty((2, 3, 4, 5)) + with pytest.raises(ValueError, match="Duplicate dimensions"): + expand_dims(a, axis=(3, -3), xp=xp)