Skip to content

Commit 9726bc0

Browse files
committed
Add a draft reciprocal function for 2024.12
1 parent 548f071 commit 9726bc0

File tree

4 files changed

+15
-0
lines changed

4 files changed

+15
-0
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@
177177
positive,
178178
pow,
179179
real,
180+
reciprocal,
180181
remainder,
181182
round,
182183
sign,
@@ -246,6 +247,7 @@
246247
"positive",
247248
"pow",
248249
"real",
250+
"reciprocal",
249251
"remainder",
250252
"round",
251253
"sign",

array_api_strict/_elementwise_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,17 @@ def real(x: Array, /) -> Array:
872872
return Array._new(np.real(x._array), device=x.device)
873873

874874

875+
@requires_api_version('2024.12')
876+
def reciprocal(x: Array, /) -> Array:
877+
"""
878+
Array API compatible wrapper for :py:func:`np.reciprocal <numpy.reciprocal>`.
879+
880+
See its docstring for more information.
881+
"""
882+
if x.dtype not in _floating_dtypes:
883+
raise TypeError("Only floating-point dtypes are allowed in reciprocal")
884+
return Array._new(np.reciprocal(x._array), device=x.device)
885+
875886
def remainder(x1: Array, x2: Array, /) -> Array:
876887
"""
877888
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def nargs(func):
8686
"positive": "numeric",
8787
"pow": "numeric",
8888
"real": "complex floating-point",
89+
"reciprocal": "floating-point",
8990
"remainder": "real numeric",
9091
"round": "numeric",
9192
"sign": "numeric",

array_api_strict/tests/test_flags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def test_fft(func_name):
287287
api_version_2024_12_examples = {
288288
'diff': lambda: xp.diff(xp.asarray([0, 1, 2])),
289289
'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)),
290+
'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])),
290291
}
291292

292293
@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys())

0 commit comments

Comments
 (0)