Skip to content

Rewriting the kron function using JAX implementation #684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import reshape
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
Expand Down Expand Up @@ -1031,7 +1032,7 @@ def tensorsolve(a, b, axes=None):
def kron(a, b):
"""Kronecker product.

Same as scipy.linalg.kron(a, b).
Uses the JAX implementation for kron.

Parameters
----------
Expand All @@ -1048,6 +1049,8 @@ def kron(a, b):
They don't have the same shape and order when
a.ndim != b.ndim != 2.

This new function now works for ndim > 2

"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
Expand All @@ -1056,18 +1059,14 @@ def kron(a, b):
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
o = ptm.outer(a, b)
o = o.reshape(ptb.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *range(3, o.ndim))
if shf.ndim == 3:
shf = o.dimshuffle(1, 0, 2)
o = shf.flatten()
else:
o = shf.reshape(
(
o.shape[0] * o.shape[2],
o.shape[1] * o.shape[3],
*(o.shape[i] for i in range(4, o.ndim)),
)
)
return o

if (a.ndim < b.ndim):
a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim)))
elif (b.ndim < a.ndim):
b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2*a.ndim,2)))
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2*b.ndim,2)))
out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = reshape(output_out_of_shape, out_shape)
return output_reshaped
40 changes: 40 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
tensorinv,
tensorsolve,
trace,
kron,
)
from pytensor.tensor.type import (
lmatrix,
Expand Down Expand Up @@ -580,3 +581,42 @@ def test_eval(self):
t_binv1 = tf_b1(self.b1)
assert _allclose(t_binv, n_binv)
assert _allclose(t_binv1, n_binv1)

class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)

def setup_method(self):
self.op = kron
super().setup_method()

def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here? Use pytest.mark.parametrize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! could you briefly explain how pytest.mark.parametrize works. It'll anyways be useful in writing future tests and I can update these as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a decorator that passes keyword arguments into a test function. For example:

@pytest.mark.parameterize('a', [1,2,3])
def test(a):
    assert a < 2

This will run 3 tests, one for each value of a. The important thing is that the arguments into the test function match the strings you give in the decorator.

You can pass multiple parameters like this:

@pytest.mark.parameterize('a, b', [(1, 2) ,(2, 3), (3, 4)])
def test(a, b):
    assert a < b

This will make 3 tests, all of which will pass.

Or you can make the full cartesian product between parameters by stacking decorators:

@pytest.mark.parameterize('a', [1, 2, 3])
@pytest.mark.parameterize('b', [3, 4, 5])
def test(a, b):
    assert a < b

This will make 3 * 3 = 9 tests, and you should get one failure (for the 3 < 3 case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you this is really helpful. safe to say these are just fancy for loops? 😄

Copy link
Member

@jessegrabowski jessegrabowski Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but with better outputs when you run pytest. It splits each parameterization into its own test, and so you can know exactly which combination of parameters is failing. If you have a loop, it just tells you pass/fail, not at which step of the loop.

Also it means you get more green dots when you run pytest, which is obviously extremely important

x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
np.testing.assert_allclose(out, scipy_val)

def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
assert np.allclose(out, np.kron(a, b))
42 changes: 0 additions & 42 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
cholesky,
eigvalsh,
expm,
kron,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
Expand Down Expand Up @@ -511,47 +510,6 @@ def test_expm_grad_3():

utt.verify_grad(expm, [A], rng=rng)


class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)

def setup_method(self):
self.op = kron
super().setup_method()

def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
np.testing.assert_allclose(out, scipy_val)

def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
assert np.allclose(out, np.kron(a, b))


def test_solve_discrete_lyapunov_via_direct_real():
N = 5
rng = np.random.default_rng(utt.fetch_seed())
Expand Down