Skip to content

Labeled tensors #1411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2610862
Use DimShuffle instead of Reshape in `ix_`
ricardoV94 May 22, 2025
2054ad8
WIP Basic labeled tensor functionality
ricardoV94 Aug 2, 2023
c3da710
Implement Elemwise and Blockwise operations for XTensorVariables
ricardoV94 May 26, 2025
3c53271
Implement reduction operations for XTensorVariables
ricardoV94 May 25, 2025
5a7b23c
Implement xarray-like Concat
ricardoV94 May 26, 2025
2d4899a
Lint whitespace changes
ricardoV94 May 29, 2025
f8566b3
Use xr_arange_like in existing tests
ricardoV94 May 29, 2025
3f5194f
Implement transpose for XTensorVariables
AllenDowney May 28, 2025
6821b93
Add stack metthod
ricardoV94 May 30, 2025
ec3d700
Implement unstack operation for XTensorVariables
OriolAbril May 22, 2025
1ab3e09
Remove dprint from test
ricardoV94 Jun 2, 2025
692c53c
Test xtensor module
ricardoV94 Jun 2, 2025
81317d8
Don't allow repeated dims
ricardoV94 Jun 3, 2025
28a4d86
Fix scalar case in XElemwise
ricardoV94 May 29, 2025
ff63514
Check shape length matches dims in XTensorType
ricardoV94 May 29, 2025
f8dbd5c
Fix bug in `xtensor_constant`
ricardoV94 Jun 2, 2025
61aedd7
Implement casting for XTensorVariables
ricardoV94 Jun 3, 2025
e68729c
Implement index operations for XTensorVariables
ricardoV94 May 21, 2025
9ecc772
.fix import
ricardoV94 Jun 3, 2025
e703d37
.tweak test name
ricardoV94 Jun 3, 2025
9971ca3
Don't lose static shape in AdvancedIncSubtensor
ricardoV94 Jun 2, 2025
ea690e6
Implement index update for XTensorVariables
ricardoV94 Jun 2, 2025
1905904
Add diff method to XTensorVariables
ricardoV94 May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
Expand Down Expand Up @@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
Expand Down Expand Up @@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
Expand All @@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4539,7 +4539,7 @@ def ix_(*args):
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1)))
out.append(new)
return tuple(out)

Expand Down
18 changes: 0 additions & 18 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)


class CumsumOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj


class CumprodOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj


def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.

Expand Down
7 changes: 1 addition & 6 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs):
return Apply(
self,
(x, y, *new_inputs),
[
tensor(
dtype=x.type.dtype,
shape=tuple(1 if s == 1 else None for s in x.type.shape),
)
],
[x.type()],
)

def perform(self, node, inputs, out_):
Expand Down
16 changes: 16 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import (
linalg,
special,
)
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
xtensor_constant,
)


warnings.warn("xtensor module is experimental and full of bugs")
86 changes: 86 additions & 0 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections.abc import Sequence

from pytensor.graph import Apply, Op
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


class XOp(Op):
"""A base class for XOps that shouldn't be materialized"""

def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


class XViewOp(Op):
# Make this a View Op with C-implementation
view_map = {0: [0]}

def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0]


class TensorFromXTensor(XViewOp):
__props__ = ()

def make_node(self, x) -> Apply:
if not isinstance(x.type, XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, [x], [output])


tensor_from_xtensor = TensorFromXTensor()


class XTensorFromTensor(XViewOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)

def make_node(self, x) -> Apply:
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])


def xtensor_from_tensor(x, dims):
return XTensorFromTensor(dims=dims)(x)


class Rename(XViewOp):
__props__ = ("new_dims",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims

def make_node(self, x):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
return Apply(self, [x], [output])


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
names = name_dict

x = as_xtensor(x)
old_names = x.type.dims
new_names = list(old_names)
for old_name, new_name in names.items():
try:
new_names[old_names.index(old_name)] = new_name
except IndexError:
raise ValueError(
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
Loading
Loading