diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 335008e4..de5d1a5d 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,7 +2,7 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union, Literal import torch @@ -828,6 +828,12 @@ def sign(x: Array, /) -> Array: return out +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]: + # enforce the default of 'xy' + # TODO: is the return type a list or a tuple + return list(torch.meshgrid(*arrays, indexing='xy')) + + __all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', @@ -844,6 +850,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] _all_ignore = ['torch', 'get_xp'] diff --git a/tests/test_torch.py b/tests/test_torch.py index e8340f31..7adb4ab3 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -100,3 +100,20 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b): assert dtype_1 == dtype_2 finally: torch.set_default_dtype(prev_default) + + +def test_meshgrid(): + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" + + x, y = xp.asarray([1, 2]), xp.asarray([4]) + + X, Y = xp.meshgrid(x, y) + + # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different + X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]]) + + assert X.shape == X_xy.shape + assert xp.all(X == X_xy) + + assert Y.shape == Y_xy.shape + assert xp.all(Y == Y_xy)