Skip to content

Commit 2912c9e

Browse files
committed
Refactor how get_xp works
Now instead of guessing the array library from the input with get_namespace, the array library is hard-coded into the wrapped function based on which subnamespace it is imported from.
1 parent 005852f commit 2912c9e

File tree

8 files changed

+301
-213
lines changed

8 files changed

+301
-213
lines changed

array_api_compat/_internal.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,40 @@
55
from functools import wraps
66
from inspect import signature
77

8-
from .common._helpers import get_namespace
9-
10-
def get_xp(f):
8+
def get_xp(xp):
119
"""
12-
Decorator to automatically replace xp with the corresponding array module
10+
Decorator to automatically replace xp with the corresponding array module.
1311
1412
Use like
1513
16-
@get_xp
14+
import numpy as np
15+
16+
@get_xp(np)
1717
def func(x, /, xp, kwarg=None):
1818
return xp.func(x, kwarg=kwarg)
1919
20-
Note that xp must be able to be passed as a keyword argument.
20+
Note that xp must be a keyword argument and come after all non-keyword
21+
arguments.
22+
2123
"""
22-
@wraps(f)
23-
def inner(*args, **kwargs):
24-
xp = get_namespace(*args, _use_compat=False)
25-
return f(*args, xp=xp, **kwargs)
24+
def inner(f):
25+
sig = signature(f)
26+
27+
@wraps(f)
28+
def wrapped_f(*args, **kwargs):
29+
return f(*args, xp=xp, **kwargs)
2630

27-
sig = signature(f)
28-
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
31+
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
2932

30-
if inner.__doc__ is None:
31-
inner.__doc__ = f"""\
33+
if wrapped_f.__doc__ is None:
34+
wrapped_f.__doc__ = f"""\
3235
Array API compatibility wrapper for {f.__name__}.
3336
3437
See the corresponding documentation in NumPy/CuPy and/or the array API
3538
specification for more details.
3639
3740
"""
38-
inner.__signature__ = new_sig
41+
# wrapped_f.__signature__ = new_sig
42+
return wrapped_f
3943

4044
return inner

0 commit comments

Comments
 (0)