|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import operator |
| 16 | +from typing import List, Optional, Tuple, Union |
| 17 | + |
15 | 18 | import jax
|
| 19 | +from jax import Array |
16 | 20 | from jax.experimental.array_api._data_type_functions import result_type as _result_type
|
17 | 21 |
|
18 | 22 |
|
19 |
| -def broadcast_arrays(*arrays): |
| 23 | +def broadcast_arrays(*arrays: Array) -> List[Array]: |
20 | 24 | """Broadcasts one or more arrays against one another."""
|
21 | 25 | return jax.numpy.broadcast_arrays(*arrays)
|
22 | 26 |
|
23 | 27 |
|
24 |
| -def broadcast_to(x, /, shape): |
| 28 | +def broadcast_to(x: Array, /, shape: Tuple[int]) -> Array: |
25 | 29 | """Broadcasts an array to a specified shape."""
|
26 | 30 | return jax.numpy.broadcast_to(x, shape=shape)
|
27 | 31 |
|
28 | 32 |
|
29 |
| -def concat(arrays, /, *, axis=0): |
| 33 | +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: |
30 | 34 | """Joins a sequence of arrays along an existing axis."""
|
31 | 35 | dtype = _result_type(*arrays)
|
| 36 | + if axis is None: |
| 37 | + arrays = [reshape(arr, (arr.size,)) for arr in arrays] |
| 38 | + axis = 0 |
32 | 39 | return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype)
|
33 | 40 |
|
34 | 41 |
|
35 |
| -def expand_dims(x, /, *, axis=0): |
| 42 | +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: |
36 | 43 | """Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis."""
|
37 | 44 | return jax.lax.expand_dims(x, dimensions=[axis])
|
38 | 45 |
|
39 | 46 |
|
40 |
| -def flip(x, /, *, axis=None): |
| 47 | +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: |
41 | 48 | """Reverses the order of elements in an array along the given axis."""
|
42 |
| - dimensions = list(axis) if isinstance(axis, tuple) else [axis] |
| 49 | + if axis is None: |
| 50 | + dimensions = tuple(range(x.ndim)) |
| 51 | + elif isinstance(axis, int): |
| 52 | + dimensions = (axis,) |
| 53 | + elif isinstance(axis, tuple): |
| 54 | + dimensions = tuple(operator.index(ax) for ax in axis) |
| 55 | + else: |
| 56 | + raise TypeError(f"Unexpected input axis={axis}: expected None, int, or tuple of ints") |
43 | 57 | return jax.lax.rev(x, dimensions=dimensions)
|
44 | 58 |
|
45 | 59 |
|
46 |
| -def permute_dims(x, /, axes): |
| 60 | +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: |
47 | 61 | """Permutes the axes (dimensions) of an array x."""
|
48 | 62 | return jax.lax.transpose(x, axes)
|
49 | 63 |
|
50 | 64 |
|
51 |
| -def reshape(x, /, shape, *, copy=None): |
| 65 | +def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array: |
52 | 66 | """Reshapes an array without changing its data."""
|
53 | 67 | del copy # unused
|
54 | 68 | return jax.lax.reshape(x, shape)
|
55 | 69 |
|
56 | 70 |
|
57 |
| -def roll(x, /, shift, *, axis=None): |
| 71 | +def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: |
58 | 72 | """Rolls array elements along a specified axis."""
|
59 | 73 | return jax.numpy.roll(x, shift=shift, axis=axis)
|
60 | 74 |
|
61 | 75 |
|
62 |
| -def squeeze(x, /, axis): |
| 76 | +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: |
63 | 77 | """Removes singleton dimensions (axes) from x."""
|
64 |
| - dimensions = list(axis) if isinstance(axis, tuple) else [axis] |
| 78 | + dimensions = axis if isinstance(axis, tuple) else (axis,) |
65 | 79 | return jax.lax.squeeze(x, dimensions=dimensions)
|
66 | 80 |
|
67 | 81 |
|
68 |
| -def stack(arrays, /, *, axis=0): |
| 82 | +def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: |
69 | 83 | """Joins a sequence of arrays along a new axis."""
|
70 | 84 | dtype = _result_type(*arrays)
|
71 | 85 | return jax.numpy.stack(arrays, axis=axis, dtype=dtype)
|
0 commit comments