Skip to content

Add rewrite to lift linear algebra through certain linalg ops #622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Feb 2, 2024

Adds rewrites to lift certain linear algebra ops (inv, pinv, cholesky) through "compositions" of matrices (block_diag, kron).

Kron is currently broken, so it's marked as a draft.

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@jessegrabowski
Copy link
Member Author

I tracked down the bug we hit during the hackathon. We were importing pytensor.compile.builders.function, which is a module. We wanted pytensor.compile.builders.function.function.

Now the problem is that kron isn't well defined on batches. We noted that it works for tensors of any dimension, but it doesn't define batched behavior. If we have tensors of shape (batch, a, b) and (batch, c, d), the expected output of a batched kron is shape (batch, a*c, b*d), but we get something I really don't understand. Examples:

from scipy import linalg
import numpy as np
import pytensor.tensor as pt

rng = np.random.default_rng()

# Desired
A = rng.normal(size=(4, 2, 3))
B = rng.normal(size=(4, 5, 6))

x = np.vectorize(linalg.kron, signature='(n,m),(o,p)->(q,r)')(A, B)
x.shape  # Out: (4, 10, 18)

x = pt.linalg.kron(A, B)
x.eval() # Out: (12, 8, 5, 6)

So our batched tests for this rewrite currently failed for the batched kron. Since the base case is defined on arbitrary dimensions, it seems like our options are:

  1. Have users define their own BlockWise kron as needed. Apply the test to a customized batched kron.
  2. Add an axis argument to kron that lets the user specify which axes are the core dimensions. It could default to None (all dimensions are core), but we test on axis=[-1,-2].
  3. xfail the kron batched test.

Thoughts?

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 11, 2024

Call pt.vectorize on the test from the 2D core case?

Does the rewrite still hold as is though?

@jessegrabowski
Copy link
Member Author

The rewrite works fine on the 2d case, it just fails on the batch case because the way kron currently works breaks the "liftability" if ndim > 2.

This is actually an artifact of how we define kron. It seems like there are two ways: np.kron is not the same as scipy.linalg,kron. Theano went with scipy, JAX went with numpy. The rewrite would work with the numpy definition:

np.allclose(np.kron(np.linalg.pinv(A), np.linalg.pinv(B)),
            np.linalg.pinv(np.kron(A, B)))
# True

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 11, 2024

Okay, so maybe let's add a kwarg: scipy_like? I wouldn't add axis just because it's more complex for us. Numpy behavior is simpler in that axis are always the last 2, which we can get automatically with Blockwise?

Regardless, we have to make sure we do the rewrite correctly (or restrict it to the cases where the logic is correct, although it doesn't sound like it should be too complicated to support both forms?)

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Feb 11, 2024

For the scipy form, we need a way for the user to tell us what are batch axes and what are not. If we don't know, the rewrite will break for anything bigger than 2d. I guess if we had a mode keyword, we could just check it and skip the rewrite if it's scipy.

For the numpy form, we technically still need to know this (because Blockwise(kron(A, B)) isn't the same as kron(A,B) for ndim > 2), but the rewrite still works in all cases. If were were in this case, we could just provide kron "as is" and demand users vectorize it themselves as needed, as jax does.

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 11, 2024

For the scipy form, we need a way for the user to tell us what are batch axes and what are not. If we don't know, the rewrite will break for anything bigger than 2d. I guess if we had a mode keyword, we could just check it and skip the rewrite if it's scipy.

Isn't our current implementation already the Scipy kron. We don't need to do anything. A numpy kron is a Blockwised kron with 2d core inputs, which right now users would have to create manually. For this PR you can decide to support only Scipy kron or also a Blockwised Kron?

I don't see why the axis is needed. Scipy mixes everything if I understand correctly, and Numpy only the last two axes. From the signature / absense of the Blockwise (and its signature) we should know which case it is.

@jessegrabowski
Copy link
Member Author

Numpy is not blockwise, it just multiplies all the axes together in a way that still respects commutativity with matrix operations. Here's a guide to all the outputs:

# 2d case -- both are the same
A = rng.normal(size=(2, 3))
B = rng.normal(size=(5, 6))
np.allclose(np.kron(A, B), linalg.kron(A, B))
# True

# nd case -- both are different. Note the effect on the "batch" dimension in the numpy case (becomes 4*4=16)
A = rng.normal(size=(4, 2, 3))
B = rng.normal(size=(4, 5, 6))

print(np.kron(A,B).shape)
print(linalg.kron(A,B).shape)

# (16, 10, 18)
# (12, 8, 5, 6)

# Vectorized 2d case -- both are the same again,  but not the same as the numpy nd case
vec_np_kron = np.vectorize(np.kron, signature='(n,m),(o,p)->(q,r)')
vec_sp_kron = np.vectorize(linalg.kron, signature='(n,m),(o,p)->(q,r)')

print(vec_np_kron(A,B).shape)
print(vec_sp_kron(A,B).shape)

# (4, 10, 18)
# (4, 10, 18)

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Apr 8, 2024

After #684 this should be ready to go. I touched pytensor/link/c/cmodule.py because mypy was complaining about it on my end, but now mypy is complaining about something on the CI, so I'll need some guidance on what to do there.

@ricardoV94
Copy link
Member

After #684 this should be ready to go. I touched pytensor/link/c/cmodule.py because mypy was complaining about it on my end, but now mypy is complaining about something on the CI, so I'll need some guidance on what to do there.

Things like this have happened in the past. My best hunch is a difference between individuals commits vs all files. Locally is doing one and in CI the other and mypy opinion changes accordingly.

Or python version.

Or something else

@jessegrabowski
Copy link
Member Author

Probably the eclipse for all I know. Mypy is something else.

@ricardoV94 ricardoV94 added enhancement New feature or request graph rewriting linalg Linear algebra labels Apr 8, 2024
@ricardoV94
Copy link
Member

Nice! Proved trickier because of the scipy thing but not terrible in the end

f = f - grad(pt_sum(f), y)
f = f - grad(pt_sum(f), y)
fn = function([x, y, z], f)
g = grad(pt_sum(f), y)
Copy link
Member

@ricardoV94 ricardoV94 Apr 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a new test for the error and leave this one unchanged? Or did this start failing? Looks like it shouldn't have?

Copy link
Member Author

@jessegrabowski jessegrabowski Apr 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it started failing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why? It was working before, may need to investigate or is it obvious?

Copy link
Member Author

@jessegrabowski jessegrabowski Apr 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It merits investigation. But it might be that the test was always flakey and now we're catching it. The duplicated input is one of the intermediate f representations -- something like *4 Matrix<?,?>. The true inputs x,y,z aren't the problem. I assume this is something to do with how the grad function works?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intermediate variables can be repeated just fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test doesn't have obvious double inputs though, what's going on there? Is it because the original f is considered input to both grad(f.sum(), y) and grad(grad(f.sum(), y).sum(), y) ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see what are the duplicate inputs. In any case you proved we are too restrictive with the OFG([a], [a]) example.

Copy link
Member

@ricardoV94 ricardoV94 Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The L_op of the first grad call looks like:

pytensor.dprint(self.fgraph, print_fgraph_inputs=True)

→ *0-<Matrix(float64, shape=(?, ?))> [id A]
→ *1-<Matrix(float64, shape=(?, ?))> [id B]
→ *2-<Matrix(float64, shape=(?, ?))> [id C]
→ *3-<Matrix(float64, shape=(?, ?))> [id D]
→ *4-<Matrix(float64, shape=(?, ?))> [id E]
*4-<Matrix(float64, shape=(?, ?))> [id E]
Mul [id F] 0
 ├─ *4-<Matrix(float64, shape=(?, ?))> [id E]
 └─ *2-<Matrix(float64, shape=(?, ?))> [id C]
Mul [id G] 1
 ├─ *4-<Matrix(float64, shape=(?, ?))> [id E]
 └─ *1-<Matrix(float64, shape=(?, ?))> [id B]

You can see the last input 4* (the output_grad) is also the first output (the input grad wrt to x is just the output grad). So when the second L_op is called on the first one, it will try to build an OFG that takes as inputs 4* twice, once from its inputs, and once from its outputs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with skipping the input check if on_unused_inputs == 'ignore'. I thought that was better than removing the error. I guess a real solution will be to check for outputs that are inputs when we build the OpFromGraph, and clone them if so?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but let's do that in a separate PR. Then we can get rid of this hacky params as well

@jessegrabowski
Copy link
Member Author

float32 😍

Copy link

codecov bot commented Apr 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.78%. Comparing base (14651fb) to head (c1da527).
Report is 7 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #622      +/-   ##
==========================================
+ Coverage   80.76%   80.78%   +0.02%     
==========================================
  Files         162      162              
  Lines       46707    46757      +50     
  Branches    11422    11440      +18     
==========================================
+ Hits        37723    37773      +50     
  Misses       6732     6732              
  Partials     2252     2252              
Files Coverage Δ
pytensor/compile/builders.py 77.45% <100.00%> (+0.27%) ⬆️
pytensor/tensor/nlinalg.py 94.76% <100.00%> (+0.50%) ⬆️
pytensor/tensor/rewriting/linalg.py 88.69% <100.00%> (+1.70%) ⬆️

... and 5 files with indirect coverage changes

@jessegrabowski
Copy link
Member Author

I squashed this PR down to two commits, one for the actual rewrite and one for the changes that touched OpFromGraph (that hopefully we can revert down the road). After the CI runs again I'll merge. The codecov is failing because we don't have tests for the OpFromGraph error -- do you care? I can handle this in the next PR that will deal less hackishly with the duplicated inputs/outputs problem.

@ricardoV94
Copy link
Member

Yeah we can merge as is

@ricardoV94
Copy link
Member

I wouldn't remove the check later though, just not use kwargs. So a test is not a bad idea, just don't want to drag the PR too much more

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@ricardoV94 ricardoV94 merged commit eb18f0e into pymc-devs:main Apr 28, 2024
55 checks passed
@jessegrabowski jessegrabowski deleted the linalg-lift2 branch April 29, 2024 09:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants