Skip to content

Commit 70d207e

Browse files
committed
register namespace
1 parent e62afb7 commit 70d207e

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

jax/experimental/array_api/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,18 @@
185185
tensordot as tensordot,
186186
vecdot as vecdot,
187187
)
188+
189+
def _array_namespace(self, /, *, api_version: None | str = None):
190+
import sys
191+
if api_version is not None and api_version != __array_api_version__:
192+
raise ValueError(f"{api_version=!r} is not available; "
193+
f"available versions are: {[__array_api_version__]}")
194+
return sys.modules[__name__]
195+
196+
def _setup_array_type():
197+
# TODO(jakevdp): set on tracers as well?
198+
from jax._src.array import ArrayImpl
199+
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
200+
201+
_setup_array_type()
202+
del _setup_array_type

jax/experimental/array_api/_data_type_functions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,12 @@ def _promote_types(t1, t2):
122122

123123

124124
def astype(x, dtype, /, *, copy=True):
125-
return jnp.asarray(x, dtype=dtype, copy=copy)
125+
return jnp.array(x, dtype=dtype, copy=copy)
126126

127127

128128
def can_cast(from_, to, /):
129+
if isinstance(from_, jax.Array):
130+
from_ = from_.dtype
129131
if not _is_valid_dtype(from_):
130132
raise ValueError(f"{from_} is not a valid dtype")
131133
if not _is_valid_dtype(to):

jax/experimental/array_api/_manipulation_functions.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,74 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import operator
16+
from typing import List, Optional, Tuple, Union
17+
1518
import jax
19+
from jax import Array
1620
from jax.experimental.array_api._data_type_functions import result_type as _result_type
1721

1822

19-
def broadcast_arrays(*arrays):
23+
def broadcast_arrays(*arrays: Array) -> List[Array]:
2024
"""Broadcasts one or more arrays against one another."""
2125
return jax.numpy.broadcast_arrays(*arrays)
2226

2327

24-
def broadcast_to(x, /, shape):
28+
def broadcast_to(x: Array, /, shape: Tuple[int]) -> Array:
2529
"""Broadcasts an array to a specified shape."""
2630
return jax.numpy.broadcast_to(x, shape=shape)
2731

2832

29-
def concat(arrays, /, *, axis=0):
33+
def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
3034
"""Joins a sequence of arrays along an existing axis."""
3135
dtype = _result_type(*arrays)
36+
if axis is None:
37+
arrays = [reshape(arr, (arr.size,)) for arr in arrays]
38+
axis = 0
3239
return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype)
3340

3441

35-
def expand_dims(x, /, *, axis=0):
42+
def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
3643
"""Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis."""
3744
return jax.lax.expand_dims(x, dimensions=[axis])
3845

3946

40-
def flip(x, /, *, axis=None):
47+
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
4148
"""Reverses the order of elements in an array along the given axis."""
42-
dimensions = list(axis) if isinstance(axis, tuple) else [axis]
49+
if axis is None:
50+
dimensions = tuple(range(x.ndim))
51+
elif isinstance(axis, int):
52+
dimensions = (axis,)
53+
elif isinstance(axis, tuple):
54+
dimensions = tuple(operator.index(ax) for ax in axis)
55+
else:
56+
raise TypeError(f"Unexpected input axis={axis}: expected None, int, or tuple of ints")
4357
return jax.lax.rev(x, dimensions=dimensions)
4458

4559

46-
def permute_dims(x, /, axes):
60+
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
4761
"""Permutes the axes (dimensions) of an array x."""
4862
return jax.lax.transpose(x, axes)
4963

5064

51-
def reshape(x, /, shape, *, copy=None):
65+
def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
5266
"""Reshapes an array without changing its data."""
5367
del copy # unused
5468
return jax.lax.reshape(x, shape)
5569

5670

57-
def roll(x, /, shift, *, axis=None):
71+
def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
5872
"""Rolls array elements along a specified axis."""
5973
return jax.numpy.roll(x, shift=shift, axis=axis)
6074

6175

62-
def squeeze(x, /, axis):
76+
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
6377
"""Removes singleton dimensions (axes) from x."""
64-
dimensions = list(axis) if isinstance(axis, tuple) else [axis]
78+
dimensions = axis if isinstance(axis, tuple) else (axis,)
6579
return jax.lax.squeeze(x, dimensions=dimensions)
6680

6781

68-
def stack(arrays, /, *, axis=0):
82+
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
6983
"""Joins a sequence of arrays along a new axis."""
7084
dtype = _result_type(*arrays)
7185
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)

0 commit comments

Comments
 (0)