|
| 1 | +# Copyright 2023 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import NamedTuple |
| 16 | + |
| 17 | +import jax |
| 18 | + |
| 19 | + |
| 20 | +class EighResult(NamedTuple): |
| 21 | + eigenvalues: jax.Array |
| 22 | + eigenvectors: jax.Array |
| 23 | + |
| 24 | +class QRResult(NamedTuple): |
| 25 | + Q: jax.Array |
| 26 | + R: jax.Array |
| 27 | + |
| 28 | +class SlogdetResult(NamedTuple): |
| 29 | + sign: jax.Array |
| 30 | + logabsdet: jax.Array |
| 31 | + |
| 32 | +class SVDResult(NamedTuple): |
| 33 | + U: jax.Array |
| 34 | + S: jax.Array |
| 35 | + Vh: jax.Array |
| 36 | + |
| 37 | + |
| 38 | +def matmul(x1, x2, /): |
| 39 | + """Computes the matrix product.""" |
| 40 | + return jax.numpy.matmul(x1, x2) |
| 41 | + |
| 42 | + |
| 43 | +def matrix_transpose(x, /): |
| 44 | + """Transposes a matrix (or a stack of matrices) x.""" |
| 45 | + if x.ndim < 2: |
| 46 | + raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {x.ndim=}") |
| 47 | + return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) |
| 48 | + |
| 49 | + |
| 50 | +def tensordot(x1, x2, /, *, axes=2): |
| 51 | + """Returns a tensor contraction of x1 and x2 over specific axes.""" |
| 52 | + return jax.numpy.tensordot(x1, x2, axes=axes) |
| 53 | + |
| 54 | + |
| 55 | +def vecdot(x1, x2, /, *, axis=-1): |
| 56 | + """Computes the (vector) dot product of two arrays.""" |
| 57 | + rank = max(x1.ndim, x2.ndim) |
| 58 | + x1 = jax.lax.broadcst_to_rank(x1, rank) |
| 59 | + x2 = jax.lax.broadcast_to_rank(x2, rank) |
| 60 | + if x1.shape[axis] != x2.shape[axis]: |
| 61 | + raise ValueError("x1 and x2 must have the same size along specified axis.") |
| 62 | + x1, x2 = jax.lax.broadcast_arrays(x1, x2) |
| 63 | + x1 = jax.numpy.moveaxis(x1, axis, -1) |
| 64 | + x2 = jax.numpy.moveaxis(x2, axis, -1) |
| 65 | + return jax.numpy.matmul(x1[..., None, :], x2[..., None])[..., 0, 0] |
0 commit comments