diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 81ff82ada2..c79178a74c 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -8,6 +8,7 @@ Det, Eig, Eigh, + KroneckerProduct, MatrixInverse, MatrixPinv, QRFull, @@ -104,6 +105,14 @@ def batched_dot(a, b): return batched_dot +@jax_funcify.register(KroneckerProduct) +def jax_funcify_KroneckerProduct(op, **kwargs): + def _kron(x, y): + return jnp.kron(x, y) + + return _kron + + @jax_funcify.register(Max) def jax_funcify_Max(op, **kwargs): axis = op.axis diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 2175670ee6..0bc6749448 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -165,3 +165,15 @@ def test_pinv_hermitian(): assert not np.allclose( jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) ) + + +def test_kron(): + x = matrix("x") + y = matrix("y") + z = pt_nlinalg.kron(x, y) + + fgraph = FunctionGraph([x, y], [z]) + x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + + compare_jax_and_py(fgraph, [x_np, y_np])