Skip to content

Labeled tensors #1411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open

Labeled tensors #1411

wants to merge 19 commits into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 22, 2025

import numpy as np

from pytensor import function
from pytensor.xtensor.basic import add, exp
from pytensor.xtensor.type import xtensor

x = xtensor("x", dims=("city",), shape=(None,))
y = xtensor("y", dims=("country",), shape=(4,))
z = add(exp(x), exp(y))
assert z.type.dims == ("city", "country")
assert z.type.shape == (None, 4)

fn = function([x, y], z)
fn.dprint(print_type=True)
# XTensorFromTensor{dims=('city', 'country')} [id A] <XTensorType{dtype='float64', shape=(None, 4), dims=('city', 'country')}> 7
#  └─ Add [id B] <Matrix(float64, shape=(?, 4))> 6
#     ├─ Exp [id C] <Matrix(float64, shape=(?, 1))> 5
#     │  └─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(?, 1))> 3
#     │     └─ TensorFromXTensor [id E] <Vector(float64, shape=(?,))> 1
#     │        └─ x [id F] <XTensorType{dtype='float64', shape=(None,), dims=('city',)}>
#     └─ Exp [id G] <Matrix(float64, shape=(1, 4))> 4
#        └─ ExpandDims{axis=0} [id H] <Matrix(float64, shape=(1, 4))> 2
#           └─ TensorFromXTensor [id I] <Vector(float64, shape=(4,))> 0
#              └─ y [id J] <XTensorType{dtype='float64', shape=(4,), dims=('country',)}>

np.testing.assert_allclose(
    fn(x=np.zeros(3), y=np.zeros(4)),
    np.full((3, 4), 2.0),
)

Strategy

We implement xarray-like dummy Ops that respect / propagate dims semantics, and lower them to regular PyTensor graphs with rewrites.

Note in the example above the dummy TensorFromXtensor and XTensorFromTensor remain in the final graph. If we had created a function with Tensor inputs and outputs that are only then converted (symbolically) to and from xtensor, respectively, the final graph would have no signs of dimension operations, other than how it was constructed.

I suggest registering those rewrites in an xtensor_lowering database.

Coordinates

For now I'm playing with how far we can get without coordinates. This means the graphs produced by an xarray-like syntax are much more amenable to the numpy-like backend of PyTensor. Otherwise it involves a lot of Pandas-like stuff (e.g., Multiindex) that we don't really have. It may be feasible, specially if nothing is symbolic, but... I fear a rabbit hole of edge cases)

Gradients

These ops are currently not differentiable, but one can lower the graph and then call the gradient. I do want to try the lazy grad approach from #788

Help implementing more Ops so we have MVP to try out with PyMC next. We need some Ops

Open a PR on top of this branch, I'll try to merge quickly! Try to make it clean (one commit per Op, unless it's like a factory of related Ops)

Implementing means:

  1. Create a dummy Op
  2. Create a rewrite that lowers the dummy Op to real tensor operations
    3.1 The rewrites "box" the lower tensor operations between TensorFromXTensor and XTensorFromTensor calls, so that the replacements are valid in terms of types. There are rewrites to remove chains of useless TensorFromXTensor/XTensorFromTensor that should clean up everything in the middle of the graph.
  3. Add a test that compares with xarray/xarray_einstants and proves it's correct
  4. If you really want, test the error checks (I haven't been doing that)

Interplay between XTensorTypes and TensorTypes / weakly typed inputs

  • Symbolic conversion to and from XTensor and Tensor
  • Make sure MetaOps accept non-XTensorType scalar inputs
  • Make MetaOps (Elemwise / Blockwise) "cast" regular numpy/TensorVariable inputs to XTensorVariable to behave like xarray does (dims are considered to match positionally, try it out).
  • Operators as methods (__add__ and the like so you can do x + x)

Meta Ops

  • Elemwise (automatically generated, some may fail)
  • Blockwise (each Op needs manual curation though)
  • XTensorVariable.where (1 or 2, can be implemented with math.switch probably can't support drop=True)
  • xtensor.where(3 inputs, just an alias to xtensor.math.switch that should be available at the module level)
  • CAReduce (Sum, All, Mean, ...)
  • Einsum (probably low priority)
  • Scan (a thin wrapper around Scan should be fine, just need to add the time dim to the outputs, and perhaps use that to also align the sequences)

Math stuff

  • Cast (it's a parametrized ScalarOp so the general XElemwise logic won't suffice)
  • Dot
  • Mean/Std/Variance (there's no CAReduce Op corresponding to those)
  • Everything that is a blockwise in vanilla pytensor (like all of linalg)
  • Argmax / Argmin
  • Sort / Argsort?

Shape stuff

  • Rename
  • Transpose
  • ExpandDims
  • Squeeze
  • swap_dims (should just be a call to Rename)
  • Stack (missing adding as method)
  • Unstack Add unstack for xtensors #1412
  • Concat
  • Broadcast_arrays (chaining Elemwise second should achieve this)

Array creation stuff

  • ZerosLike / OnesLike (just return self.x * 0, self.x * 0 + 1)? PyTensor will do the right thing when it gets lowered)

Indexing stuff

  • __getitem__ + isel Implement indexing operations for XTensorVariables #1429
  • __getitem__ + isel for boolean indices (should work fine, just need to test and lift raise error)
  • Indexing update (aka set and inc_subtensor)
    It probably makes sense to convert the non-XTensor indices to XTensor indices if they can be rendered equivalent, to reduce logic needed.

RandomVariables

This is quite important, as we'll need those for PyMC models! They are a mix of blockwise + size argument (which can or not be redundant)

Graph transformations

  • grad (and jacobian and all that)
  • vectorize_graph

📚 Documentation preview 📚: https://pytensor--1411.org.readthedocs.build/en/1411/

@ricardoV94 ricardoV94 added the enhancement New feature or request label May 22, 2025
@ricardoV94 ricardoV94 force-pushed the labeled_tensors branch 5 times, most recently from d6a3ddf to 177a4c2 Compare May 26, 2025 17:36
Copy link

codecov bot commented Jun 2, 2025

Codecov Report

Attention: Patch coverage is 75.86957% with 222 lines in your changes missing coverage. Please review.

Project coverage is 82.02%. Comparing base (2d414d4) to head (81317d8).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/xtensor/type.py 60.81% 101 Missing and 15 partials ⚠️
pytensor/xtensor/shape.py 73.93% 21 Missing and 22 partials ⚠️
pytensor/xtensor/reduction.py 82.27% 12 Missing and 2 partials ⚠️
pytensor/xtensor/rewriting/reduction.py 79.31% 7 Missing and 5 partials ⚠️
pytensor/xtensor/vectorization.py 81.53% 8 Missing and 4 partials ⚠️
pytensor/xtensor/basic.py 79.62% 8 Missing and 3 partials ⚠️
pytensor/xtensor/rewriting/basic.py 75.75% 8 Missing ⚠️
pytensor/xtensor/linalg.py 86.66% 2 Missing and 2 partials ⚠️
pytensor/xtensor/special.py 60.00% 2 Missing ⚠️

❌ Your patch check has failed because the patch coverage (75.86%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1411      +/-   ##
==========================================
- Coverage   82.12%   82.02%   -0.10%     
==========================================
  Files         211      225      +14     
  Lines       49722    50665     +943     
  Branches     8820     8942     +122     
==========================================
+ Hits        40832    41556     +724     
- Misses       6710     6878     +168     
- Partials     2180     2231      +51     
Files with missing lines Coverage Δ
pytensor/tensor/basic.py 91.69% <ø> (+0.02%) ⬆️
pytensor/tensor/extra_ops.py 88.88% <ø> (+0.67%) ⬆️
pytensor/xtensor/math.py 100.00% <100.00%> (ø)
pytensor/xtensor/rewriting/__init__.py 100.00% <100.00%> (ø)
pytensor/xtensor/rewriting/shape.py 100.00% <100.00%> (ø)
pytensor/xtensor/rewriting/utils.py 100.00% <100.00%> (ø)
pytensor/xtensor/rewriting/vectorization.py 100.00% <100.00%> (ø)
pytensor/xtensor/special.py 60.00% <60.00%> (ø)
pytensor/xtensor/linalg.py 86.66% <86.66%> (ø)
pytensor/xtensor/rewriting/basic.py 75.75% <75.75%> (ø)
... and 6 more

... and 13 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants