|
5 | 5 | from functools import wraps
|
6 | 6 | from inspect import signature
|
7 | 7 |
|
8 |
| -from .common._helpers import get_namespace |
9 |
| - |
10 |
| -def get_xp(f): |
| 8 | +def get_xp(xp): |
11 | 9 | """
|
12 |
| - Decorator to automatically replace xp with the corresponding array module |
| 10 | + Decorator to automatically replace xp with the corresponding array module. |
13 | 11 |
|
14 | 12 | Use like
|
15 | 13 |
|
16 |
| - @get_xp |
| 14 | + import numpy as np |
| 15 | +
|
| 16 | + @get_xp(np) |
17 | 17 | def func(x, /, xp, kwarg=None):
|
18 | 18 | return xp.func(x, kwarg=kwarg)
|
19 | 19 |
|
20 |
| - Note that xp must be able to be passed as a keyword argument. |
| 20 | + Note that xp must be a keyword argument and come after all non-keyword |
| 21 | + arguments. |
| 22 | +
|
21 | 23 | """
|
22 |
| - @wraps(f) |
23 |
| - def inner(*args, **kwargs): |
24 |
| - xp = get_namespace(*args, _use_compat=False) |
25 |
| - return f(*args, xp=xp, **kwargs) |
| 24 | + def inner(f): |
| 25 | + sig = signature(f) |
| 26 | + |
| 27 | + @wraps(f) |
| 28 | + def wrapped_f(*args, **kwargs): |
| 29 | + return f(*args, xp=xp, **kwargs) |
26 | 30 |
|
27 |
| - sig = signature(f) |
28 |
| - new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) |
| 31 | + new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) |
29 | 32 |
|
30 |
| - if inner.__doc__ is None: |
31 |
| - inner.__doc__ = f"""\ |
| 33 | + if wrapped_f.__doc__ is None: |
| 34 | + wrapped_f.__doc__ = f"""\ |
32 | 35 | Array API compatibility wrapper for {f.__name__}.
|
33 | 36 |
|
34 | 37 | See the corresponding documentation in NumPy/CuPy and/or the array API
|
35 | 38 | specification for more details.
|
36 | 39 |
|
37 | 40 | """
|
38 |
| - inner.__signature__ = new_sig |
| 41 | + # wrapped_f.__signature__ = new_sig |
| 42 | + return wrapped_f |
39 | 43 |
|
40 | 44 | return inner
|
0 commit comments