Skip to content

Commit 7bceb3b

Browse files
committed
more functions
1 parent f48c2f5 commit 7bceb3b

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

jax/experimental/array_api/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,15 @@
173173
sum as sum,
174174
var as var
175175
)
176+
177+
from jax.experimental.array_api._utility_functions import (
178+
all as all,
179+
any as any,
180+
)
181+
182+
from jax.experimental.array_api._linear_algebra_functions import (
183+
matmul as matmul,
184+
matrix_transpose as matrix_transpose,
185+
tensordot as tensordot,
186+
vecdot as vecdot,
187+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import NamedTuple
16+
17+
import jax
18+
19+
20+
class EighResult(NamedTuple):
21+
eigenvalues: jax.Array
22+
eigenvectors: jax.Array
23+
24+
class QRResult(NamedTuple):
25+
Q: jax.Array
26+
R: jax.Array
27+
28+
class SlogdetResult(NamedTuple):
29+
sign: jax.Array
30+
logabsdet: jax.Array
31+
32+
class SVDResult(NamedTuple):
33+
U: jax.Array
34+
S: jax.Array
35+
Vh: jax.Array
36+
37+
38+
def matmul(x1, x2, /):
39+
"""Computes the matrix product."""
40+
return jax.numpy.matmul(x1, x2)
41+
42+
43+
def matrix_transpose(x, /):
44+
"""Transposes a matrix (or a stack of matrices) x."""
45+
if x.ndim < 2:
46+
raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {x.ndim=}")
47+
return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
48+
49+
50+
def tensordot(x1, x2, /, *, axes=2):
51+
"""Returns a tensor contraction of x1 and x2 over specific axes."""
52+
return jax.numpy.tensordot(x1, x2, axes=axes)
53+
54+
55+
def vecdot(x1, x2, /, *, axis=-1):
56+
"""Computes the (vector) dot product of two arrays."""
57+
rank = max(x1.ndim, x2.ndim)
58+
x1 = jax.lax.broadcst_to_rank(x1, rank)
59+
x2 = jax.lax.broadcast_to_rank(x2, rank)
60+
if x1.shape[axis] != x2.shape[axis]:
61+
raise ValueError("x1 and x2 must have the same size along specified axis.")
62+
x1, x2 = jax.lax.broadcast_arrays(x1, x2)
63+
x1 = jax.numpy.moveaxis(x1, axis, -1)
64+
x2 = jax.numpy.moveaxis(x2, axis, -1)
65+
return jax.numpy.matmul(x1[..., None, :], x2[..., None])[..., 0, 0]

jax/experimental/array_api/_sorting_functions.py

-1
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,3 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
2828
if descending:
2929
return jax.lax.rev(result, dimensions=[axis])
3030
return result
31-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
17+
18+
def all(x, /, *, axis=None, keepdims=False):
19+
"""Tests whether all input array elements evaluate to True along a specified axis."""
20+
return jax.numpy.all(x, axis=axis, keepdims=keepdims)
21+
22+
23+
def any(x, /, *, axis=None, keepdims=False):
24+
"""Tests whether any input array element evaluates to True along a specified axis."""
25+
return jax.numpy.any(x, axis=axis, keepdims=keepdims)

0 commit comments

Comments
 (0)