Skip to content

Rewriting the kron function using JAX implementation #684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 27, 2024
Merged

Rewriting the kron function using JAX implementation #684

merged 14 commits into from
Mar 27, 2024

Conversation

tanish1729
Copy link
Contributor

@tanish1729 tanish1729 commented Mar 26, 2024

Description

  • Changes the implementation of the kron function from scipy to the JAX implementation, which works for nd arrays

Related Issue

Checklist

  • As suggested by jesse, this PR will be used for discussing about the issue as I solve it. Thus, I will add a checklist later while making all the changes.

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@tanish1729
Copy link
Contributor Author

so based on the compute pattern in the JAX implementation, our function should look like this

if (a.ndim < b.ndim):
        a = ptb.expand_dims(a, range(b.ndim - a.ndim))
    elif (b.ndim < a.ndim):
        b = ptb.expand_dims(b, range(a.ndim - b.ndim))
    a_reshaped = ptb.expand_dims(a, range(1, 2*a.ndim,2))
    b_reshaped = ptb.expand_dims(b, range(1, 2*b.ndim,2))
    out_shape = tuple(np.multiply(a.shape, b.shape))
    output_out_of_shape = np.multiply(a_reshaped, b_reshaped)
    output_reshaped = output_out_of_shape.reshape(out_shape)
    return output_reshaped

I am slightly confused whether I should be using numpy's functions (since it is now in nlinalg.py) or JAX functions (because we are following their implementation)

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 26, 2024

Neither! You want to use all pytensor functions. As you've identified, we want to use ptb.expand_dims. You can replace np.multiply with just *, since it's only doing elementwise multiplication anyway. If you really like the function, you can use ptm.mul, but I don't see the point

@jessegrabowski
Copy link
Member

Pytensor's expand_dims is a lot fussier about types, so you'll have to wrap the range in tuple to make it work. You also have a small typo in the b_reshaped range -- it should start with 0. Something like this:

if (a.ndim < b.ndim):
    a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim)))
elif (b.ndim < a.ndim):
    b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2*a.ndim, 2)))
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2*b.ndim, 2)))
out_shape = tuple(np.multiply(a.shape, b.shape))
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape)

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 26, 2024

Also make sure you move the kron tests from test_slinag to test_nlinalg

@tanish1729
Copy link
Contributor Author

could you please check if i have made the changes correctly.
i imported the reshape function explicitly so that i am sure i use the correct thing i hope thats alright.

about the testing, is there any documentation where I could read about how all the tests work for pytensor. that would also be helpful for me in general

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 26, 2024

Tests are always going to be idiosyncratic, there's not really a fixed format. In general, we compute the pytensor version then make sure the outputs match the reference implementation. In this case, we should compare against numpy.

Basically, you first declare symbolic placeholder tensors like this:

x = tensor(dtype="floatX", shape=(None,) * len(shp0))
y = tensor(dtype="floatX", shape=(None,) * len(shp1))

shape = (None,) means a 1d tensor (vector) of arbitrary shape (we're not committing to a shape yet). shape = (None, None) is a matrix, and so on.

Next, it computes the kron symbolically and compiles the graph into a python function. This is all done in one go. The signature for pytensor.function is (inputs, outputs). The two symbolic variables are the inputs and the output is their kroneker product

f = function([x, y], kron(x, y))

So this line actually tests your pytensor implementation:

out = f(a, b)

Then we just verify that this is the same as the numpy output.

@jessegrabowski
Copy link
Member

Also, make sure you are running pre-commit before every push. You can do [mamba / conda / pip] install pre-commit then just type pre-commit install . It will automatically set everything up, and when you push it will make sure your code is formatted correctly. You have these two check fails because you're not running it.

@tanish1729
Copy link
Contributor Author

this time i ran it with the pre-commit installed. i also went through the class we made for testing and could understand the general structure behind it.

@tanish1729
Copy link
Contributor Author

hmmm i am not quite sure why the pre-commit checks are still failing. what exactly is ruff and ruff-format?

@jessegrabowski
Copy link
Member

It's one of the tools that pre-commit runs.

You can try doing pre-commit run --all to force it to run on all the files you've changed. You should have to do this again, though. If it makes changes, make sure to git add the changes then run it again.

@tanish1729
Copy link
Contributor Author

oh cool. that worked now. you also mentioned this

Add a new test for the case of kron(3d, 3d), make sure that it 1) matches the numpy np.kron output, and 2) respects kron(inv(A), inv(B)) == inv(kron(A, B))

should i try adding this to the test class in a new function similar to test_numpy_2d?

apart from that, is everything done as required?

@tanish1729
Copy link
Contributor Author

tanish1729 commented Mar 26, 2024

...ugh, now there's an issue with importing reshape
image

it says its most likely due to a circular import

@jessegrabowski
Copy link
Member

You need to import reshape either direction from pytensor.tensor.shape, or just use the method on the tensor objects like this:

output_reshaped = output_out_of_shape.reshape(out_shape)

I'd personally do the 2nd one, and remove the reshape import.

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 26, 2024

should i try adding this to the test class in a new function similar to test_numpy_2d?

Actually it looks like it's already doing and 3d and a 4d test. You can delete the numpy test entirely, since it's now redundant with the first test.

I'd just request that you test that matrix inverse is commutative with the new kron implementation, since that's the bug that set off this whole issue. Make a new test function like test_kron_commutes_with_inv(self), and using a 2d case and a 3d case, test that inv(kron(a, b)) is the same as kron(inv(a), inv(b)). Use small matrices (like 2x2) so it's not computationally burdensome.

@tanish1729
Copy link
Contributor Author

cool i'll try doing that
another general query that i have. when i have written the tests, how do i make sure that they are being passed. are these the same tests being mentioned as "All tests" here on github or do i have to test it locally

@jessegrabowski
Copy link
Member

are these the same tests being mentioned as "All tests" here on github or do i have to test it locally

I just authorized the full testing run on the CI, so you'll see if they pass here now. Before it was just checking that you ran pre-commit and that the docs built.

You should run the tests locally before you push, because the CI is quite slow. If you're in pycharm you can click the little green arrow to run a single test. If you're not, you can do pytest tests/tensor/test_nlinalg.py::TestKron to just run the test suite you're working on (you can type any function or class after the :: to just run that test)

@tanish1729
Copy link
Contributor Author

tanish1729 commented Mar 26, 2024

test that inv(kron(a, b)) is the same as kron(inv(a), inv(b))

for the inverse computation, which function do i use since there's a few of them :
MatrixInverse (matrix_inverse), TensorInv, tensorinv

@tanish1729
Copy link
Contributor Author

ah great i will do the testing locally as well now because i can see its taking a long time here

@jessegrabowski
Copy link
Member

Use inv, it's defined in the same file you're working in so you don't even have to import it.

@tanish1729
Copy link
Contributor Author

hmmm inv is not defined anywhere in the file. could you please take a look again and see which one to use.

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 26, 2024

inv = matrix_inverse = Blockwise(MatrixInverse())

Oh sorry I'm a pepega, you're in tests. You want to use pt.linalg.inv in the test context, or import it from pytensor.tensor.linalg import inv

@tanish1729
Copy link
Contributor Author

Yep I have to write a proposal which is due on the 2nd of April. I think if I can take today to familiarise myself with the problem, then take 2 days to write a proposal which you can review, and finally make an updated proposal, I should be good before the deadline

@tanish1729
Copy link
Contributor Author

there is still an error while testing.
image

i tried breaking down the test into smaller ones but its always this same error with "last 2 dimensions of the array must be square". can you point out the exact part which is triggering this

@jessegrabowski
Copy link
Member

Oh this is because you can't invert a non-square matrix. Use pt.linalg.pinv in the test instead, it should have the same commutative property.

@tanish1729
Copy link
Contributor Author

great! that fixed it. everything should be working now

@tanish1729
Copy link
Contributor Author

btw, as i was looking through the rewrites, i saw that it is still importing KroneckerProduct from slinalg, and i have only added a function kron in nlinalg, not a KroneckerProduct class like it used to be.

should i make changes to that as well and how would the import be updated across all files

@jessegrabowski
Copy link
Member

Your tests are working but you have a small precision issue when computing at half-precision. Have a look here for how I handle this. For choosing the actual tolerance, you can set pytensor to float32 mode by putting pytensor.config.floatX = 'float32' at the start of your test and run it locally. Just make sure to remove that line before you commit.

@jessegrabowski
Copy link
Member

No don't worry about that in this PR. We just want to make this one targeted change here.

@tanish1729
Copy link
Contributor Author

alright. some questions that i have about this

  1. the test passes locally but not here, what is the difference in the way both of them are run
  2. i should add an absolute tolerance value within the np.testing.assert_allclose ?

@tanish1729
Copy link
Contributor Author

tanish1729 commented Mar 27, 2024

image

is it just a tolerance issue? the difference in matched elements is a lot that the matrices arent even the same. doesnt this mean the commutativeness isnt holding 🤔

but again, this is only when i add the pytensor.config.floatX = 'float32' line before the test. what exactly is that doing

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 27, 2024

No it's commuting, because the first case passes. It even works on batches, because the 2nd case passes. I think there's just some precision issues with pinv on the big matrix. I think you can just add a check that skips the 3rd test case if config.floatX == 'float32'. Here's an example.

@tanish1729
Copy link
Contributor Author

i think we should be good now, can you authorise all tests?

np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val)

def test_kron_commutes_with_inv(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since no code is reused, why not parametrize?

for i, (shp0, shp1) in enumerate(
zip([(2, 3), (2, 3), (2, 4, 3)], [(6, 7), (4, 3, 5), (4, 3, 5)])
):
if (pytensor.config.floatX == "float32") & (i == 2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just skip everything with float32?

super().setup_method()

def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here? Use pytest.mark.parametrize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! could you briefly explain how pytest.mark.parametrize works. It'll anyways be useful in writing future tests and I can update these as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a decorator that passes keyword arguments into a test function. For example:

@pytest.mark.parameterize('a', [1,2,3])
def test(a):
    assert a < 2

This will run 3 tests, one for each value of a. The important thing is that the arguments into the test function match the strings you give in the decorator.

You can pass multiple parameters like this:

@pytest.mark.parameterize('a, b', [(1, 2) ,(2, 3), (3, 4)])
def test(a, b):
    assert a < b

This will make 3 tests, all of which will pass.

Or you can make the full cartesian product between parameters by stacking decorators:

@pytest.mark.parameterize('a', [1, 2, 3])
@pytest.mark.parameterize('b', [3, 4, 5])
def test(a, b):
    assert a < b

This will make 3 * 3 = 9 tests, and you should get one failure (for the 3 < 3 case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you this is really helpful. safe to say these are just fancy for loops? 😄

Copy link
Member

@jessegrabowski jessegrabowski Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but with better outputs when you run pytest. It splits each parameterization into its own test, and so you can know exactly which combination of parameters is failing. If you have a loop, it just tells you pass/fail, not at which step of the loop.

Also it means you get more green dots when you run pytest, which is obviously extremely important

@ricardoV94
Copy link
Member

PS: my comments are nitpicks, feel free to ignore.

I imagine it was just following the old test template

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are passing, everything looks good to me now. You can refactor the test to use parameterize if you'd like, but I'm not going to make it a blocker.

@tanish1729
Copy link
Contributor Author

this is so much better, i'll be sure to use these while writing tests

@jessegrabowski
Copy link
Member

Look great! Thanks for taking this on

@jessegrabowski jessegrabowski merged commit 378cb40 into pymc-devs:main Mar 27, 2024
@tanish1729
Copy link
Contributor Author

amazing! btw this is probably not the best place to ask but is there a discord or slack channel for messaging. i was having some questions about grahh rewriting and they would be better answered in a discussion instead of comments

@jessegrabowski
Copy link
Member

The discourse is a good place, and the public discussion is a good record for future people who need similar help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug: linalg.kron is not correct for matrices of ndims > 2
4 participants