-
Notifications
You must be signed in to change notification settings - Fork 132
Rewriting the kron function using JAX implementation #684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
so based on the compute pattern in the JAX implementation, our function should look like this if (a.ndim < b.ndim):
a = ptb.expand_dims(a, range(b.ndim - a.ndim))
elif (b.ndim < a.ndim):
b = ptb.expand_dims(b, range(a.ndim - b.ndim))
a_reshaped = ptb.expand_dims(a, range(1, 2*a.ndim,2))
b_reshaped = ptb.expand_dims(b, range(1, 2*b.ndim,2))
out_shape = tuple(np.multiply(a.shape, b.shape))
output_out_of_shape = np.multiply(a_reshaped, b_reshaped)
output_reshaped = output_out_of_shape.reshape(out_shape)
return output_reshaped I am slightly confused whether I should be using numpy's functions (since it is now in nlinalg.py) or JAX functions (because we are following their implementation) |
Neither! You want to use all pytensor functions. As you've identified, we want to use |
Pytensor's
|
Also make sure you move the kron tests from |
…o the nlinalg file
could you please check if i have made the changes correctly. about the testing, is there any documentation where I could read about how all the tests work for pytensor. that would also be helpful for me in general |
Tests are always going to be idiosyncratic, there's not really a fixed format. In general, we compute the pytensor version then make sure the outputs match the reference implementation. In this case, we should compare against numpy. Basically, you first declare symbolic placeholder tensors like this:
Next, it computes the kron symbolically and compiles the graph into a python function. This is all done in one go. The signature for
So this line actually tests your pytensor implementation:
Then we just verify that this is the same as the numpy output. |
Also, make sure you are running pre-commit before every push. You can do |
this time i ran it with the pre-commit installed. i also went through the class we made for testing and could understand the general structure behind it. |
hmmm i am not quite sure why the pre-commit checks are still failing. what exactly is ruff and ruff-format? |
It's one of the tools that pre-commit runs. You can try doing |
oh cool. that worked now. you also mentioned this
should i try adding this to the test class in a new function similar to test_numpy_2d? apart from that, is everything done as required? |
You need to import
I'd personally do the 2nd one, and remove the reshape import. |
Actually it looks like it's already doing and 3d and a 4d test. You can delete the numpy test entirely, since it's now redundant with the first test. I'd just request that you test that matrix inverse is commutative with the new kron implementation, since that's the bug that set off this whole issue. Make a new test function like |
cool i'll try doing that |
I just authorized the full testing run on the CI, so you'll see if they pass here now. Before it was just checking that you ran pre-commit and that the docs built. You should run the tests locally before you push, because the CI is quite slow. If you're in pycharm you can click the little green arrow to run a single test. If you're not, you can do |
for the inverse computation, which function do i use since there's a few of them : |
ah great i will do the testing locally as well now because i can see its taking a long time here |
Use |
hmmm |
pytensor/pytensor/tensor/nlinalg.py Line 159 in f97d9ea
Oh sorry I'm a pepega, you're in tests. You want to use |
Yep I have to write a proposal which is due on the 2nd of April. I think if I can take today to familiarise myself with the problem, then take 2 days to write a proposal which you can review, and finally make an updated proposal, I should be good before the deadline |
Oh this is because you can't invert a non-square matrix. Use |
great! that fixed it. everything should be working now |
btw, as i was looking through the rewrites, i saw that it is still importing KroneckerProduct from slinalg, and i have only added a function should i make changes to that as well and how would the import be updated across all files |
Your tests are working but you have a small precision issue when computing at half-precision. Have a look here for how I handle this. For choosing the actual tolerance, you can set pytensor to float32 mode by putting |
No don't worry about that in this PR. We just want to make this one targeted change here. |
alright. some questions that i have about this
|
No it's commuting, because the first case passes. It even works on batches, because the 2nd case passes. I think there's just some precision issues with |
i think we should be good now, can you authorise all tests? |
tests/tensor/test_nlinalg.py
Outdated
np_val = np.kron(a, b) | ||
np.testing.assert_allclose(out, np_val) | ||
|
||
def test_kron_commutes_with_inv(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since no code is reused, why not parametrize?
tests/tensor/test_nlinalg.py
Outdated
for i, (shp0, shp1) in enumerate( | ||
zip([(2, 3), (2, 3), (2, 4, 3)], [(6, 7), (4, 3, 5), (4, 3, 5)]) | ||
): | ||
if (pytensor.config.floatX == "float32") & (i == 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just skip everything with float32?
tests/tensor/test_nlinalg.py
Outdated
super().setup_method() | ||
|
||
def test_perform(self): | ||
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here? Use pytest.mark.parametrize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey! could you briefly explain how pytest.mark.parametrize
works. It'll anyways be useful in writing future tests and I can update these as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a decorator that passes keyword arguments into a test function. For example:
@pytest.mark.parameterize('a', [1,2,3])
def test(a):
assert a < 2
This will run 3 tests, one for each value of a
. The important thing is that the arguments into the test function match the strings you give in the decorator.
You can pass multiple parameters like this:
@pytest.mark.parameterize('a, b', [(1, 2) ,(2, 3), (3, 4)])
def test(a, b):
assert a < b
This will make 3 tests, all of which will pass.
Or you can make the full cartesian product between parameters by stacking decorators:
@pytest.mark.parameterize('a', [1, 2, 3])
@pytest.mark.parameterize('b', [3, 4, 5])
def test(a, b):
assert a < b
This will make 3 * 3 = 9
tests, and you should get one failure (for the 3 < 3
case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you this is really helpful. safe to say these are just fancy for loops? 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but with better outputs when you run pytest. It splits each parameterization into its own test, and so you can know exactly which combination of parameters is failing. If you have a loop, it just tells you pass/fail, not at which step of the loop.
Also it means you get more green dots when you run pytest, which is obviously extremely important
PS: my comments are nitpicks, feel free to ignore. I imagine it was just following the old test template |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests are passing, everything looks good to me now. You can refactor the test to use parameterize if you'd like, but I'm not going to make it a blocker.
this is so much better, i'll be sure to use these while writing tests |
Look great! Thanks for taking this on |
amazing! btw this is probably not the best place to ask but is there a discord or slack channel for messaging. i was having some questions about grahh rewriting and they would be better answered in a discussion instead of comments |
The discourse is a good place, and the public discussion is a good record for future people who need similar help. |
Description
Related Issue
linalg.kron
is not correct for matrices ofndims > 2
#640Checklist
Type of change