Skip to content

Commit 05cfd66

Browse files
committed
ENH: on failure, the normalizer decorator either raise (default) or return a value
On failure, @ normalizer without arguments (default) raises, while @Normalizer(return_or_failure=value) return a specified `value`. This is useful, e.g. for @-normalizer def is_scalar(a: ArrayLike): .... Consider is_scalar(int): the argument cannot be converted to tensor, so the function returns False. But the actual failure is in the decorator, so cannot do try-except in the body of the function!
1 parent b858ade commit 05cfd66

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed

torch_np/_normalizations.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -81,35 +81,49 @@ def normalize_ndarray(arg, name=None):
8181

8282
import functools
8383

84+
_sentinel = object()
8485

85-
def maybe_normalize(arg, parm):
86+
87+
def maybe_normalize(arg, parm, return_on_failure=_sentinel):
8688
"""Normalize arg if a normalizer is registred."""
8789
normalizer = normalizers.get(parm.annotation, None)
88-
return normalizer(arg) if normalizer else arg
89-
90-
91-
def normalizer(func):
92-
@functools.wraps(func)
93-
def wrapped(*args, **kwds):
94-
params = inspect.signature(func).parameters
95-
first_param = next(iter(params.values()))
96-
# NumPy's API does not have positional args before variadic positional args
97-
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
98-
args = [maybe_normalize(arg, first_param) for arg in args]
90+
try:
91+
return normalizer(arg) if normalizer else arg
92+
except Exception as exc:
93+
if return_on_failure is not _sentinel:
94+
return return_on_failure
9995
else:
100-
# NB: extra unknown arguments: pass through, will raise in func(*args) below
101-
args = (
102-
tuple(
103-
maybe_normalize(arg, parm)
104-
for arg, parm in zip(args, params.values())
96+
raise exc from None
97+
98+
99+
def normalizer(_func=None, *, return_on_failure=_sentinel):
100+
def normalizer_inner(func):
101+
@functools.wraps(func)
102+
def wrapped(*args, **kwds):
103+
params = inspect.signature(func).parameters
104+
first_param = next(iter(params.values()))
105+
# NumPy's API does not have positional args before variadic positional args
106+
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
107+
args = [maybe_normalize(arg, first_param, return_on_failure) for arg in args]
108+
else:
109+
# NB: extra unknown arguments: pass through, will raise in func(*args) below
110+
args = (
111+
tuple(
112+
maybe_normalize(arg, parm, return_on_failure)
113+
for arg, parm in zip(args, params.values())
114+
)
115+
+ args[len(params.values()) :]
105116
)
106-
+ args[len(params.values()) :]
107-
)
108117

109-
kwds = {
110-
name: maybe_normalize(arg, params[name]) if name in params else arg
111-
for name, arg in kwds.items()
112-
}
113-
return func(*args, **kwds)
118+
kwds = {
119+
name: maybe_normalize(arg, params[name]) if name in params else arg
120+
for name, arg in kwds.items()
121+
}
122+
return func(*args, **kwds)
123+
return wrapped
124+
125+
if _func is None:
126+
return normalizer_inner
127+
else:
128+
return normalizer_inner(_func)
114129

115-
return wrapped

torch_np/_wrapper.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -775,15 +775,12 @@ def i0(x: ArrayLike):
775775
return _helpers.array_from(result)
776776

777777

778-
def isscalar(a):
778+
@normalizer(return_on_failure=False)
779+
def isscalar(a: ArrayLike):
779780
# XXX: this is a stub
780-
try:
781-
from ._ndarray import asarray
782-
783-
t = asarray(a).get()
784-
return t.numel() == 1
785-
except Exception:
786-
return False
781+
if a is False:
782+
return a
783+
return a.numel() == 1
787784

788785

789786
@normalizer

0 commit comments

Comments
 (0)