Skip to content

Rewriting the kron function using JAX implementation #684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,40 @@ def tensorsolve(a, b, axes=None):
return TensorSolve(axes)(a, b)


def kron(a, b):
"""Kronecker product.

Same as np.kron(a, b)

Parameters
----------
a: array_like
b: array_like

Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions
"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)

if a.ndim < b.ndim:
a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim)))
elif b.ndim < a.ndim:
b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2 * a.ndim, 2)))
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2 * b.ndim, 2)))
out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape)
return output_reshaped


__all__ = [
"pinv",
"inv",
Expand All @@ -1025,4 +1059,5 @@ def tensorsolve(a, b, axes=None):
"norm",
"tensorinv",
"tensorsolve",
"kron",
]
48 changes: 1 addition & 47 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -559,51 +559,6 @@ def eigvalsh(a, b, lower=True):
return Eigvalsh(lower)(a, b)


def kron(a, b):
"""Kronecker product.

Same as scipy.linalg.kron(a, b).

Parameters
----------
a: array_like
b: array_like

Returns
-------
array_like with a.ndim + b.ndim - 2 dimensions

Notes
-----
numpy.kron(a, b) != scipy.linalg.kron(a, b)!
They don't have the same shape and order when
a.ndim != b.ndim != 2.

"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
f"You passed {int(a.ndim)} and {int(b.ndim)}."
)
o = ptm.outer(a, b)
o = o.reshape(ptb.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *range(3, o.ndim))
if shf.ndim == 3:
shf = o.dimshuffle(1, 0, 2)
o = shf.flatten()
else:
o = shf.reshape(
(
o.shape[0] * o.shape[2],
o.shape[1] * o.shape[3],
*(o.shape[i] for i in range(4, o.ndim)),
)
)
return o


class Expm(Op):
"""
Compute the matrix exponential of a square array.
Expand Down Expand Up @@ -1021,7 +976,6 @@ def block_diag(*matrices: TensorVariable):
"cholesky",
"solve",
"eigvalsh",
"kron",
"expm",
"solve_discrete_lyapunov",
"solve_continuous_lyapunov",
Expand Down
40 changes: 40 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
det,
eig,
eigh,
kron,
lstsq,
matrix_dot,
matrix_inverse,
Expand Down Expand Up @@ -580,3 +581,42 @@ def test_eval(self):
t_binv1 = tf_b1(self.b1)
assert _allclose(t_binv, n_binv)
assert _allclose(t_binv1, n_binv1)


class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)

def setup_method(self):
self.op = kron
super().setup_method()

@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1):
if len(shp0) + len(shp1) == 2:
pytest.skip("Sum of shp0 and shp1 must be more than 2")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Using the np.kron to compare outputs
np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val)

@pytest.mark.parametrize(
"i, shp0, shp1",
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],
)
def test_kron_commutes_with_inv(self, i, shp0, shp1):
if (pytensor.config.floatX == "float32") & (i == 2):
pytest.skip("Half precision insufficient for test 3 to pass")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
b = self.rng.random(shp1).astype(config.floatX)
lhs_f = function([x, y], pinv(kron(x, y)))
rhs_f = function([x, y], kron(pinv(x), pinv(y)))
atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(lhs_f(a, b), rhs_f(a, b), atol=atol)
41 changes: 0 additions & 41 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
cholesky,
eigvalsh,
expm,
kron,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
Expand Down Expand Up @@ -512,46 +511,6 @@ def test_expm_grad_3():
utt.verify_grad(expm, [A], rng=rng)


class TestKron(utt.InferShapeTester):
rng = np.random.default_rng(43)

def setup_method(self):
self.op = kron
super().setup_method()

def test_perform(self):
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
np.testing.assert_allclose(out, scipy_val)

def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.random(shp1).astype(config.floatX)
out = f(a, b)
assert np.allclose(out, np.kron(a, b))


def test_solve_discrete_lyapunov_via_direct_real():
N = 5
rng = np.random.default_rng(utt.fetch_seed())
Expand Down