From 00e7cceb338025d9428af2bb6afbe7eaac8cf414 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 15 Apr 2025 11:53:21 +0200 Subject: [PATCH] BUG: add torch.repeat --- array_api_compat/torch/_aliases.py | 7 ++++++- torch-xfails.txt | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..0a604b8c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -574,6 +574,11 @@ def count_nonzero( return result +# "repeat" is torch.repeat_interleave; also the dim argument +def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + return torch.repeat_interleave(x, repeats, axis) + + def where( condition: Array, x1: Array | bool | int | float | complex, @@ -854,6 +859,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] _all_ignore = ['torch', 'get_xp'] diff --git a/torch-xfails.txt b/torch-xfails.txt index e556fa4f..ab11f457 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -120,9 +120,8 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype array_api_tests/test_data_type_functions.py::test_iinfo_dtype # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] +# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers array_api_tests/test_manipulation_functions.py::test_repeat -array_api_tests/test_signatures.py::test_func_signature[repeat] # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature