Skip to content

Rewriting the kron function using JAX implementation #684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor.graph.op import Op
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor import reshape
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
Expand Down Expand Up @@ -1026,3 +1027,45 @@ def tensorsolve(a, b, axes=None):
"tensorinv",
"tensorsolve",
]


# Adding the kron function here, which earlier used to be in slinalg.py
def kron(a, b):
"""Kronecker product.

Same as np.kron(a, b)

Parameters
----------
a: array_like
b: array_like

Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions

Notes
-----
numpy.kron(a, b) != scipy.linalg.kron(a, b)!
They don't have the same shape and order when
a.ndim != b.ndim != 2.

"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)

if a.ndim < b.ndim:
a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim)))
elif b.ndim < a.ndim:
b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2 * a.ndim, 2)))
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2 * b.ndim, 2)))
out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = reshape(output_out_of_shape, out_shape)
return output_reshaped
46 changes: 2 additions & 44 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -559,49 +559,7 @@ def eigvalsh(a, b, lower=True):
return Eigvalsh(lower)(a, b)


def kron(a, b):
"""Kronecker product.

Same as scipy.linalg.kron(a, b).

Parameters
----------
a: array_like
b: array_like

Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions

Notes
-----
numpy.kron(a, b) != scipy.linalg.kron(a, b)!
They don't have the same shape and order when
a.ndim != b.ndim != 2.

"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
o = ptm.outer(a, b)
o = o.reshape(ptb.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *range(3, o.ndim))
if shf.ndim == 3:
shf = o.dimshuffle(1, 0, 2)
o = shf.flatten()
else:
o = shf.reshape(
(
o.shape[0] * o.shape[2],
o.shape[1] * o.shape[3],
*(o.shape[i] for i in range(4, o.ndim)),
)
)
return o
# Removed kron function from here and moved to nlinalg.py


class Expm(Op):
Expand Down
37 changes: 37 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
det,
eig,
eigh,
kron,
lstsq,
matrix_dot,
matrix_inverse,
Expand Down Expand Up @@ -580,3 +581,39 @@ def test_eval(self):
t_binv1 = tf_b1(self.b1)
assert _allclose(t_binv, n_binv)
assert _allclose(t_binv1, n_binv1)


class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)

def setup_method(self):
self.op = kron
super().setup_method()

def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here? Use pytest.mark.parametrize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! could you briefly explain how pytest.mark.parametrize works. It'll anyways be useful in writing future tests and I can update these as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a decorator that passes keyword arguments into a test function. For example:

@pytest.mark.parameterize('a', [1,2,3])
def test(a):
    assert a < 2

This will run 3 tests, one for each value of a. The important thing is that the arguments into the test function match the strings you give in the decorator.

You can pass multiple parameters like this:

@pytest.mark.parameterize('a, b', [(1, 2) ,(2, 3), (3, 4)])
def test(a, b):
    assert a < b

This will make 3 tests, all of which will pass.

Or you can make the full cartesian product between parameters by stacking decorators:

@pytest.mark.parameterize('a', [1, 2, 3])
@pytest.mark.parameterize('b', [3, 4, 5])
def test(a, b):
    assert a < b

This will make 3 * 3 = 9 tests, and you should get one failure (for the 3 < 3 case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you this is really helpful. safe to say these are just fancy for loops? 😄

Copy link
Member

@jessegrabowski jessegrabowski Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but with better outputs when you run pytest. It splits each parameterization into its own test, and so you can know exactly which combination of parameters is failing. If you have a loop, it just tells you pass/fail, not at which step of the loop.

Also it means you get more green dots when you run pytest, which is obviously extremely important

x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Using the np.kron to compare outputs
np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val)

def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
assert np.allclose(out, np.kron(a, b))
41 changes: 0 additions & 41 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
cholesky,
eigvalsh,
expm,
kron,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
Expand Down Expand Up @@ -512,46 +511,6 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng)


class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)

def setup_method(self):
self.op = kron
super().setup_method()

def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
np.testing.assert_allclose(out, scipy_val)

def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
assert np.allclose(out, np.kron(a, b))


def test_solve_discrete_lyapunov_via_direct_real():
N = 5
rng = np.random.default_rng(utt.fetch_seed())
Expand Down