diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..1f6cbc798f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,6 +82,7 @@ jobs: install-numba: [0] install-jax: [0] install-torch: [0] + install-xarray: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -115,6 +116,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 + install-xarray: 0 - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" @@ -150,6 +152,13 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/pytorch" + - install-xarray: 1 + os: "ubuntu-latest" + python-version: "3.13" + numpy-version: ">=2.0" + fast-compile: 0 + float32: 0 + part: "tests/xtensor" - os: macos-15 python-version: "3.13" numpy-version: ">=2.0" @@ -196,6 +205,7 @@ jobs: if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi + if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi pip install pytest-sphinx pip install -e ./ @@ -212,6 +222,7 @@ jobs: INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}} + INSTALL_XARRAY: ${{ matrix.install-xarray }} OS: ${{ matrix.os}} - name: Run tests diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index b9e9c3164d..deedf13e93 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -4551,7 +4551,7 @@ def ix_(*args): new = as_tensor(new) if new.ndim != 1: raise ValueError("Cross index must be 1 dimensional") - new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1)) + new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1))) out.append(new) return tuple(out) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 7a1bc75b0b..dc92238010 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -473,24 +473,6 @@ def cumprod(x, axis=None): return CumOp(axis=axis, mode="mul")(x) -class CumsumOp(Op): - __props__ = ("axis",) - - def __new__(typ, *args, **kwargs): - obj = object.__new__(CumOp, *args, **kwargs) - obj.mode = "add" - return obj - - -class CumprodOp(Op): - __props__ = ("axis",) - - def __new__(typ, *args, **kwargs): - obj = object.__new__(CumOp, *args, **kwargs) - obj.mode = "mul" - return obj - - def diff(x, n=1, axis=-1): """Calculate the `n`-th order discrete difference along the given `axis`. diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 278d1e8da6..99ae67af9b 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs): return Apply( self, (x, y, *new_inputs), - [ - tensor( - dtype=x.type.dtype, - shape=tuple(1 if s == 1 else None for s in x.type.shape), - ) - ], + [x.type()], ) def perform(self, node, inputs, out_): diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py new file mode 100644 index 0000000000..06265e40de --- /dev/null +++ b/pytensor/xtensor/__init__.py @@ -0,0 +1,16 @@ +import warnings + +import pytensor.xtensor.rewriting +from pytensor.xtensor import ( + linalg, + special, +) +from pytensor.xtensor.shape import concat +from pytensor.xtensor.type import ( + as_xtensor, + xtensor, + xtensor_constant, +) + + +warnings.warn("xtensor module is experimental and full of bugs") diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py new file mode 100644 index 0000000000..ab035be346 --- /dev/null +++ b/pytensor/xtensor/basic.py @@ -0,0 +1,104 @@ +from collections.abc import Sequence + +from pytensor.compile import ViewOp +from pytensor.graph import Apply, Op +from pytensor.link.c.op import COp +from pytensor.tensor.type import TensorType +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +class XOp(Op): + """A base class for XOps that shouldn't be materialized""" + + def perform(self, node, inputs, outputs): + raise NotImplementedError( + f"xtensor operation {self} must be lowered to equivalent tensor operations" + ) + + +class XTypeCastOp(COp): + """Base class for Ops that type cast between TensorType and XTensorType. + + This is like a `ViewOp` but without the expectation the input and output have identical types. + """ + + view_map = {0: [0]} + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + + def c_code(self, node, nodename, inp, out, sub): + (iname,) = inp + (oname,) = out + fail = sub["fail"] + + code, _ = ViewOp.c_code_and_version[TensorType] + return code % locals() + + def c_code_cache_version(self): + _, version = ViewOp.c_code_and_version[TensorType] + return (version,) + + +class TensorFromXTensor(XTypeCastOp): + __props__ = () + + def make_node(self, x): + if not isinstance(x.type, XTensorType): + raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") + output = TensorType(x.type.dtype, shape=x.type.shape)() + return Apply(self, [x], [output]) + + +tensor_from_xtensor = TensorFromXTensor() + + +class XTensorFromTensor(XTypeCastOp): + __props__ = ("dims",) + + def __init__(self, dims: Sequence[str]): + super().__init__() + self.dims = tuple(dims) + + def make_node(self, x): + if not isinstance(x.type, TensorType): + raise TypeError(f"x must be an TensorType type, got {type(x.type)}") + output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape) + return Apply(self, [x], [output]) + + +def xtensor_from_tensor(x, dims): + return XTensorFromTensor(dims=dims)(x) + + +class Rename(XTypeCastOp): + __props__ = ("new_dims",) + + def __init__(self, new_dims: tuple[str, ...]): + super().__init__() + self.new_dims = new_dims + + def make_node(self, x): + x = as_xtensor(x) + output = x.type.clone(dims=self.new_dims)() + return Apply(self, [x], [output]) + + +def rename(x, name_dict: dict[str, str] | None = None, **names: str): + if name_dict is not None: + if names: + raise ValueError("Cannot use both positional and keyword names in rename") + names = name_dict + + x = as_xtensor(x) + old_names = x.type.dims + new_names = list(old_names) + for old_name, new_name in names.items(): + try: + new_names[old_names.index(old_name)] = new_name + except IndexError: + raise ValueError( + f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}" + ) + + return Rename(tuple(new_names))(x) diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py new file mode 100644 index 0000000000..01517db55d --- /dev/null +++ b/pytensor/xtensor/indexing.py @@ -0,0 +1,219 @@ +# HERE LIE DRAGONS +# Useful links to make sense of all the numpy/xarray complexity +# https://numpy.org/devdocs//user/basics.indexing.html +# https://numpy.org/neps/nep-0021-advanced-indexing.html +# https://docs.xarray.dev/en/latest/user-guide/indexing.html +# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html +from typing import Literal + +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.scalar.basic import discrete_dtypes +from pytensor.tensor.basic import as_tensor +from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice +from pytensor.xtensor.basic import XOp, xtensor_from_tensor +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +def as_idx_variable(idx, indexed_dim: str): + if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): + raise TypeError( + "XTensors do not support indexing with None (np.newaxis), use expand_dims instead" + ) + if isinstance(idx, slice): + idx = make_slice(idx) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + pass + elif ( + isinstance(idx, tuple) + and len(idx) == 2 + and ( + isinstance(idx[0], str) + or ( + isinstance(idx[0], tuple | list) + and all(isinstance(d, str) for d in idx[0]) + ) + ) + ): + # Special case for ("x", array) that xarray supports + dim, idx = idx + if isinstance(idx, Variable) and isinstance(idx.type, XTensorType): + raise IndexError( + f"Giving a dimension name to an XTensorVariable indexer is not supported: {(dim, idx)}. " + "Use .rename() instead." + ) + if isinstance(dim, str): + dims = (dim,) + else: + dims = tuple(dim) + idx = as_xtensor(as_tensor(idx), dims=dims) + else: + # Must be integer / boolean indices, we already counted for None and slices + try: + idx = as_xtensor(idx) + except TypeError: + idx = as_tensor(idx) + if idx.type.ndim > 1: + # Same error that xarray raises + raise IndexError( + "Unlabeled multi-dimensional array cannot be used for indexing" + ) + # This is implicitly an XTensorVariable with dim matching the indexed one + idx = xtensor_from_tensor(idx, dims=(indexed_dim,)[: idx.type.ndim]) + + if idx.type.dtype == "bool": + if idx.type.ndim != 1: + # xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379 + # Otherwise, it is always restricted to 1d boolean indexing arrays + raise NotImplementedError( + "Only 1d boolean indexing arrays are supported" + ) + if idx.type.dims != (indexed_dim,): + raise IndexError( + "Boolean indexer should be unlabeled or on the same dimension to the indexed array. " + f"Indexer is on {idx.type.dims} but the target dimension is {indexed_dim}." + ) + + # Convert to nonzero indices + idx = as_xtensor(idx.values.nonzero()[0], dims=idx.type.dims) + + elif idx.type.dtype not in discrete_dtypes: + raise TypeError("Numerical indices must be integers or boolean") + return idx + + +def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: + if dim_length is None: + return None + if isinstance(slc, Constant): + d = slc.data + start, stop, step = d.start, d.stop, d.step + elif slc.owner is None: + # It's a root variable no way of knowing what we're getting + return None + else: + # It's a MakeSliceOp + start, stop, step = slc.owner.inputs + if isinstance(start, Constant): + start = start.data + else: + return None + if isinstance(stop, Constant): + stop = stop.data + else: + return None + if isinstance(step, Constant): + step = step.data + else: + return None + return len(range(*slice(start, stop, step).indices(dim_length))) + + +class Index(XOp): + __props__ = () + + def make_node(self, x, *idxs): + x = as_xtensor(x) + + if any(idx is Ellipsis for idx in idxs): + if idxs.count(Ellipsis) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Convert intermediate Ellipsis to slice(None) + ellipsis_loc = idxs.index(Ellipsis) + n_implied_none_slices = x.type.ndim - (len(idxs) - 1) + idxs = ( + *idxs[:ellipsis_loc], + *((slice(None),) * n_implied_none_slices), + *idxs[ellipsis_loc + 1 :], + ) + + x_ndim = x.type.ndim + x_dims = x.type.dims + x_shape = x.type.shape + out_dims = [] + out_shape = [] + + def combine_dim_info(idx_dim, idx_dim_shape): + if idx_dim not in out_dims: + # First information about the dimension length + out_dims.append(idx_dim) + out_shape.append(idx_dim_shape) + else: + # Dim already introduced in output by a previous index + # Update static shape or raise if incompatible + out_dim_pos = out_dims.index(idx_dim) + out_dim_shape = out_shape[out_dim_pos] + if out_dim_shape is None: + # We don't know the size of the dimension yet + out_shape[out_dim_pos] = idx_dim_shape + elif idx_dim_shape is not None and idx_dim_shape != out_dim_shape: + raise IndexError( + f"Dimension of indexers mismatch for dim {idx_dim}" + ) + + if len(idxs) > x_ndim: + raise IndexError("Too many indices") + + idxs = [ + as_idx_variable(idx, dim) for idx, dim in zip(idxs, x_dims, strict=False) + ] + + for i, idx in enumerate(idxs): + if isinstance(idx.type, SliceType): + idx_dim = x_dims[i] + idx_dim_shape = get_static_slice_length(idx, x_shape[i]) + combine_dim_info(idx_dim, idx_dim_shape) + else: + if idx.type.ndim == 0: + # Scalar index, dimension is dropped + continue + + assert isinstance(idx.type, XTensorType) + + idx_dims = idx.type.dims + for idx_dim in idx_dims: + idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)] + combine_dim_info(idx_dim, idx_dim_shape) + + for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]): + # Add back any unindexed dimensions + if dim_i not in out_dims: + # If the dimension was not indexed, we keep it as is + combine_dim_info(dim_i, shape_i) + + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, *idxs], [output]) + + +index = Index() + + +class IndexUpdate(XOp): + __props__ = ("mode",) + + def __init__(self, mode: Literal["set", "inc"]): + if mode not in ("set", "inc"): + raise ValueError("mode must be 'set' or 'inc'") + self.mode = mode + + def make_node(self, x, y, *idxs): + # Call Index on (x, *idxs) to process inputs and infer output type + x_view_node = index.make_node(x, *idxs) + x, *idxs = x_view_node.inputs + [x_view] = x_view_node.outputs + + try: + y = as_xtensor(y) + except TypeError: + y = as_xtensor(as_tensor(y), dims=x_view.type.dims) + + if not set(y.type.dims).issubset(x_view.type.dims): + raise ValueError( + f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}" + ) + + out = x.type() + return Apply(self, [x, y, *idxs], [out]) + + +index_assignment = IndexUpdate("set") +index_increment = IndexUpdate("inc") diff --git a/pytensor/xtensor/linalg.py b/pytensor/xtensor/linalg.py new file mode 100644 index 0000000000..1ed7abf9d9 --- /dev/null +++ b/pytensor/xtensor/linalg.py @@ -0,0 +1,74 @@ +from collections.abc import Sequence +from typing import Literal + +from pytensor.tensor.slinalg import Cholesky, Solve +from pytensor.xtensor.type import as_xtensor +from pytensor.xtensor.vectorization import XBlockwise + + +def cholesky( + x, + lower: bool = True, + *, + check_finite: bool = False, + overwrite_a: bool = False, + on_error: Literal["raise", "nan"] = "raise", + dims: Sequence[str], +): + if len(dims) != 2: + raise ValueError(f"Cholesky needs two dims, got {len(dims)}") + + core_op = Cholesky( + lower=lower, + check_finite=check_finite, + overwrite_a=overwrite_a, + on_error=on_error, + ) + core_dims = ( + ((dims[0], dims[1]),), + ((dims[0], dims[1]),), + ) + x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims) + return x_op(x) + + +def solve( + a, + b, + dims: Sequence[str], + assume_a="gen", + lower: bool = False, + check_finite: bool = False, +): + a, b = as_xtensor(a), as_xtensor(b) + input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]] + output_core_dims: tuple[tuple[str] | tuple[str, str]] + if len(dims) == 2: + b_ndim = 1 + [m1_dim] = [dim for dim in dims if dim not in b.type.dims] + m2_dim = dims[0] if dims[0] != m1_dim else dims[1] + input_core_dims = ((m1_dim, m2_dim), (m2_dim,)) + output_core_dims = ((m2_dim,),) + elif len(dims) == 3: + b_ndim = 2 + [n_dim] = [dim for dim in dims if dim not in a.type.dims] + [m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim] + input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim)) + output_core_dims = ( + ( + m2_dim, + n_dim, + ), + ) + else: + raise ValueError("Solve dims must have length 2 or 3") + + core_op = Solve( + b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite + ) + x_op = XBlockwise( + core_op, + signature=core_op.gufunc_signature, + core_dims=(input_core_dims, output_core_dims), + ) + return x_op(a, b) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py new file mode 100644 index 0000000000..4fe0ca8106 --- /dev/null +++ b/pytensor/xtensor/math.py @@ -0,0 +1,136 @@ +import sys + +import numpy as np + +import pytensor.scalar as ps +from pytensor import config +from pytensor.scalar import ScalarOp +from pytensor.scalar.basic import _cast_mapping +from pytensor.xtensor.basic import as_xtensor +from pytensor.xtensor.vectorization import XElemwise + + +this_module = sys.modules[__name__] + + +def _as_xelemwise(core_op: ScalarOp) -> XElemwise: + out = XElemwise(core_op) + out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables" + return out + + +abs = _as_xelemwise(ps.abs) +add = _as_xelemwise(ps.add) +logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_) +angle = _as_xelemwise(ps.angle) +arccos = _as_xelemwise(ps.arccos) +arccosh = _as_xelemwise(ps.arccosh) +arcsin = _as_xelemwise(ps.arcsin) +arcsinh = _as_xelemwise(ps.arcsinh) +arctan = _as_xelemwise(ps.arctan) +arctan2 = _as_xelemwise(ps.arctan2) +arctanh = _as_xelemwise(ps.arctanh) +betainc = _as_xelemwise(ps.betainc) +betaincinv = _as_xelemwise(ps.betaincinv) +ceil = _as_xelemwise(ps.ceil) +clip = _as_xelemwise(ps.clip) +complex = _as_xelemwise(ps.complex) +conjugate = conj = _as_xelemwise(ps.conj) +cos = _as_xelemwise(ps.cos) +cosh = _as_xelemwise(ps.cosh) +deg2rad = _as_xelemwise(ps.deg2rad) +equal = eq = _as_xelemwise(ps.eq) +erf = _as_xelemwise(ps.erf) +erfc = _as_xelemwise(ps.erfc) +erfcinv = _as_xelemwise(ps.erfcinv) +erfcx = _as_xelemwise(ps.erfcx) +erfinv = _as_xelemwise(ps.erfinv) +exp = _as_xelemwise(ps.exp) +exp2 = _as_xelemwise(ps.exp2) +expm1 = _as_xelemwise(ps.expm1) +floor = _as_xelemwise(ps.floor) +floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div) +gamma = _as_xelemwise(ps.gamma) +gammainc = _as_xelemwise(ps.gammainc) +gammaincc = _as_xelemwise(ps.gammaincc) +gammainccinv = _as_xelemwise(ps.gammainccinv) +gammaincinv = _as_xelemwise(ps.gammaincinv) +gammal = _as_xelemwise(ps.gammal) +gammaln = _as_xelemwise(ps.gammaln) +gammau = _as_xelemwise(ps.gammau) +greater_equal = ge = _as_xelemwise(ps.ge) +greater = gt = _as_xelemwise(ps.gt) +hyp2f1 = _as_xelemwise(ps.hyp2f1) +i0 = _as_xelemwise(ps.i0) +i1 = _as_xelemwise(ps.i1) +identity = _as_xelemwise(ps.identity) +imag = _as_xelemwise(ps.imag) +logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert) +isinf = _as_xelemwise(ps.isinf) +isnan = _as_xelemwise(ps.isnan) +iv = _as_xelemwise(ps.iv) +ive = _as_xelemwise(ps.ive) +j0 = _as_xelemwise(ps.j0) +j1 = _as_xelemwise(ps.j1) +jv = _as_xelemwise(ps.jv) +kve = _as_xelemwise(ps.kve) +less_equal = le = _as_xelemwise(ps.le) +log = _as_xelemwise(ps.log) +log10 = _as_xelemwise(ps.log10) +log1mexp = _as_xelemwise(ps.log1mexp) +log1p = _as_xelemwise(ps.log1p) +log2 = _as_xelemwise(ps.log2) +less = lt = _as_xelemwise(ps.lt) +mod = _as_xelemwise(ps.mod) +multiply = mul = _as_xelemwise(ps.mul) +negative = neg = _as_xelemwise(ps.neg) +not_equal = neq = _as_xelemwise(ps.neq) +logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_) +owens_t = _as_xelemwise(ps.owens_t) +polygamma = _as_xelemwise(ps.polygamma) +power = pow = _as_xelemwise(ps.pow) +psi = _as_xelemwise(ps.psi) +rad2deg = _as_xelemwise(ps.rad2deg) +real = _as_xelemwise(ps.real) +reciprocal = _as_xelemwise(ps.reciprocal) +round = _as_xelemwise(ps.round_half_to_even) +maximum = _as_xelemwise(ps.scalar_maximum) +minimum = _as_xelemwise(ps.scalar_minimum) +second = _as_xelemwise(ps.second) +sigmoid = _as_xelemwise(ps.sigmoid) +sign = _as_xelemwise(ps.sign) +sin = _as_xelemwise(ps.sin) +sinh = _as_xelemwise(ps.sinh) +softplus = _as_xelemwise(ps.softplus) +square = sqr = _as_xelemwise(ps.sqr) +sqrt = _as_xelemwise(ps.sqrt) +subtract = sub = _as_xelemwise(ps.sub) +where = switch = _as_xelemwise(ps.switch) +tan = _as_xelemwise(ps.tan) +tanh = _as_xelemwise(ps.tanh) +tri_gamma = _as_xelemwise(ps.tri_gamma) +true_divide = true_div = _as_xelemwise(ps.true_div) +trunc = _as_xelemwise(ps.trunc) +logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor) + +_xelemwise_cast_op: dict[str, XElemwise] = {} + + +def cast(x, dtype): + if dtype == "floatX": + dtype = config.floatX + else: + dtype = np.dtype(dtype).name + + x = as_xtensor(x) + if x.type.dtype == dtype: + return x + if x.type.dtype.startswith("complex") and not dtype.startswith("complex"): + raise TypeError( + "Casting from complex to real is ambiguous: consider" + " real(), imag(), angle() or abs()" + ) + + if dtype not in _xelemwise_cast_op: + _xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype]) + return _xelemwise_cast_op[dtype](x) diff --git a/pytensor/xtensor/readme.md b/pytensor/xtensor/readme.md new file mode 100644 index 0000000000..b3511f56ad --- /dev/null +++ b/pytensor/xtensor/readme.md @@ -0,0 +1,69 @@ +# XTensor Module + +This module implements as abstraction layer on regular tensor operations, that behaves like Xarray. + +A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute, +that labels the dimensions of the tensor. + +Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects. + +The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations. +These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into +a regular tensor graph that can itself be evaluated as usual. + +Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray. +If the existing XOps can be composed to produce the desired result, then we can use them directly. + +## Coordinates +For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`. +The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor. +Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor. + +## Example + +```python +import pytensor.tensor as pt +import pytensor.xtensor as px + +a = pt.tensor("a", shape=(3,)) +b = pt.tensor("b", shape=(4,)) + +ax = px.as_xtensor(a, dims=["x"]) +bx = px.as_xtensor(b, dims=["y"]) + +zx = ax + bx +assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4)) + +z = zx.values +z.dprint() +# TensorFromXTensor [id A] +# └─ XElemwise{scalar_op=Add()} [id B] +# ├─ XTensorFromTensor{dims=('x',)} [id C] +# │ └─ a [id D] +# └─ XTensorFromTensor{dims=('y',)} [id E] +# └─ b [id F] +``` + +Once we compile the graph, no `XOp`s are left. + +```python +import pytensor + +with pytensor.config.change_flags(optimizer_verbose=True): + fn = pytensor.function([a, b], z) + +# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0) +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) + +fn.dprint() +# Add [id A] 2 +# ├─ ExpandDims{axis=1} [id B] 1 +# │ └─ a [id C] +# └─ ExpandDims{axis=0} [id D] 0 +# └─ b [id E] +``` + + + diff --git a/pytensor/xtensor/reduction.py b/pytensor/xtensor/reduction.py new file mode 100644 index 0000000000..aa7662d7cd --- /dev/null +++ b/pytensor/xtensor/reduction.py @@ -0,0 +1,125 @@ +import typing +from collections.abc import Sequence +from functools import partial +from types import EllipsisType + +import pytensor.scalar as ps +from pytensor.graph.basic import Apply +from pytensor.tensor.math import variadic_mul +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.math import neq, sqrt +from pytensor.xtensor.math import sqr as square +from pytensor.xtensor.type import as_xtensor, xtensor + + +REDUCE_DIM = str | Sequence[str] | EllipsisType | None + + +class XReduce(XOp): + __slots__ = ("binary_op", "dims") + + def __init__(self, binary_op, dims: Sequence[str]): + super().__init__() + self.binary_op = binary_op + # Order of reduce dims doens't change the behavior of the Op + self.dims = tuple(sorted(dims)) + + def make_node(self, x): + x = as_xtensor(x) + x_dims = x.type.dims + x_dims_set = set(x_dims) + reduce_dims_set = set(self.dims) + if x_dims_set == reduce_dims_set: + out_dims, out_shape = [], [] + else: + if not reduce_dims_set.issubset(x_dims_set): + raise ValueError( + f"Reduced dims {self.dims} not found in array dimensions {x_dims}." + ) + out_dims, out_shape = zip( + *[ + (d, s) + for d, s in zip(x_dims, x.type.shape) + if d not in reduce_dims_set + ] + ) + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x], [output]) + + +def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]: + if isinstance(dim, str): + return (dim,) + elif dim is None or dim is Ellipsis: + x = as_xtensor(x) + return typing.cast(tuple[str], x.type.dims) + return dim + + +def reduce(x, dim: REDUCE_DIM = None, *, binary_op): + dims = _process_user_dims(x, dim) + return XReduce(binary_op=binary_op, dims=dims)(x) + + +sum = partial(reduce, binary_op=ps.add) +prod = partial(reduce, binary_op=ps.mul) +max = partial(reduce, binary_op=ps.scalar_maximum) +min = partial(reduce, binary_op=ps.scalar_minimum) + + +def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op): + x = as_xtensor(x) + if x.type.dtype != "bool": + x = neq(x, 0) + return reduce(x, dim=dim, binary_op=binary_op) + + +all = partial(bool_reduce, binary_op=ps.and_) +any = partial(bool_reduce, binary_op=ps.or_) + + +def _infer_reduced_size(original_var, reduced_var): + reduced_dims = reduced_var.dims + return variadic_mul( + *[size for dim, size in original_var.sizes if dim not in reduced_dims] + ) + + +def mean(x, dim: REDUCE_DIM): + x = as_xtensor(x) + sum_x = sum(x, dim) + n = _infer_reduced_size(x, sum_x) + return sum_x / n + + +def var(x, dim: REDUCE_DIM, *, ddof: int = 0): + x = as_xtensor(x) + x_mean = mean(x, dim) + n = _infer_reduced_size(x, x_mean) + return square(x - x_mean) / (n - ddof) + + +def std(x, dim: REDUCE_DIM, *, ddof: int = 0): + return sqrt(var(x, dim, ddof=ddof)) + + +class XCumReduce(XOp): + __props__ = ("binary_op", "dims") + + def __init__(self, binary_op, dims: Sequence[str]): + self.binary_op = binary_op + self.dims = tuple(sorted(dims)) # Order doesn't matter + + def make_node(self, x): + x = as_xtensor(x) + out = x.type() + return Apply(self, [x], [out]) + + +def cumreduce(x, dim: REDUCE_DIM, *, binary_op): + dims = _process_user_dims(x, dim) + return XCumReduce(dims=dims, binary_op=binary_op)(x) + + +cumsum = partial(cumreduce, binary_op=ps.add) +cumprod = partial(cumreduce, binary_op=ps.mul) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py new file mode 100644 index 0000000000..a65ad0db85 --- /dev/null +++ b/pytensor/xtensor/rewriting/__init__.py @@ -0,0 +1,5 @@ +import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.indexing +import pytensor.xtensor.rewriting.reduction +import pytensor.xtensor.rewriting.shape +import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py new file mode 100644 index 0000000000..be93101426 --- /dev/null +++ b/pytensor/xtensor/rewriting/basic.py @@ -0,0 +1,62 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor.basic import register_infer_shape +from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless +from pytensor.xtensor.basic import ( + Rename, + TensorFromXTensor, + XTensorFromTensor, + xtensor_from_tensor, +) +from pytensor.xtensor.rewriting.utils import register_lower_xtensor + + +@register_infer_shape +@register_useless +@register_canonicalize +@register_lower_xtensor +@node_rewriter(tracks=[TensorFromXTensor]) +def useless_tensor_from_xtensor(fgraph, node): + """TensorFromXTensor(XTensorFromTensor(x)) -> x""" + [x] = node.inputs + if x.owner and isinstance(x.owner.op, XTensorFromTensor): + return [x.owner.inputs[0]] + + +@register_infer_shape +@register_useless +@register_canonicalize +@register_lower_xtensor +@node_rewriter(tracks=[XTensorFromTensor]) +def useless_xtensor_from_tensor(fgraph, node): + """XTensorFromTensor(TensorFromXTensor(x)) -> x""" + [x] = node.inputs + if x.owner and isinstance(x.owner.op, TensorFromXTensor): + return [x.owner.inputs[0]] + + +@register_lower_xtensor +@node_rewriter(tracks=[TensorFromXTensor]) +def useless_tensor_from_xtensor_of_rename(fgraph, node): + """TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)""" + [renamed_x] = node.inputs + if renamed_x.owner and isinstance(renamed_x.owner.op, Rename): + [x] = renamed_x.owner.inputs + return node.op(x, return_list=True) + + +@register_lower_xtensor +@node_rewriter(tracks=[Rename]) +def useless_rename(fgraph, node): + """ + + Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims) + Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims) + """ + [renamed_x] = node.inputs + if renamed_x.owner: + if isinstance(renamed_x.owner.op, Rename): + [x] = renamed_x.owner.inputs + return [node.op(x)] + elif isinstance(renamed_x.owner.op, TensorFromXTensor): + [x] = renamed_x.owner.inputs + return [xtensor_from_tensor(x, dims=node.op.new_dims)] diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py new file mode 100644 index 0000000000..25a0f80dd4 --- /dev/null +++ b/pytensor/xtensor/rewriting/indexing.py @@ -0,0 +1,212 @@ +from itertools import zip_longest + +from pytensor import as_symbolic +from pytensor.graph import Constant, node_rewriter +from pytensor.tensor import TensorType, arange, specify_shape +from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor +from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.indexing import Index, IndexUpdate, index +from pytensor.xtensor.rewriting.utils import register_lower_xtensor +from pytensor.xtensor.type import XTensorType + + +def to_basic_idx(idx): + if isinstance(idx.type, SliceType): + if isinstance(idx, Constant): + return idx.data + elif idx.owner: + # MakeSlice Op + # We transform NoneConsts to regular None so that basic Subtensor can be used if possible + return slice( + *[ + None if isinstance(i.type, NoneTypeT) else i + for i in idx.owner.inputs + ] + ) + else: + return idx + if ( + isinstance(idx.type, XTensorType) + and idx.type.ndim == 0 + and idx.type.dtype != bool + ): + return idx.values + raise TypeError("Cannot convert idx to basic idx") + + +def _lower_index(node): + """Lower XTensorVariable indexing to regular TensorVariable indexing. + + xarray-like indexing has two modes: + 1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices. + 2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing. + + An Index Op can combine both modes. + To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing. + We expand the dims of each index so they are as large as the number of output dimensions, place the indices that + belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes. + + For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]], + This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have + indices that have more than one dimension at the start. + + In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension). + However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices + we have to convert the slices to equivalent advanced indices. + We do this by creating an `arange` tensor that matches the shape of the dimension being indexed, + and then indexing it with the original slice. This index is then handled as a regular advanced index. + + Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy. + When all advanced indices are consecutive, the respective view is located in the "original" location. + However, if advanced indices are separated by basic indices (slices in our case), the output views + always show up at the front of the array. This information is returned as the second output of this function, + which labels the final position of the indexed dimensions under this rule. + """ + + assert isinstance(node.op, Index) + + x, *idxs = node.inputs + [out] = node.outputs + x_tensor_indexed_dims = out.type.dims + x_tensor = tensor_from_xtensor(x) + + if all( + ( + isinstance(idx.type, SliceType) + or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0) + ) + for idx in idxs + ): + # Special case having just basic indexing + x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)] + + else: + # General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing + # May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index + x_dims = x.type.dims + x_shape = tuple(x.shape) + out_ndim = out.type.ndim + out_dims = out.type.dims + aligned_idxs = [] + basic_idx_axis = [] + # zip_longest adds the implicit slice(None) + for i, (idx, x_dim) in enumerate( + zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None))) + ): + if isinstance(idx.type, SliceType): + if not any( + ( + isinstance(other_idx.type, XTensorType) + and x_dim in other_idx.dims + ) + for j, other_idx in enumerate(idxs) + if j != i + ): + # We can use basic indexing directly if no other index acts on this dimension + # This is an optimization that avoids creating an unnecessary arange tensor + # and facilitates the use of the specialized AdvancedSubtensor1 when possible + aligned_idxs.append(idx) + basic_idx_axis.append(out_dims.index(x_dim)) + else: + # Otherwise we need to convert the basic index into an equivalent advanced indexing + # And align it so it interacts correctly with the other advanced indices + adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)] + ds_order = ["x"] * out_ndim + ds_order[out_dims.index(x_dim)] = 0 + aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order)) + else: + assert isinstance(idx.type, XTensorType) + if idx.type.ndim == 0: + # Scalar index, we can use it directly + aligned_idxs.append(idx.values) + else: + # Vector index, we need to align the indexing dimensions with the base_dims + ds_order = ["x"] * out_ndim + for j, idx_dim in enumerate(idx.dims): + ds_order[out_dims.index(idx_dim)] = j + aligned_idxs.append(idx.values.dimshuffle(ds_order)) + + # Squeeze indexing dimensions that were not used because we kept basic indexing slices + if basic_idx_axis: + aligned_idxs = [ + idx.squeeze(axis=basic_idx_axis) + if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) + else idx + for idx in aligned_idxs + ] + + x_tensor_indexed = x_tensor[tuple(aligned_idxs)] + + if basic_idx_axis and _non_consecutive_adv_indexing(aligned_idxs): + # Numpy moves advanced indexing dimensions to the front when they are not consecutive + # We need to transpose them back to the expected output order + x_tensor_indexed_basic_dims = [out_dims[axis] for axis in basic_idx_axis] + x_tensor_indexed_dims = [ + dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims + ] + x_tensor_indexed_basic_dims + + return x_tensor_indexed, x_tensor_indexed_dims + + +@register_lower_xtensor +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + """Lower XTensorVariable indexing to regular TensorVariable indexing. + + The bulk of the work is done by `_lower_index`, except for special logic to control the + location of non-consecutive advanced indices, and to preserve static shape information. + """ + + [out] = node.outputs + out_dims = out.type.dims + + x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node) + if x_tensor_indexed_dims != out_dims: + # Numpy moves advanced indexing dimensions to the front when they are not consecutive + # We need to transpose them back to the expected output order + transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims] + x_tensor_indexed = x_tensor_indexed.transpose(transpose_order) + + # Add lost shape information + x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) + + new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims) + return [new_out] + + +@register_lower_xtensor +@node_rewriter(tracks=[IndexUpdate]) +def lower_index_update(fgraph, node): + """Lower XTensorVariable index update to regular TensorVariable indexing update. + + This rewrite requires converting the index view to a tensor-based equivalent expression, + just like `lower_index`. It then requires aligning the dimensions of y with the + dimensions of the index view, with special care for non-consecutive dimensions being + pulled to the front axis according to numpy rules. + """ + x, y, *idxs = node.inputs + + # Lower the indexing part first + indexed_node = index.make_node(x, *idxs) + x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node) + y_tensor = tensor_from_xtensor(y) + + # Align dimensions of y with those of the indexed tensor x + y_dims = y.type.dims + y_dims_set = set(y_dims) + y_order = tuple( + y_dims.index(x_dim) if x_dim in y_dims_set else "x" + for x_dim in x_tensor_indexed_dims + ) + # Remove useless left expand_dims + while len(y_order) > 0 and y_order[0] == "x": + y_order = y_order[1:] + if y_order != tuple(range(y_tensor.type.ndim)): + y_tensor = y_tensor.dimshuffle(y_order) + + x_tensor_updated = inc_subtensor( + x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set" + ) + new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims) + return [new_out] diff --git a/pytensor/xtensor/rewriting/reduction.py b/pytensor/xtensor/rewriting/reduction.py new file mode 100644 index 0000000000..e43be81e73 --- /dev/null +++ b/pytensor/xtensor/rewriting/reduction.py @@ -0,0 +1,72 @@ +from functools import partial + +import pytensor.scalar as ps +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.extra_ops import CumOp +from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.reduction import XCumReduce, XReduce +from pytensor.xtensor.rewriting.utils import register_lower_xtensor + + +@register_lower_xtensor +@node_rewriter(tracks=[XReduce]) +def lower_reduce(fgraph, node): + [x] = node.inputs + [out] = node.outputs + x_dims = x.type.dims + reduce_dims = node.op.dims + reduce_axis = [x_dims.index(dim) for dim in reduce_dims] + + if not reduce_axis: + return [x] + + match node.op.binary_op: + case ps.add: + tensor_op_class = Sum + case ps.mul: + tensor_op_class = Prod + case ps.and_: + tensor_op_class = All + case ps.or_: + tensor_op_class = Any + case ps.scalar_maximum: + tensor_op_class = Max + case ps.scalar_minimum: + tensor_op_class = Min + case _: + # Case without known/predefined Ops + tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op) + + x_tensor = tensor_from_xtensor(x) + out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor) + new_out = xtensor_from_tensor(out_tensor, out.type.dims) + return [new_out] + + +@register_lower_xtensor +@node_rewriter(tracks=[XCumReduce]) +def lower_cumreduce(fgraph, node): + [x] = node.inputs + x_dims = x.type.dims + reduce_dims = node.op.dims + reduce_axis = [x_dims.index(dim) for dim in reduce_dims] + + if not reduce_axis: + return [x] + + match node.op.binary_op: + case ps.add: + tensor_op_class = partial(CumOp, mode="add") + case ps.mul: + tensor_op_class = partial(CumOp, mode="mul") + case _: + # We don't know how to convert an arbitrary binary cum/reduce Op + return None + + # Each dim corresponds to an application of Cumsum/Cumprod + out_tensor = tensor_from_xtensor(x) + for axis in reduce_axis: + out_tensor = tensor_op_class(axis=axis)(out_tensor) + out = xtensor_from_tensor(out_tensor, x.type.dims) + return [out] diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py new file mode 100644 index 0000000000..b0ca4f3bd4 --- /dev/null +++ b/pytensor/xtensor/rewriting/shape.py @@ -0,0 +1,134 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import ( + broadcast_to, + join, + moveaxis, + specify_shape, + squeeze, +) +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.rewriting.basic import register_lower_xtensor +from pytensor.xtensor.shape import ( + Concat, + Squeeze, + Stack, + Transpose, + UnStack, +) + + +@register_lower_xtensor +@node_rewriter(tracks=[Stack]) +def lower_stack(fgraph, node): + [x] = node.inputs + batch_ndim = x.type.ndim - len(node.op.stacked_dims) + stacked_axes = [ + i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims + ] + end = tuple(range(-len(stacked_axes), 0)) + + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end) + if batch_ndim == (x.type.ndim - 1): + # This happens when we stack a "single" dimension, in this case all we need is the transpose + # Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename + final_tensor = x_tensor_transposed + else: + final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1) + final_tensor = x_tensor_transposed.reshape(final_shape) + + new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) + return [new_out] + + +@register_lower_xtensor +@node_rewriter(tracks=[UnStack]) +def lower_unstack(fgraph, node): + x = node.inputs[0] + unstacked_lengths = node.inputs[1:] + axis_to_unstack = x.type.dims.index(node.op.old_dim_name) + + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1]) + final_tensor = x_tensor_transposed.reshape( + (*x_tensor_transposed.shape[:-1], *unstacked_lengths) + ) + # Reintroduce any static shape information that was lost during the reshape + final_tensor = specify_shape(final_tensor, node.outputs[0].type.shape) + + new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) + return [new_out] + + +@register_lower_xtensor +@node_rewriter(tracks=[Concat]) +def lower_concat(fgraph, node): + out_dims = node.outputs[0].type.dims + concat_dim = node.op.dim + concat_axis = out_dims.index(concat_dim) + + # Convert input XTensors to Tensors and align batch dimensions + tensor_inputs = [] + for inp in node.inputs: + inp_dims = inp.type.dims + order = [ + inp_dims.index(out_dim) if out_dim in inp_dims else "x" + for out_dim in out_dims + ] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) + tensor_inputs.append(tensor_inp) + + # Broadcast non-concatenated dimensions of each input + non_concat_shape = [None] * len(out_dims) + for tensor_inp in tensor_inputs: + # TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime + # I'm running this as "shape_unsafe" to simplify the logic / returned graph + for i, (bcast, sh) in enumerate( + zip(tensor_inp.type.broadcastable, tensor_inp.shape) + ): + if bcast or i == concat_axis or non_concat_shape[i] is not None: + continue + non_concat_shape[i] = sh + + assert non_concat_shape.count(None) == 1 + + bcast_tensor_inputs = [] + for tensor_inp in tensor_inputs: + # We modify the concat_axis in place, as we don't need the list anywhere else + non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis] + bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape)) + + joined_tensor = join(concat_axis, *bcast_tensor_inputs) + new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) + return [new_out] + + +@register_lower_xtensor +@node_rewriter(tracks=[Transpose]) +def lower_transpose(fgraph, node): + [x] = node.inputs + # Use the final dimensions that were already computed in make_node + out_dims = node.outputs[0].type.dims + in_dims = x.type.dims + + # Compute the permutation based on the final dimensions + perm = tuple(in_dims.index(d) for d in out_dims) + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = x_tensor.transpose(perm) + new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) + return [new_out] + + +@register_lower_xtensor +@node_rewriter([Squeeze]) +def local_squeeze_reshape(fgraph, node): + """Rewrite Squeeze to tensor.squeeze.""" + [x] = node.inputs + x_tensor = tensor_from_xtensor(x) + x_dims = x.type.dims + dims_to_remove = node.op.dims + axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove) + x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze) + + new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py new file mode 100644 index 0000000000..ea1c7ab4b4 --- /dev/null +++ b/pytensor/xtensor/rewriting/utils.py @@ -0,0 +1,35 @@ +from pytensor.compile import optdb +from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase + + +lower_xtensor_db = EquilibriumDB(ignore_newtrees=False) + +optdb.register( + "lower_xtensor", + lower_xtensor_db, + "fast_run", + "fast_compile", + "xtensor", + position=0, +) + + +def register_lower_xtensor( + node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: RewriteDatabase | NodeRewriter): + return register_lower_xtensor( + inner_rewriter, node_rewriter, *tags, **kwargs + ) + + return register + + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore + lower_xtensor_db.register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter diff --git a/pytensor/xtensor/rewriting/vectorization.py b/pytensor/xtensor/rewriting/vectorization.py new file mode 100644 index 0000000000..9b52022f07 --- /dev/null +++ b/pytensor/xtensor/rewriting/vectorization.py @@ -0,0 +1,64 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import Elemwise +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.rewriting.utils import register_lower_xtensor +from pytensor.xtensor.vectorization import XBlockwise, XElemwise + + +@register_lower_xtensor +@node_rewriter(tracks=[XElemwise]) +def lower_elemwise(fgraph, node): + out_dims = node.outputs[0].type.dims + + # Convert input XTensors to Tensors and align batch dimensions + tensor_inputs = [] + for inp in node.inputs: + inp_dims = inp.type.dims + order = [ + inp_dims.index(out_dim) if out_dim in inp_dims else "x" + for out_dim in out_dims + ] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) + tensor_inputs.append(tensor_inp) + + tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( + *tensor_inputs, return_list=True + ) + + # Convert output Tensors to XTensors + new_outs = [ + xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs + ] + return new_outs + + +@register_lower_xtensor +@node_rewriter(tracks=[XBlockwise]) +def lower_blockwise(fgraph, node): + op: XBlockwise = node.op + batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0]) + batch_dims = node.outputs[0].type.dims[:batch_ndim] + + # Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end + tensor_inputs = [] + for inp, core_dims in zip(node.inputs, op.core_dims[0]): + inp_dims = inp.type.dims + # Align the batch dims of the input, and place the core dims on the right + batch_order = [ + inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" + for batch_dim in batch_dims + ] + core_order = [inp_dims.index(core_dim) for core_dim in core_dims] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) + tensor_inputs.append(tensor_inp) + + tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature) + tensor_outs = tensor_op(*tensor_inputs, return_list=True) + + # Convert output Tensors to XTensors + new_outs = [ + xtensor_from_tensor(tensor_out, dims=old_out.type.dims) + for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True) + ] + return new_outs diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py new file mode 100644 index 0000000000..cd0f024e56 --- /dev/null +++ b/pytensor/xtensor/shape.py @@ -0,0 +1,382 @@ +import warnings +from collections.abc import Sequence +from types import EllipsisType +from typing import Literal + +from pytensor.graph import Apply +from pytensor.scalar import discrete_dtypes, upcast +from pytensor.tensor import as_tensor, get_scalar_constant_value +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.type import as_xtensor, xtensor + + +class Stack(XOp): + __props__ = ("new_dim_name", "stacked_dims") + + def __init__(self, new_dim_name: str, stacked_dims: tuple[str, ...]): + super().__init__() + if new_dim_name in stacked_dims: + raise ValueError( + f"Stacking dim {new_dim_name} must not be in {stacked_dims}" + ) + if not stacked_dims: + raise ValueError(f"Stacking dims must not be empty: got {stacked_dims}") + self.new_dim_name = new_dim_name + self.stacked_dims = stacked_dims + + def make_node(self, x): + x = as_xtensor(x) + if not (set(self.stacked_dims) <= set(x.type.dims)): + raise ValueError( + f"Stacking dims {self.stacked_dims} must be a subset of {x.type.dims}" + ) + if self.new_dim_name in x.type.dims: + raise ValueError( + f"Stacking dim {self.new_dim_name} must not be in {x.type.dims}" + ) + if len(self.stacked_dims) == x.type.ndim: + batch_dims, batch_shape = (), () + else: + batch_dims, batch_shape = zip( + *( + (dim, shape) + for dim, shape in zip(x.type.dims, x.type.shape) + if dim not in self.stacked_dims + ) + ) + stack_shape = 1 + for dim, shape in zip(x.type.dims, x.type.shape): + if dim in self.stacked_dims: + if shape is None: + stack_shape = None + break + else: + stack_shape *= shape + output = xtensor( + dtype=x.type.dtype, + shape=(*batch_shape, stack_shape), + dims=(*batch_dims, self.new_dim_name), + ) + return Apply(self, [x], [output]) + + +def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]): + if dim is not None: + if dims: + raise ValueError("Cannot use both positional dim and keyword dims in stack") + dims = dim + + y = x + for new_dim_name, stacked_dims in dims.items(): + if isinstance(stacked_dims, str): + raise TypeError( + f"Stacking dims must be a sequence of strings, got a single string: {stacked_dims}" + ) + y = Stack(new_dim_name, tuple(stacked_dims))(y) + return y + + +class UnStack(XOp): + __props__ = ("old_dim_name", "unstacked_dims") + + def __init__( + self, + old_dim_name: str, + unstacked_dims: tuple[str, ...], + ): + super().__init__() + if old_dim_name in unstacked_dims: + raise ValueError( + f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}" + ) + if not unstacked_dims: + raise ValueError("Dims to unstack into can't be empty.") + if len(unstacked_dims) == 1: + raise ValueError("Only one dimension to unstack into, use rename instead") + self.old_dim_name = old_dim_name + self.unstacked_dims = unstacked_dims + + def make_node(self, x, *unstacked_length): + x = as_xtensor(x) + if self.old_dim_name not in x.type.dims: + raise ValueError( + f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}" + ) + if not set(self.unstacked_dims).isdisjoint(x.type.dims): + raise ValueError( + f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}" + ) + + if len(unstacked_length) != len(self.unstacked_dims): + raise ValueError( + f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}" + ) + unstacked_lengths = [as_tensor(length, ndim=0) for length in unstacked_length] + if not all(length.dtype in discrete_dtypes for length in unstacked_lengths): + raise TypeError("Unstacked lengths must be discrete dtypes.") + + if x.type.ndim == 1: + batch_dims, batch_shape = (), () + else: + batch_dims, batch_shape = zip( + *( + (dim, shape) + for dim, shape in zip(x.type.dims, x.type.shape) + if dim != self.old_dim_name + ) + ) + + static_unstacked_lengths = [None] * len(unstacked_lengths) + for i, length in enumerate(unstacked_lengths): + try: + static_length = get_scalar_constant_value(length) + except NotScalarConstantError: + pass + else: + static_unstacked_lengths[i] = int(static_length) + + output = xtensor( + dtype=x.type.dtype, + shape=(*batch_shape, *static_unstacked_lengths), + dims=(*batch_dims, *self.unstacked_dims), + ) + return Apply(self, [x, *unstacked_lengths], [output]) + + +def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]): + if dim is not None: + if dims: + raise ValueError( + "Cannot use both positional dim and keyword dims in unstack" + ) + dims = dim + + y = x + for old_dim_name, unstacked_dict in dims.items(): + y = UnStack(old_dim_name, tuple(unstacked_dict.keys()))( + y, *tuple(unstacked_dict.values()) + ) + return y + + +class Transpose(XOp): + __props__ = ("dims",) + + def __init__( + self, + dims: tuple[str | EllipsisType, ...], + ): + super().__init__() + if dims.count(...) > 1: + raise ValueError("an index can only have a single ellipsis ('...')") + self.dims = dims + + def make_node(self, x): + x = as_xtensor(x) + + transpose_dims = self.dims + x_dims = x.type.dims + + if transpose_dims == () or transpose_dims == (...,): + out_dims = tuple(reversed(x_dims)) + elif ... in transpose_dims: + # Handle ellipsis expansion + ellipsis_idx = transpose_dims.index(...) + pre = transpose_dims[:ellipsis_idx] + post = transpose_dims[ellipsis_idx + 1 :] + middle = [d for d in x_dims if d not in pre + post] + out_dims = (*pre, *middle, *post) + if set(out_dims) != set(x_dims): + raise ValueError(f"{out_dims} must be a permuted list of {x_dims}") + else: + out_dims = transpose_dims + if set(out_dims) != set(x_dims): + raise ValueError( + f"{out_dims} must be a permuted list of {x_dims}, unless `...` is included" + ) + + output = xtensor( + dtype=x.type.dtype, + shape=tuple(x.type.shape[x.type.dims.index(d)] for d in out_dims), + dims=out_dims, + ) + return Apply(self, [x], [output]) + + +def transpose( + x, + *dims: str | EllipsisType, + missing_dims: Literal["raise", "warn", "ignore"] = "raise", +): + """Transpose dimensions of the tensor. + + Parameters + ---------- + x : XTensorVariable + Input tensor to transpose. + *dims : str + Dimensions to transpose to. Can include ellipsis (...) to represent + remaining dimensions in their original order. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in the input tensor: + - "raise": Raise an error if any dimensions don't exist (default) + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". + """ + # Validate dimensions + x = as_xtensor(x) + all_dims = x.type.dims + invalid_dims = set(dims) - {..., *all_dims} + if invalid_dims: + if missing_dims != "ignore": + msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {all_dims}" + if missing_dims == "raise": + raise ValueError(msg) + else: + warnings.warn(msg) + # Handle missing dimensions if not raising + dims = tuple(d for d in dims if d in all_dims or d is ...) + + return Transpose(dims)(x) + + +class Concat(XOp): + __props__ = ("dim",) + + def __init__(self, dim: str): + self.dim = dim + super().__init__() + + def make_node(self, *inputs): + inputs = [as_xtensor(inp) for inp in inputs] + concat_dim = self.dim + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + else: + if dim == concat_dim: + if dim_length is None: + dims_and_shape[dim] = None + elif dims_and_shape[dim] is not None: + dims_and_shape[dim] += dim_length + elif dim_length is not None: + # Check for conflicting in non-concatenated shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError( + f"Non-concatenated dimension {dim} has conflicting shapes" + ) + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + if concat_dim not in dims_and_shape: + # It's a new dim, that should be located at the start + dims_and_shape = {concat_dim: len(inputs)} | dims_and_shape + elif dims_and_shape[concat_dim] is not None: + # We need to add +1 for every input that doesn't have this dimension + for inp in inputs: + if concat_dim not in inp.type.dims: + dims_and_shape[concat_dim] += 1 + + dims, shape = zip(*dims_and_shape.items()) + dtype = upcast(*[x.type.dtype for x in inputs]) + output = xtensor(dtype=dtype, dims=dims, shape=shape) + return Apply(self, inputs, [output]) + + +def concat(xtensors, dim: str): + return Concat(dim=dim)(*xtensors) + + +class Squeeze(XOp): + """Remove specified dimensions from an XTensorVariable. + + Only dimensions that are known statically to be size 1 will be removed. + Symbolic dimensions must be explicitly specified, and are assumed safe. + + Parameters + ---------- + dim : tuple of str + The names of the dimensions to remove. + """ + + __props__ = ("dims",) + + def __init__(self, dims): + self.dims = tuple(sorted(set(dims))) + + def make_node(self, x): + x = as_xtensor(x) + + # Validate that dims exist and are size-1 if statically known + dims_to_remove = [] + x_dims = x.type.dims + x_shape = x.type.shape + for d in self.dims: + if d not in x_dims: + raise ValueError(f"Dimension {d} not found in {x.type.dims}") + idx = x_dims.index(d) + dim_size = x_shape[idx] + if dim_size is not None and dim_size != 1: + raise ValueError(f"Dimension {d} has static size {dim_size}, not 1") + dims_to_remove.append(idx) + + new_dims = tuple( + d for i, d in enumerate(x.type.dims) if i not in dims_to_remove + ) + new_shape = tuple( + s for i, s in enumerate(x.type.shape) if i not in dims_to_remove + ) + + out = xtensor( + dtype=x.type.dtype, + shape=new_shape, + dims=new_dims, + ) + return Apply(self, [x], [out]) + + +def squeeze(x, dim=None): + """Remove dimensions of size 1 from an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str or None or iterable of str, optional + The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 + (known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime. + + Returns + ------- + XTensorVariable + A new tensor with the specified dimension(s) removed. + """ + x = as_xtensor(x) + + if dim is None: + dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1) + elif isinstance(dim, str): + dims = (dim,) + else: + dims = tuple(dim) + + if not dims: + return x # no-op if nothing to squeeze + + return Squeeze(dims=dims)(x) diff --git a/pytensor/xtensor/special.py b/pytensor/xtensor/special.py new file mode 100644 index 0000000000..b6d5fc99f9 --- /dev/null +++ b/pytensor/xtensor/special.py @@ -0,0 +1,7 @@ +from pytensor.xtensor.math import exp +from pytensor.xtensor.reduction import REDUCE_DIM + + +def softmax(x, dim: REDUCE_DIM = None): + exp_x = exp(x) + return exp_x / exp_x.sum(dim=dim) # type: ignore diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py new file mode 100644 index 0000000000..fd601df018 --- /dev/null +++ b/pytensor/xtensor/type.py @@ -0,0 +1,717 @@ +import typing +import warnings +from types import EllipsisType + +from pytensor.compile import ( + DeepCopyOp, + ViewOp, + register_deep_copy_op_c_code, + register_view_op_c_code, +) +from pytensor.tensor import TensorType +from pytensor.tensor.math import variadic_mul + + +try: + import xarray as xr + + XARRAY_AVAILABLE = True +except ModuleNotFoundError: + XARRAY_AVAILABLE = False + +from collections.abc import Sequence +from typing import Any, Literal, TypeVar + +import numpy as np + +import pytensor.xtensor as px +from pytensor import _as_symbolic, config +from pytensor.graph import Apply, Constant +from pytensor.graph.basic import OptionalApplyType, Variable +from pytensor.graph.type import HasDataType, HasShape, Type +from pytensor.tensor.basic import constant as tensor_constant +from pytensor.tensor.utils import hash_from_ndarray +from pytensor.tensor.variable import TensorVariable + + +class XTensorType(Type, HasDataType, HasShape): + """A `Type` for Xtensors (Xarray-like tensors with dims).""" + + __props__ = ("dtype", "shape", "dims") + + def __init__( + self, + dtype: str | np.dtype, + *, + dims: Sequence[str], + shape: Sequence[int | None] | None = None, + name: str | None = None, + ): + if dtype == "floatX": + self.dtype = config.floatX + else: + self.dtype = np.dtype(dtype).name + + self.dims = tuple(dims) + if len(set(dims)) < len(dims): + raise ValueError(f"Dimensions must be unique. Found duplicates in {dims}: ") + if shape is None: + self.shape = (None,) * len(self.dims) + else: + self.shape = tuple(shape) + if len(self.shape) != len(self.dims): + raise ValueError( + f"Shape {self.shape} must have the same length as dims {self.dims}" + ) + self.ndim = len(self.dims) + self.name = name + + def clone( + self, + dtype=None, + dims=None, + shape=None, + **kwargs, + ): + if dtype is None: + dtype = self.dtype + if dims is None: + dims = self.dims + if shape is None: + shape = self.shape + return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs) + + def filter(self, value, strict=False, allow_downcast=None): + # TODO implement this + return value + + def convert_variable(self, var): + # TODO: Implement this + return var + + def __repr__(self): + return f"XTensorType({self.dtype}, {self.dims}, {self.shape})" + + def __hash__(self): + return hash((type(self), self.dtype, self.shape, self.dims)) + + def __eq__(self, other): + return ( + type(self) is type(other) + and self.dims == other.dims + and self.shape == other.shape + ) + + def is_super(self, otype): + if type(self) is not type(otype): + return False + if self.dtype != otype.dtype: + return False + if self.dims != otype.dims: + return False + if any( + s_dim_length is not None and s_dim_length != o_dim_length + for s_dim_length, o_dim_length in zip(self.shape, otype.shape) + ): + return False + return True + + +def xtensor( + name: str | None = None, + *, + dims: Sequence[str], + shape: Sequence[int | None] | None = None, + dtype: str | np.dtype = "floatX", +): + return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name) + + +_XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType) + + +class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): + # These can't work because Python requires native output types + def __bool__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python boolean. " + "Call `.astype(bool)` for the symbolic equivalent." + ) + + def __index__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python integer. " + "Call `.astype(int)` for the symbolic equivalent." + ) + + def __int__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python integer. " + "Call `.astype(int)` for the symbolic equivalent." + ) + + def __float__(self): + raise TypeError( + "XTensorVariables cannot be converted to Python float. " + "Call `.astype(float)` for the symbolic equivalent." + ) + + def __complex__(self): + raise TypeError( + "XTensorVariables cannot be converted to Python complex number. " + "Call `.astype(complex)` for the symbolic equivalent." + ) + + # Python valid overloads + def __abs__(self): + return px.math.abs(self) + + def __neg__(self): + return px.math.neg(self) + + def __lt__(self, other): + return px.math.lt(self, other) + + def __le__(self, other): + return px.math.le(self, other) + + def __gt__(self, other): + return px.math.gt(self, other) + + def __ge__(self, other): + return px.math.ge(self, other) + + def __invert__(self): + return px.math.invert(self) + + def __and__(self, other): + return px.math.and_(self, other) + + def __or__(self, other): + return px.math.or_(self, other) + + def __xor__(self, other): + return px.math.xor(self, other) + + def __rand__(self, other): + return px.math.and_(other, self) + + def __ror__(self, other): + return px.math.or_(other, self) + + def __rxor__(self, other): + return px.math.xor(other, self) + + def __add__(self, other): + return px.math.add(self, other) + + def __sub__(self, other): + return px.math.sub(self, other) + + def __mul__(self, other): + return px.math.mul(self, other) + + def __div__(self, other): + return px.math.div(self, other) + + def __pow__(self, other): + return px.math.pow(self, other) + + def __mod__(self, other): + return px.math.mod(self, other) + + def __divmod__(self, other): + return px.math.divmod(self, other) + + def __truediv__(self, other): + return px.math.true_div(self, other) + + def __floordiv__(self, other): + return px.math.floor_div(self, other) + + def __rtruediv__(self, other): + return px.math.true_div(other, self) + + def __rfloordiv__(self, other): + return px.math.floor_div(other, self) + + def __radd__(self, other): + return px.math.add(other, self) + + def __rsub__(self, other): + return px.math.sub(other, self) + + def __rmul__(self, other): + return px.math.mul(other, self) + + def __rdiv__(self, other): + return px.math.div_proxy(other, self) + + def __rmod__(self, other): + return px.math.mod(other, self) + + def __rdivmod__(self, other): + return px.math.divmod(other, self) + + def __rpow__(self, other): + return px.math.pow(other, self) + + def __ceil__(self): + return px.math.ceil(self) + + def __floor__(self): + return px.math.floor(self) + + def __trunc__(self): + return px.math.trunc(self) + + # DataArray-like attributes + # https://docs.xarray.dev/en/latest/api.html#id1 + @property + def values(self) -> TensorVariable: + return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self)) + + # Can't provide property data because that's already taken by Constants! + # data = values + + @property + def coords(self): + raise NotImplementedError("coords not implemented for XTensorVariable") + + @property + def dims(self) -> tuple[str, ...]: + return self.type.dims + + @property + def sizes(self) -> dict[str, TensorVariable]: + return dict(zip(self.dims, self.shape)) + + @property + def as_numpy(self): + # No-op, since the underlying data is always a numpy array + return self + + # ndarray attributes + # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes + @property + def ndim(self) -> int: + return self.type.ndim + + @property + def shape(self) -> tuple[TensorVariable, ...]: + return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore + + @property + def size(self) -> TensorVariable: + return typing.cast(TensorVariable, variadic_mul(*self.shape)) + + @property + def dtype(self): + return self.type.dtype + + # DataArray contents + # https://docs.xarray.dev/en/latest/api.html#dataarray-contents + def rename(self, new_name_or_name_dict=None, **names): + if isinstance(new_name_or_name_dict, str): + new_name = new_name_or_name_dict + name_dict = None + else: + new_name = None + name_dict = new_name_or_name_dict + new_out = px.basic.rename(self, name_dict, **names) + new_out.name = new_name + return new_out + + def copy(self, name: str | None = None): + out = px.math.identity(self) + out.name = name # type: ignore + return out + + def astype(self, dtype): + return px.math.cast(self, dtype) + + def item(self): + raise NotImplementedError("item not implemented for XTensorVariable") + + # Indexing + # https://docs.xarray.dev/en/latest/api.html#id2 + def __setitem__(self, idx, value): + raise TypeError( + "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." + ) + + @property + def loc(self): + raise NotImplementedError("loc not implemented for XTensorVariable") + + def sel(self, *args, **kwargs): + raise NotImplementedError("sel not implemented for XTensorVariable") + + def __getitem__(self, idx): + if isinstance(idx, dict): + return self.isel(idx) + + if not isinstance(idx, tuple): + idx = (idx,) + + return px.indexing.index(self, *idx) + + def isel( + self, + indexers: dict[str, Any] | None = None, + drop: bool = False, # Unused by PyTensor + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + **indexers_kwargs, + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to isel" + ) + indexers = indexers_kwargs + + if not indexers: + # No-op + return self + + if missing_dims not in {"raise", "warn", "ignore"}: + raise ValueError( + f"Unrecognized options {missing_dims} for missing_dims argument" + ) + + # Sort indices and pass them to index + dims = self.type.dims + indices = [slice(None)] * self.type.ndim + for key, idx in indexers.items(): + if idx is Ellipsis: + # Xarray raises a less informative error, suggesting indices must be integer + # But slices are also fine + raise TypeError("Ellipsis (...) is an invalid labeled index") + try: + indices[dims.index(key)] = idx + except IndexError: + if missing_dims == "raise": + raise ValueError( + f"Dimension {key} does not exist. Expected one of {dims}" + ) + elif missing_dims == "warn": + warnings.warn( + f"Dimension {key} does not exist. Expected one of {dims}", + UserWarning, + ) + + return px.indexing.index(self, *indices) + + def set(self, value): + if not ( + self.owner is not None and isinstance(self.owner.op, px.indexing.Index) + ): + raise ValueError( + f"set can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}" + ) + + x, *idxs = self.owner.inputs + return px.indexing.index_assignment(x, value, *idxs) + + def inc(self, value): + if not ( + self.owner is not None and isinstance(self.owner.op, px.indexing.Index) + ): + raise ValueError( + f"inc can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}" + ) + + x, *idxs = self.owner.inputs + return px.indexing.index_increment(x, value, *idxs) + + def _head_tail_or_thin( + self, + indexers: dict[str, Any] | int | None, + indexers_kwargs: dict[str, Any], + *, + kind: Literal["head", "tail", "thin"], + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to head" + ) + indexers = indexers_kwargs + + if indexers is None: + if kind == "thin": + raise TypeError( + "thin() indexers must be either dict-like or a single integer" + ) + else: + # Default to 5 for head and tail + indexers = {dim: 5 for dim in self.type.dims} + + elif not isinstance(indexers, dict): + indexers = {dim: indexers for dim in self.type.dims} + + if kind == "head": + indices = {dim: slice(None, value) for dim, value in indexers.items()} + elif kind == "tail": + sizes = self.sizes + # Can't use slice(-value, None), in case value is zero + indices = { + dim: slice(sizes[dim] - value, None) for dim, value in indexers.items() + } + elif kind == "thin": + indices = {dim: slice(None, None, value) for dim, value in indexers.items()} + return self.isel(indices) + + def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head") + + def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail") + + def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin") + + def squeeze( + self, + dim: Sequence[str] | str | None = None, + drop: bool = False, + axis: int | Sequence[int] | None = None, + ): + if axis is not None: + raise NotImplementedError("Squeeze with axis not Implemented") + return px.shape.squeeze(self, dim) + + # ndarray methods + # https://docs.xarray.dev/en/latest/api.html#id7 + def clip(self, min, max): + return px.math.clip(self, min, max) + + def conj(self): + return px.math.conj(self) + + @property + def imag(self): + return px.math.imag(self) + + @property + def real(self): + return px.math.real(self) + + @property + def T(self): + """Return the full transpose of the tensor. + + This is equivalent to calling transpose() with no arguments. + + Returns + ------- + XTensorVariable + Fully transposed tensor. + """ + return self.transpose() + + # Aggregation + # https://docs.xarray.dev/en/latest/api.html#id6 + def all(self, dim): + return px.reduction.all(self, dim) + + def any(self, dim): + return px.reduction.any(self, dim) + + def max(self, dim): + return px.reduction.max(self, dim) + + def min(self, dim): + return px.reduction.min(self, dim) + + def mean(self, dim): + return px.reduction.mean(self, dim) + + def prod(self, dim): + return px.reduction.prod(self, dim) + + def sum(self, dim): + return px.reduction.sum(self, dim) + + def std(self, dim): + return px.reduction.std(self, dim) + + def var(self, dim): + return px.reduction.var(self, dim) + + def cumsum(self, dim): + return px.reduction.cumsum(self, dim) + + def cumprod(self, dim): + return px.reduction.cumprod(self, dim) + + def diff(self, dim, n=1): + """Compute the n-th discrete difference along the given dimension.""" + slice1 = {dim: slice(1, None)} + slice2 = {dim: slice(None, -1)} + x = self + for _ in range(n): + x = x[slice1] - x[slice2] + return x + + # Reshaping and reorganizing + # https://docs.xarray.dev/en/latest/api.html#id8 + def transpose( + self, + *dims: str | EllipsisType, + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + ): + """Transpose dimensions of the tensor. + + Parameters + ---------- + *dims : str | Ellipsis + Dimensions to transpose. If empty, performs a full transpose. + Can use ellipsis (...) to represent remaining dimensions. + missing_dims : {"raise", "warn", "ignore"}, default="raise" + How to handle dimensions that don't exist in the tensor: + - "raise": Raise an error if any dimensions don't exist + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If missing_dims="raise" and any dimensions don't exist. + If multiple ellipsis are provided. + """ + return px.shape.transpose(self, *dims, missing_dims=missing_dims) + + def stack(self, dim, **dims): + return px.shape.stack(self, dim, **dims) + + def unstack(self, dim, **dims): + return px.shape.unstack(self, dim, **dims) + + +class XTensorConstantSignature(tuple): + def __eq__(self, other): + if type(self) is not type(other): + return False + + (ttype0, data0), (ttype1, data1) = self, other + if ttype0 != ttype1 or data0.shape != data1.shape: + return False + + # TODO: Cash sum and use it in hash like TensorConstant does + return (data0 == data1).all() + + def __ne__(self, other): + return not self == other + + def __hash__(self): + (ttype, data) = self + return hash((type(self), ttype, data.shape)) + + def pytensor_hash(self): + _, data = self + return hash_from_ndarray(data) + + +class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]): + def __init__(self, type: _XTensorTypeType, data, name=None): + # TODO: Add checks that type and data are compatible + Constant.__init__(self, type, data, name) + + def signature(self): + return XTensorConstantSignature((self.type, self.data)) + + +XTensorType.variable_type = XTensorVariable # type: ignore +XTensorType.constant_type = XTensorConstant # type: ignore + + +def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): + x_dims: tuple[str, ...] + if isinstance(x, xr.DataArray): + xarray_dims = x.dims + if not all(isinstance(dim, str) for dim in xarray_dims): + raise NotImplementedError( + "DataArray can only be converted to xtensor_constant if all dims are of string type" + ) + x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims)) + x_data = x.values + + if dims is not None and dims != x_dims: + raise ValueError( + f"xr.DataArray dims {x_dims} don't match requested specified {dims}. " + "Use transpose or rename" + ) + else: + x_data = tensor_constant(x).data + if dims is not None: + x_dims = tuple(dims) + else: + if x_data.ndim == 0: + x_dims = () + else: + raise TypeError( + "Cannot convert TensorLike constant to XTensorConstant without specifying dims." + ) + try: + return XTensorConstant( + XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape), + x_data, + name=name, + ) + except TypeError: + raise TypeError(f"Could not convert {x} to XTensorType") + + +if XARRAY_AVAILABLE: + + @_as_symbolic.register(xr.DataArray) + def as_symbolic_xarray(x, **kwargs): + return xtensor_constant(x, **kwargs) + + +def as_xtensor(x, name=None, dims: Sequence[str] | None = None): + if isinstance(x, Apply): + if len(x.outputs) != 1: + raise ValueError( + "It is ambiguous which output of a multi-output Op has to be fetched.", + x, + ) + else: + x = x.outputs[0] + + if isinstance(x, Variable): + if isinstance(x.type, XTensorType): + return x + if isinstance(x.type, TensorType): + if x.type.ndim > 0 and dims is None: + raise TypeError( + "non-scalar TensorVariable cannot be converted to XTensorVariable without dims." + ) + return px.basic.xtensor_from_tensor(x, dims) + else: + raise TypeError( + "Variable with type {x.type} cannot be converted to XTensorVariable." + ) + try: + return xtensor_constant(x, name=name, dims=dims) + except TypeError as err: + raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err + + +register_view_op_c_code( + XTensorType, + # XTensorType is just TensorType under the hood + *ViewOp.c_code_and_version[TensorType], +) + +register_deep_copy_op_c_code( + XTensorType, + # XTensorType is just TensorType under the hood + *DeepCopyOp.c_code_and_version[TensorType], +) diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py new file mode 100644 index 0000000000..1fe7dd99d7 --- /dev/null +++ b/pytensor/xtensor/vectorization.py @@ -0,0 +1,122 @@ +from itertools import chain + +from pytensor import scalar as ps +from pytensor.graph import Apply, Op +from pytensor.tensor import tensor +from pytensor.tensor.utils import _parse_gufunc_signature +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.type import as_xtensor, xtensor + + +class XElemwise(XOp): + __props__ = ("scalar_op",) + + def __init__(self, scalar_op): + super().__init__() + self.scalar_op = scalar_op + + def make_node(self, *inputs): + inputs = [as_xtensor(inp) for inp in inputs] + if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin): + raise ValueError( + f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}" + ) + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + elif dim_length is not None: + # Check for conflicting shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError(f"Dimension {dim} has conflicting shapes") + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + if dims_and_shape: + output_dims, output_shape = zip(*dims_and_shape.items()) + else: + output_dims, output_shape = (), () + + dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs] + output_dtypes = [ + out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs + ] + outputs = [ + xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape) + for output_dtype in output_dtypes + ] + return Apply(self, inputs, outputs) + + +class XBlockwise(XOp): + __props__ = ("core_op", "signature", "core_dims") + + def __init__( + self, + core_op: Op, + signature: str, + core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]], + ): + super().__init__() + self.core_op = core_op + self.signature = signature + self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self.core_dims = core_dims + + def make_node(self, *inputs): + inputs = [as_xtensor(i) for i in inputs] + if len(inputs) != len(self.inputs_sig): + raise ValueError( + f"Wrong number of inputs, expected {len(self.inputs_sig)}, got {len(inputs)}" + ) + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + elif dim_length is not None: + # Check for conflicting shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError(f"Dimension {dim} has conflicting shapes") + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + core_inputs_dims, core_outputs_dims = self.core_dims + # TODO: Avoid intermediate dict + core_dims = set(chain.from_iterable(core_inputs_dims)) + batched_dims_and_shape = { + k: v for k, v in dims_and_shape.items() if k not in core_dims + } + batch_dims, batch_shape = zip(*batched_dims_and_shape.items()) + + dummy_core_inputs = [] + for inp, core_inp_dims in zip(inputs, core_inputs_dims): + try: + core_static_shape = [ + inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims + ] + except IndexError: + raise ValueError( + f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}" + ) + dummy_core_inputs.append( + tensor(dtype=inp.type.dtype, shape=core_static_shape) + ) + core_node = self.core_op.make_node(*dummy_core_inputs) + + outputs = [ + xtensor( + dtype=core_out.type.dtype, + shape=batch_shape + core_out.type.shape, + dims=batch_dims + core_out_dims, + ) + for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims) + ] + return Apply(self, inputs, outputs) diff --git a/tests/xtensor/__init__.py b/tests/xtensor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py new file mode 100644 index 0000000000..b36873683b --- /dev/null +++ b/tests/xtensor/test_indexing.py @@ -0,0 +1,512 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +import re + +import numpy as np +from xarray import DataArray + +from pytensor.tensor import tensor +from pytensor.xtensor import xtensor +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, + xr_random_like, +) + + +@pytest.mark.parametrize( + "indices", + [ + (0,), + (slice(1, None),), + (slice(None, -1),), + (slice(None, None, -1),), + (0, slice(None), -1, slice(1, None)), + (..., 0, -1), + (0, ..., -1), + (0, -1, ...), + ], +) +@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"]) +def test_basic_indexing(labeled, indices): + if ... in indices and labeled: + pytest.skip("Ellipsis not supported with labeled indexing") + + dims = ("a", "b", "c", "d") + x = xtensor(dims=dims, shape=(2, 3, 5, 7)) + + if labeled: + shufled_dims = tuple(np.random.permutation(dims)) + indices = dict(zip(shufled_dims, indices, strict=False)) + out = x[indices] + + fn = xr_function([x], out) + x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( + x.type.shape + ) + x_test = DataArray(x_test_values, dims=x.type.dims) + res = fn(x_test) + expected_res = x_test[indices] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_existing_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Equivalent ways of indexing a->a + y = x[idx] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[idx_test] + xr_assert_allclose(res, expected_res) + + y = x[(("a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[((("a",), idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[((("a",), idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_new_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("new_a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("new_a",)) + + # Equivalent ways of indexing a->new_a + y = x[(("new_a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("new_a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[((["new_a"], idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[((["new_a"], idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_interacting_with_existing_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Two equivalent ways of indexing a->b + # By labeling the index on a, as "b", we cause pointwise indexing between the two dimensions. + y = x[("b", idx), 1:] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[("b", idx_test), 1:] + xr_assert_allclose(res, expected_res) + + y = x[xidx.rename(a="b"), 1:] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test.rename(a="b"), 1:] + xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize( + "dims_order", + [ + ("a", "b", "ar", "br", "o"), + ("o", "br", "ar", "b", "a"), + ("a", "b", "o", "ar", "br"), + ("a", "o", "ar", "b", "br"), + ], +) +def test_multiple_vector_indexing(dims_order): + x = xtensor(dims=dims_order, shape=(5, 7, 11, 13, 17)) + idx_a = xtensor("idx_a", dtype=int, shape=(4,), dims=("a",)) + idx_b = xtensor("idx_b", dtype=int, shape=(3,), dims=("b",)) + + idxs = [slice(None)] * 5 + idxs[x.type.dims.index("a")] = idx_a + idxs[x.type.dims.index("b")] = idx_b + idxs[x.type.dims.index("ar")] = idx_a[::-1] + idxs[x.type.dims.index("br")] = idx_b[::-1] + + out = x[tuple(idxs)] + fn = xr_function([x, idx_a, idx_b], out) + + x_test = xr_arange_like(x) + idx_a_test = DataArray(np.array([0, 1, 0, 2], dtype=int), dims=("a",)) + idx_b_test = DataArray(np.array([1, 3, 0], dtype=int), dims=("b",)) + res = fn(x_test, idx_a_test, idx_b_test) + idxs_test = [slice(None)] * 5 + idxs_test[x.type.dims.index("a")] = idx_a_test + idxs_test[x.type.dims.index("b")] = idx_b_test + idxs_test[x.type.dims.index("ar")] = idx_a_test[::-1] + idxs_test[x.type.dims.index("br")] = idx_b_test[::-1] + expected_res = x_test[tuple(idxs_test)] + xr_assert_allclose(res, expected_res) + + +def test_matrix_indexing(): + x = xtensor(dims=("a", "b", "c"), shape=(3, 5, 7)) + idx_ab = xtensor("idx_ab", dtype=int, shape=(4, 2), dims=("a", "b")) + idx_cd = xtensor("idx_cd", dtype=int, shape=(4, 3), dims=("c", "d")) + + out = x[idx_ab, slice(1, 3), idx_cd] + fn = xr_function([x, idx_ab, idx_cd], out) + + x_test = xr_arange_like(x) + idx_ab_test = DataArray( + np.array([[0, 1], [1, 2], [0, 2], [-1, -2]], dtype=int), dims=("a", "b") + ) + idx_cd_test = DataArray( + np.array([[1, 2, 3], [0, 4, 5], [2, 6, -1], [3, -2, 0]], dtype=int), + dims=("c", "d"), + ) + res = fn(x_test, idx_ab_test, idx_cd_test) + expected_res = x_test[idx_ab_test, slice(1, 3), idx_cd_test] + xr_assert_allclose(res, expected_res) + + +def test_assign_multiple_out_dims(): + x = xtensor("x", shape=(5, 7), dims=("a", "b")) + idx1 = tensor("idx1", dtype=int, shape=(4, 3)) + idx2 = tensor("idx2", dtype=int, shape=(3, 2)) + out = x[(("out1", "out2"), idx1), (["out2", "out3"], idx2)] + + fn = xr_function([x, idx1, idx2], out) + + rng = np.random.default_rng() + x_test = xr_arange_like(x) + idx1_test = rng.binomial(n=4, p=0.5, size=(4, 3)) + idx2_test = rng.binomial(n=4, p=0.5, size=(3, 2)) + res = fn(x_test, idx1_test, idx2_test) + expected_res = x_test[(("out1", "out2"), idx1_test), (["out2", "out3"], idx2_test)] + xr_assert_allclose(res, expected_res) + + +def test_assign_indexer_dims_fails(): + # Test cases where the implicit naming of the indexer dimensions is not allowed. + x = xtensor("x", shape=(5, 7), dims=("a", "b")) + idx1 = xtensor("idx1", dtype=int, shape=(4,), dims=("c",)) + + with pytest.raises( + IndexError, + match=re.escape( + "Giving a dimension name to an XTensorVariable indexer is not supported: ('d', idx1). " + "Use .rename() instead." + ), + ): + x[("d", idx1),] + + with pytest.raises( + IndexError, + match=re.escape( + "Boolean indexer should be unlabeled or on the same dimension to the indexed array. " + "Indexer is on ('c',) but the target dimension is a." + ), + ): + x[idx1.astype("bool")] + + +class TestVectorizedIndexingNotAllowedToBroadcast: + def test_compile_time_error(self): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx_a = xtensor("idx_a", dtype=int, shape=(4,), dims=("b",)) + idx_b = xtensor("idx_b", dtype=int, shape=(1,), dims=("b",)) + with pytest.raises( + IndexError, match="Dimension of indexers mismatch for dim b" + ): + x[idx_a, idx_b] + + @pytest.mark.xfail( + reason="Check that lowered indexing is not allowed to broadcast not implemented yet" + ) + def test_runtime_error(self): + """ + Test that, unlike in numpy, indices with different shapes cannot act on the same dimension, + even if the shapes could broadcast as per numpy semantics. + """ + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx_a = xtensor("idx_a", dtype=int, shape=(None,), dims=("b",)) + idx_b = xtensor("idx_b", dtype=int, shape=(None,), dims=("b",)) + out = x[idx_a, idx_b] + + fn = xr_function([x, idx_a, idx_b], out) + + x_test = xr_arange_like(x) + valid_idx_a_test = DataArray(np.array([0], dtype=int), dims=("b",)) + idx_b_test = DataArray(np.array([1], dtype=int), dims=("b",)) + xr_assert_allclose( + fn(x_test, valid_idx_a_test, idx_b_test), + x_test[valid_idx_a_test, idx_b_test], + ) + + invalid_idx_a_test = DataArray(np.array([0, 1, 0, 1], dtype=int), dims=("b",)) + with pytest.raises(ValueError): + fn(x_test, invalid_idx_a_test, idx_b_test) + + +@pytest.mark.parametrize( + "dims_order", + [ + ("a", "b", "c", "d"), + ("d", "c", "b", "a"), + ("c", "a", "b", "d"), + ], +) +def test_scalar_integer_indexing(dims_order): + x = xtensor(dims=dims_order, shape=(3, 5, 7, 11)) + scalar_idx = xtensor("scalar_idx", dtype=int, shape=(), dims=()) + vec_idx1 = xtensor("vec_idx", dtype=int, shape=(4,), dims=("a",)) + vec_idx2 = xtensor("vec_idx2", dtype=int, shape=(4,), dims=("c",)) + + idxs = [None] * 4 + idxs[x.type.dims.index("a")] = scalar_idx + idxs[x.type.dims.index("b")] = vec_idx1 + idxs[x.type.dims.index("c")] = vec_idx2 + idxs[x.type.dims.index("d")] = -scalar_idx + out1 = x[tuple(idxs)] + + idxs[x.type.dims.index("a")] = vec_idx1.rename(a="c") + out2 = x[tuple(idxs)] + + fn = xr_function([x, scalar_idx, vec_idx1, vec_idx2], (out1, out2)) + + x_test = xr_arange_like(x) + scalar_idx_test = DataArray(np.array(1, dtype=int), dims=()) + vec_idx_test1 = DataArray(np.array([0, 1, 0, 2], dtype=int), dims=("a",)) + vec_idx_test2 = DataArray(np.array([0, 2, 2, 1], dtype=int), dims=("c",)) + res1, res2 = fn(x_test, scalar_idx_test, vec_idx_test1, vec_idx_test2) + idxs = [None] * 4 + idxs[x.type.dims.index("a")] = scalar_idx_test + idxs[x.type.dims.index("b")] = vec_idx_test1 + idxs[x.type.dims.index("c")] = vec_idx_test2 + idxs[x.type.dims.index("d")] = -scalar_idx_test + expected_res1 = x_test[tuple(idxs)] + idxs[x.type.dims.index("a")] = vec_idx_test1.rename(a="c") + expected_res2 = x_test[tuple(idxs)] + xr_assert_allclose(res1, expected_res1) + xr_assert_allclose(res2, expected_res2) + + +def test_unsupported_boolean_indexing(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + + mat_idx = xtensor("idx", dtype=bool, shape=(4, 2), dims=("a", "b")) + scalar_idx = mat_idx.isel(a=0, b=1) + + for idx in (mat_idx, scalar_idx, scalar_idx.values): + with pytest.raises( + NotImplementedError, + match="Only 1d boolean indexing arrays are supported", + ): + x[idx] + + +def test_boolean_indexing(): + x = xtensor("x", shape=(8, 7), dims=("a", "b")) + bool_idx = xtensor("bool_idx", dtype=bool, shape=(8,), dims=("a",)) + int_idx = xtensor("int_idx", dtype=int, shape=(4, 3), dims=("a", "new_dim")) + + out_vectorized = x[bool_idx, int_idx] + out_orthogonal = x[bool_idx, int_idx.rename(a="b")] + fn = xr_function([x, bool_idx, int_idx], [out_vectorized, out_orthogonal]) + + x_test = xr_arange_like(x) + bool_idx_test = DataArray(np.array([True, False] * 4, dtype=bool), dims=("a",)) + int_idx_test = DataArray( + np.random.binomial(n=4, p=0.5, size=(4, 3)), + dims=("a", "new_dim"), + ) + res1, res2 = fn(x_test, bool_idx_test, int_idx_test) + expected_res1 = x_test[bool_idx_test, int_idx_test] + expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")] + xr_assert_allclose(res1, expected_res1) + xr_assert_allclose(res2, expected_res2) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +def test_basic_index_update(mode): + x = xtensor("x", shape=(11, 7), dims=("a", "b")) + y = xtensor("y", shape=(7, 5), dims=("a", "b")) + x_indexed = x[2:-2, 2:] + update_method = getattr(x_indexed, mode) + + x_updated = [ + update_method(y), + update_method(y.T), + update_method(y.isel(a=-1)), + update_method(y.isel(b=-1)), + update_method(y.isel(a=-2, b=-2)), + ] + + fn = xr_function([x, y], x_updated) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + results = fn(x_test, y_test) + + def update_fn(y): + x = x_test.copy() + if mode == "set": + x[2:-2, 2:] = y + elif mode == "inc": + x[2:-2, 2:] += y + return x + + expected_results = [ + update_fn(y_test), + update_fn(y_test.T), + update_fn(y_test.isel(a=-1)), + update_fn(y_test.isel(b=-1)), + update_fn(y_test.isel(a=-2, b=-2)), + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +@pytest.mark.parametrize("idx_dtype", (int, bool)) +def test_adv_index_update(mode, idx_dtype): + x = xtensor("x", shape=(5, 5), dims=("a", "b")) + y = xtensor("y", shape=(3,), dims=("b",)) + idx = xtensor("idx", dtype=idx_dtype, shape=(None,), dims=("a",)) + + orthogonal_update1 = getattr(x[idx, -3:], mode)(y) + orthogonal_update2 = getattr(x[idx, -3:], mode)(y.rename(b="a")) + if idx_dtype is not bool: + # Vectorized booling indexing/update is not allowed + vectorized_update = getattr(x[idx.rename(a="b"), :3], mode)(y) + else: + with pytest.raises( + IndexError, + match="Boolean indexer should be unlabeled or on the same dimension to the indexed array.", + ): + getattr(x[idx.rename(a="b"), :3], mode)(y) + vectorized_update = x + + outs = [orthogonal_update1, orthogonal_update2, vectorized_update] + + fn = xr_function([x, idx, y], outs) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + if idx_dtype is int: + idx_test = DataArray([0, 1, 2], dims=("a",)) + else: + idx_test = DataArray([True, False, True, True, False], dims=("a",)) + results = fn(x_test, idx_test, y_test) + + def update_fn(x, idx, y): + x = x.copy() + if mode == "set": + x[idx] = y + else: + x[idx] += y + return x + + expected_results = [ + update_fn(x_test, (idx_test, slice(-3, None)), y_test), + update_fn( + x_test, + (idx_test, slice(-3, None)), + y_test.rename(b="a"), + ), + update_fn(x_test, (idx_test.rename(a="b"), slice(None, 3)), y_test) + if idx_dtype is not bool + else x_test, + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +def test_non_consecutive_idx_update(mode): + x = xtensor("x", shape=(2, 3, 5, 7), dims=("a", "b", "c", "d")) + y = xtensor("y", shape=(5, 4), dims=("c", "b")) + x_indexed = x[:, [0, 1, 2, 2], :, ("b", [0, 1, 1, 2])] + out = getattr(x_indexed, mode)(y) + + fn = xr_function([x, y], out) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + + result = fn(x_test, y_test) + expected_result = x_test.copy() + # xarray fails inplace operation with the "tuple trick" + # https://github.com/pydata/xarray/issues/10387 + d_indexer = DataArray([0, 1, 1, 2], dims=("b",)) + if mode == "set": + expected_result[:, [0, 1, 2, 2], :, d_indexer] = y_test + else: + expected_result[:, [0, 1, 2, 2], :, d_indexer] += y_test + xr_assert_allclose(result, expected_result) + + +def test_indexing_renames_into_update_variable(): + x = xtensor("x", shape=(5, 5), dims=("a", "b")) + y = xtensor("y", shape=(3,), dims=("d",)) + idx = xtensor("idx", dtype=int, shape=(None,), dims=("d",)) + + # define "d" dimension by slicing the "a" dimension so we can set y into x + orthogonal_update1 = x[idx].set(y) + fn = xr_function([x, idx, y], orthogonal_update1) + + x_test = np.abs(xr_random_like(x)) + y_test = -np.abs(xr_random_like(y)) + idx_test = DataArray([0, 2, 3], dims=("d",)) + + result = fn(x_test, idx_test, y_test) + expected_result = x_test.copy() + expected_result[idx_test] = y_test + xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("n", ["implicit", 1, 2]) +@pytest.mark.parametrize("dim", ["a", "b"]) +def test_diff(dim, n): + x = xtensor(dims=("a", "b"), shape=(7, 11)) + if n == "implicit": + out = x.diff(dim) + else: + out = x.diff(dim, n=n) + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + res = fn(x_test) + if n == "implicit": + expected_res = x_test.diff(dim) + else: + expected_res = x_test.diff(dim, n=n) + xr_assert_allclose(res, expected_res) diff --git a/tests/xtensor/test_linalg.py b/tests/xtensor/test_linalg.py new file mode 100644 index 0000000000..407867070d --- /dev/null +++ b/tests/xtensor/test_linalg.py @@ -0,0 +1,83 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") +pytest.importorskip("xarray_einstats") + +import numpy as np +from xarray import DataArray +from xarray_einstats.linalg import ( + cholesky as xr_cholesky, +) +from xarray_einstats.linalg import ( + solve as xr_solve, +) + +from pytensor import function +from pytensor.xtensor.linalg import cholesky, solve +from pytensor.xtensor.type import xtensor + + +def test_cholesky(): + x = xtensor("x", dims=("a", "batch", "b"), shape=(4, 3, 4)) + y = cholesky(x, dims=["b", "a"]) + assert y.type.dims == ("batch", "b", "a") + assert y.type.shape == (3, 4, 4) + + fn = function([x], y) + rng = np.random.default_rng(25) + x_ = rng.random(size=(4, 3, 3)) + x_ = x_ @ x_.mT + x_test = DataArray(x_.transpose(1, 0, 2), dims=x.type.dims) + np.testing.assert_allclose( + fn(x_test.values), + xr_cholesky(x_test, dims=["b", "a"]).values, + ) + + +def test_solve_vector_b(): + a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) + b = xtensor("b", dims=("city", "planet"), shape=(None, 2)) + x = solve(a, b, dims=["country", "city"]) + assert x.type.dims == ("galaxy", "planet", "city") + assert x.type.shape == ( + 1, + 2, + None, + ) # Core Solve doesn't make use of the fact A must be square in the static shape + + fn = function([a, b], x) + + rng = np.random.default_rng(25) + a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) + b_test = DataArray(rng.random(size=(4, 2)), dims=b.type.dims) + + np.testing.assert_allclose( + fn(a_test.values, b_test.values), + xr_solve(a_test, b_test, dims=["country", "city"]).values, + ) + + +def test_solve_matrix_b(): + a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) + b = xtensor("b", dims=("district", "city", "planet"), shape=(5, None, 2)) + x = solve(a, b, dims=["country", "city", "district"]) + assert x.type.dims == ("galaxy", "planet", "city", "district") + assert x.type.shape == ( + 1, + 2, + None, + 5, + ) # Core Solve doesn't make use of the fact A must be square in the static shape + + fn = function([a, b], x) + + rng = np.random.default_rng(25) + a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) + b_test = DataArray(rng.random(size=(5, 4, 2)), dims=b.type.dims) + + np.testing.assert_allclose( + fn(a_test.values, b_test.values), + xr_solve(a_test, b_test, dims=["country", "city", "district"]).values, + ) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py new file mode 100644 index 0000000000..210bfe9a80 --- /dev/null +++ b/tests/xtensor/test_math.py @@ -0,0 +1,153 @@ +# ruff: noqa: E402 +import inspect + +import pytest + +from pytensor.scalar import ScalarOp + + +pytest.importorskip("xarray") # + +import numpy as np +from xarray import DataArray + +import pytensor.scalar as ps +import pytensor.xtensor.math as pxm +from pytensor import function +from pytensor.xtensor.basic import rename +from pytensor.xtensor.math import add, exp +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function + + +def test_all_scalar_ops_are_wrapped(): + # This ignores wrapper functions + pxm_members = {name for name, _ in inspect.getmembers(pxm)} + for name, op in inspect.getmembers(ps): + if name in { + "complex_from_polar", + "inclosedrange", + "inopenrange", + "round_half_away_from_zero", + "round_half_to_even", + "scalar_abs", + "scalar_maximum", + "scalar_minimum", + } or name.startswith("convert_to_"): + # These are not regular numpy functions or are unusual alias + continue + if isinstance(op, ScalarOp) and name not in pxm_members: + raise NotImplementedError(f"ScalarOp {name} not wrapped in xtensor.math") + + +def test_scalar_case(): + x = xtensor("x", dims=(), shape=()) + y = xtensor("y", dims=(), shape=()) + out = add(x, y) + + fn = function([x, y], out) + + x_test = DataArray(2.0, dims=()) + y_test = DataArray(3.0, dims=()) + np.testing.assert_allclose(fn(x_test.values, y_test.values), 5.0) + + +def test_dimension_alignment(): + x = xtensor("x", dims=("city", "country", "planet"), shape=(2, 3, 4)) + y = xtensor( + "y", + dims=("galaxy", "country", "city"), + shape=(5, 3, 2), + ) + z = xtensor("z", dims=("universe",), shape=(1,)) + out = add(x, y, z) + assert out.type.dims == ("city", "country", "planet", "galaxy", "universe") + + fn = function([x, y, z], out) + + rng = np.random.default_rng(41) + test_x, test_y, test_z = ( + DataArray(rng.normal(size=inp.type.shape), dims=inp.type.dims) + for inp in [x, y, z] + ) + np.testing.assert_allclose( + fn(test_x.values, test_y.values, test_z.values), + (test_x + test_y + test_z).values, + ) + + +def test_renamed_dimension_alignment(): + x = xtensor("x", dims=("a", "b1", "b2"), shape=(2, 3, 3)) + y = rename(x, b1="b2", b2="b1") + z = rename(x, b2="b3") + assert y.type.dims == ("a", "b2", "b1") + assert z.type.dims == ("a", "b1", "b3") + + out1 = add(x, x) # self addition + assert out1.type.dims == ("a", "b1", "b2") + out2 = add(x, y) # transposed addition + assert out2.type.dims == ("a", "b1", "b2") + out3 = add(x, z) # outer addition + assert out3.type.dims == ("a", "b1", "b2", "b3") + + fn = xr_function([x], [out1, out2, out3]) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + results = fn(x_test) + expected_results = [ + x_test + x_test, + x_test + x_test.rename(b1="b2", b2="b1"), + x_test + x_test.rename(b2="b3"), + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +def test_chained_operations(): + x = xtensor("x", dims=("city",), shape=(None,)) + y = xtensor("y", dims=("country",), shape=(4,)) + z = add(exp(x), exp(y)) + assert z.type.dims == ("city", "country") + assert z.type.shape == (None, 4) + + fn = function([x, y], z) + + x_test = DataArray(np.zeros(3), dims="city") + y_test = DataArray(np.ones(4), dims="country") + + np.testing.assert_allclose( + fn(x_test.values, y_test.values), + (np.exp(x_test) + np.exp(y_test)).values, + ) + + +def test_multiple_constant(): + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + out = exp(x * 2) + 2 + + fn = function([x], out) + + x_test = np.zeros((2, 3), dtype=x.type.dtype) + res = fn(x_test) + expected_res = np.exp(x_test * 2) + 2 + np.testing.assert_allclose(res, expected_res) + + +def test_cast(): + x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32") + yf64 = x.astype("float64") + yi16 = x.astype("int16") + ybool = x.astype("bool") + + fn = xr_function([x], [yf64, yi16, ybool]) + x_test = xr_arange_like(x) + res_f64, res_i16, res_bool = fn(x_test) + xr_assert_allclose(res_f64, x_test.astype("float64")) + xr_assert_allclose(res_i16, x_test.astype("int16")) + xr_assert_allclose(res_bool, x_test.astype("bool")) + + yc64 = x.astype("complex64") + with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): + yc64.astype("float64") diff --git a/tests/xtensor/test_reduction.py b/tests/xtensor/test_reduction.py new file mode 100644 index 0000000000..7cc9a674f1 --- /dev/null +++ b/tests/xtensor/test_reduction.py @@ -0,0 +1,27 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function + + +@pytest.mark.parametrize( + "dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"] +) +@pytest.mark.parametrize( + "method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:] +) +def test_reduction(method, dim): + x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7)) + out = getattr(x, method)(dim=dim) + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + + xr_assert_allclose( + fn(x_test), + getattr(x_test, method)(dim=dim), + ) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py new file mode 100644 index 0000000000..f5db72bf1f --- /dev/null +++ b/tests/xtensor/test_shape.py @@ -0,0 +1,371 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +import re +from itertools import chain, combinations + +import numpy as np +import pytest +from xarray import DataArray +from xarray import concat as xr_concat + +from pytensor.xtensor.shape import ( + concat, + squeeze, + stack, + transpose, + unstack, +) +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, + xr_random_like, +) + + +pytest.importorskip("xarray") + + +def powerset(iterable, min_group_size=0): + "Subsequences of the iterable from shortest to longest." + # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) + s = list(iterable) + return chain.from_iterable( + combinations(s, r) for r in range(min_group_size, len(s) + 1) + ) + + +def test_transpose(): + a, b, c, d, e = "abcde" + + x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11)) + permutations = [ + (a, b, c, d, e), # identity + (e, d, c, b, a), # full tranpose + (), # eqivalent to full transpose + (a, b, c, e, d), # swap last two dims + (..., d, c), # equivalent to (a, b, e, d, c) + (b, a, ..., e, d), # equivalent to (b, a, c, d, e) + (c, a, ...), # equivalent to (c, a, b, d, e) + ] + outs = [transpose(x, *perm) for perm in permutations] + + fn = xr_function([x], outs) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = [x_test.transpose(*perm) for perm in permutations] + for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): + xr_assert_allclose(res_i, expected_res_i) + + +def test_xtensor_variable_transpose(): + """Test the transpose() method of XTensorVariable.""" + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + + # Test basic transpose + out = x.transpose() + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + # Test transpose with specific dimensions + out = x.transpose("c", "a", "b") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b")) + + # Test transpose with ellipsis + out = x.transpose("c", ...) + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test error cases + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')" + ), + ): + x.transpose("d") + + with pytest.raises(ValueError, match="an index can only have a single ellipsis"): + x.transpose("a", ..., "b", ...) + + # Test missing_dims parameter + # Test ignore + out = x.transpose("c", ..., "d", missing_dims="ignore") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test warn + with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"): + out = x.transpose("c", ..., "d", missing_dims="warn") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + +def test_xtensor_variable_T(): + """Test the T property of XTensorVariable.""" + # Test T property with 3D tensor + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + out = x.T + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.T) + + +def test_stack(): + dims = ("a", "b", "c", "d") + x = xtensor("x", dims=dims, shape=(2, 3, 5, 7)) + outs = [ + stack(x, new_dim=dims_to_stack) + for dims_to_stack in powerset(dims, min_group_size=2) + ] + + fn = xr_function([x], outs) + x_test = xr_arange_like(x) + res = fn(x_test) + + expected_res = [ + x_test.stack(new_dim=dims_to_stack) + for dims_to_stack in powerset(dims, min_group_size=2) + ] + for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): + xr_assert_allclose(res_i, expected_res_i) + + +def test_stack_single_dim(): + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 5)) + out = stack(x, {"d": ["a"]}) + assert out.type.dims == ("b", "c", "d") + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = x_test.stack(d=["a"]) + xr_assert_allclose(res, expected_res) + + +def test_multiple_stacks(): + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7)) + out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d")) + + fn = xr_function([x], [out]) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) + xr_assert_allclose(res[0], expected_res) + + +def test_unstack_constant_size(): + x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7)) + y = unstack(x, bc=dict(b=3, c=5)) + assert y.type.dims == ("a", "d", "b", "c") + assert y.type.shape == (2, 7, 3, 5) + + fn = xr_function([x], y) + + x_test = xr_arange_like(x) + x_np = x_test.values + res = fn(x_test) + expected = ( + DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d")) + .stack(bc=("b", "c")) + .unstack("bc") + ) + xr_assert_allclose(res, expected) + + +def test_unstack_symbolic_size(): + x = xtensor(dims=("a", "b", "c")) + y = stack(x, bc=("b", "c")) + y = y / y.sum("bc") + z = unstack(y, bc={"b": x.sizes["b"], "c": x.sizes["c"]}) + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5))) + fn = xr_function([x], z) + res = fn(x_test) + expected_res = x_test / x_test.sum(["b", "c"]) + xr_assert_allclose(res, expected_res) + + +def test_stack_unstack(): + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7)) + stack_x = stack(x, bd=("b", "d")) + unstack_x = unstack(stack_x, bd=dict(b=3, d=7)) + + x_test = xr_arange_like(x) + fn = xr_function([x], unstack_x) + res = fn(x_test) + expected_res = x_test.transpose("a", "c", "b", "d") + xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize("dim", ("a", "b", "new")) +def test_concat(dim): + rng = np.random.default_rng(sum(map(ord, dim))) + + x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3)) + x2 = xtensor("x2", dims=("b", "a"), shape=(3, 2)) + + x3_shape0 = 4 if dim == "a" else 2 + x3_shape1 = 5 if dim == "b" else 3 + x3 = xtensor("x3", dims=("a", "b"), shape=(x3_shape0, x3_shape1)) + + out = concat([x1, x2, x3], dim=dim) + + fn = xr_function([x1, x2, x3], out) + x1_test = xr_random_like(x1, rng) + x2_test = xr_random_like(x2, rng) + x3_test = xr_random_like(x3, rng) + + res = fn(x1_test, x2_test, x3_test) + expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim) + xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize("dim", ("a", "b", "c", "d", "new")) +def test_concat_with_broadcast(dim): + rng = np.random.default_rng(sum(map(ord, dim)) + 1) + + x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3)) + x2 = xtensor("x2", dims=("b", "c"), shape=(3, 5)) + x3 = xtensor("x3", dims=("c", "d"), shape=(5, 7)) + x4 = xtensor("x4", dims=(), shape=()) + + out = concat([x1, x2, x3, x4], dim=dim) + + fn = xr_function([x1, x2, x3, x4], out) + + x1_test = xr_random_like(x1, rng) + x2_test = xr_random_like(x2, rng) + x3_test = xr_random_like(x3, rng) + x4_test = xr_random_like(x4, rng) + res = fn(x1_test, x2_test, x3_test, x4_test) + expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim) + xr_assert_allclose(res, expected_res) + + +def test_concat_scalar(): + x1 = xtensor("x1", dims=(), shape=()) + x2 = xtensor("x2", dims=(), shape=()) + + out = concat([x1, x2], dim="new_dim") + + fn = xr_function([x1, x2], out) + + x1_test = xr_random_like(x1) + x2_test = xr_random_like(x2) + res = fn(x1_test, x2_test) + expected_res = xr_concat([x1_test, x2_test], dim="new_dim") + xr_assert_allclose(res, expected_res) + + +def test_squeeze_explicit_dims(): + """Test squeeze with explicit dimension(s).""" + + # Single dimension + x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) + y1 = squeeze(x1, "country") + fn1 = xr_function([x1], y1) + x1_test = xr_arange_like(x1) + xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country")) + + # Multiple dimensions + x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) + y2 = squeeze(x2, ["b", "c"]) + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"])) + + # Order independence + x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1)) + y3a = squeeze(x3, ["b", "c"]) + y3b = squeeze(x3, ["c", "b"]) + fn3a = xr_function([x3], y3a) + fn3b = xr_function([x3], y3b) + x3_test = xr_arange_like(x3) + xr_assert_allclose(fn3a(x3_test), fn3b(x3_test)) + + # Redundant dimensions + y3c = squeeze(x3, ["b", "b"]) + fn3c = xr_function([x3], y3c) + xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"])) + + # Empty list = no-op + y3d = squeeze(x3, []) + fn3d = xr_function([x3], y3d) + xr_assert_allclose(fn3d(x3_test), x3_test) + + +def test_squeeze_implicit_dims(): + """Test squeeze with implicit dim=None (all size-1 dimensions).""" + + # All dimensions size 1 + x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1)) + y1 = squeeze(x1) + fn1 = xr_function([x1], y1) + x1_test = xr_arange_like(x1) + xr_assert_allclose(fn1(x1_test), x1_test.squeeze()) + + # No dimensions size 1 = no-op + x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4)) + y2 = squeeze(x2) + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn2(x2_test), x2_test) + + # Symbolic shape where runtime shape is 1 → should squeeze + x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown + y3 = squeeze(x3, "b") + x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3))) + fn3 = xr_function([x3], y3) + xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b")) + + # Mixed static + symbolic shapes, where symbolic shape is 1 + x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3)) + y4 = squeeze(x4, "b") + x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3))) + fn4 = xr_function([x4], y4) + xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) + + """ + This test documents that we intentionally don't squeeze dimensions with symbolic shapes + (static_shape=None) even when they are 1 at runtime, while xarray does squeeze them. + """ + # Create a tensor with a symbolic dimension that will be 1 at runtime + x = xtensor("x", dims=("a", "b", "c")) # shape unknown + y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3))) + fn = xr_function([x], y) + res = fn(x_test) + + # Our implementation should not squeeze the symbolic dimension + assert "b" in res.dims + # While xarray would squeeze it + assert "b" not in x_test.squeeze().dims + + +def test_squeeze_errors(): + """Test error cases for squeeze.""" + + # Non-existent dimension + x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) + with pytest.raises(ValueError, match="Dimension .* not found"): + squeeze(x1, "time") + + # Dimension size > 1 + with pytest.raises(ValueError, match="has static size .* not 1"): + squeeze(x1, "city") + + # Symbolic shape: dim is not 1 at runtime → should raise + x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown + y2 = squeeze(x2, "b") + x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3))) + fn2 = xr_function([x2], y2) + with pytest.raises(Exception): + fn2(x2_test) diff --git a/tests/xtensor/util.py b/tests/xtensor/util.py new file mode 100644 index 0000000000..81dc98a75c --- /dev/null +++ b/tests/xtensor/util.py @@ -0,0 +1,60 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +import numpy as np +from xarray import DataArray +from xarray.testing import assert_allclose + +from pytensor import function +from pytensor.xtensor.type import XTensorType + + +def xr_function(*args, **kwargs): + """Compile and wrap a PyTensor function to return xarray DataArrays.""" + fn = function(*args, **kwargs) + symbolic_outputs = fn.maker.fgraph.outputs + assert all( + isinstance(out.type, XTensorType) for out in symbolic_outputs + ), "All outputs must be xtensor" + + def xfn(*xr_inputs): + np_inputs = [ + inp.values if isinstance(inp, DataArray) else inp for inp in xr_inputs + ] + np_outputs = fn(*np_inputs) + if not isinstance(np_outputs, tuple | list): + return DataArray(np_outputs, dims=symbolic_outputs[0].type.dims) + else: + return tuple( + DataArray(res, dims=out.type.dims) + for res, out in zip(np_outputs, symbolic_outputs) + ) + + xfn.fn = fn + return xfn + + +def xr_assert_allclose(x, y, *args, **kwargs): + # Assert that two xarray DataArrays are close, ignoring coordinates + x = x.drop_vars(x.coords) + y = y.drop_vars(y.coords) + assert_allclose(x, y, *args, **kwargs) + + +def xr_arange_like(x): + return DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + + +def xr_random_like(x, rng=None): + if rng is None: + rng = np.random.default_rng() + + return DataArray( + rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims + )