Skip to content

Commit d3d69dd

Browse files
committed
ENH: on failure, the normalizer decorator either raise (default) or return a value
On failure, @ normalizer without arguments (default) raises, while @Normalizer(return_or_failure=value) return a specified `value`. This is useful, e.g. for @-normalizer def is_scalar(a: ArrayLike): .... Consider is_scalar(int): the argument cannot be converted to tensor, so the function returns False. But the actual failure is in the decorator, so cannot do try-except in the body of the function!
1 parent 35985c0 commit d3d69dd

File tree

2 files changed

+84
-73
lines changed

2 files changed

+84
-73
lines changed

torch_np/_normalizations.py

Lines changed: 79 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -85,80 +85,94 @@ def normalize_ndarray(arg, name=None):
8585

8686
import functools
8787

88+
_sentinel = object()
8889

89-
def normalize_this(arg, parm):
90+
91+
def normalize_this(arg, parm, return_on_failure=_sentinel):
9092
"""Normalize arg if a normalizer is registred."""
9193
normalizer = normalizers.get(parm.annotation, None)
9294
if normalizer:
93-
return normalizer(arg)
95+
try:
96+
return normalizer(arg)
97+
except Exception as exc:
98+
if return_on_failure is not _sentinel:
99+
return return_on_failure
100+
else:
101+
raise exc from None
94102
else:
95103
# untyped arguments pass through
96104
return arg
97105

98106

99-
def normalizer(func):
100-
@functools.wraps(func)
101-
def wrapped(*args, **kwds):
102-
sig = inspect.signature(func)
103-
104-
# first, check for *args in positional parameters. Case in point:
105-
# atleast_1d(*arys: UnpackedSequenceArrayLike)
106-
# if found, consume all args into a tuple to normalize as a whole
107-
for j, param in enumerate(sig.parameters.values()):
108-
if param.annotation == UnpackedSeqArrayLike:
109-
if j == 0:
110-
args = (args,)
111-
else:
112-
# args = args[:j] + (args[j:],) would likely work
113-
# not present in numpy codebase, so do not bother just yet.
114-
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
115-
raise NotImplementedError
116-
break
117-
118-
# loop over positional parameters and actual arguments
119-
lst, dct = [], {}
120-
for arg, (name, parm) in zip(args, sig.parameters.items()):
121-
print(arg, name, parm.annotation)
122-
lst.append(normalize_this(arg, parm))
123-
124-
# normalize keyword arguments
125-
for name, arg in kwds.items():
126-
if not name in sig.parameters:
127-
# unknown kwarg, bail out
107+
def normalizer(_func=None, *, return_on_failure=_sentinel):
108+
def normalizer_inner(func):
109+
@functools.wraps(func)
110+
def wrapped(*args, **kwds):
111+
sig = inspect.signature(func)
112+
113+
# first, check for *args in positional parameters. Case in point:
114+
# atleast_1d(*arys: UnpackedSequenceArrayLike)
115+
# if found, consume all args into a tuple to normalize as a whole
116+
for j, param in enumerate(sig.parameters.values()):
117+
if param.annotation == UnpackedSeqArrayLike:
118+
if j == 0:
119+
args = (args,)
120+
else:
121+
# args = args[:j] + (args[j:],) would likely work
122+
# not present in numpy codebase, so do not bother just yet.
123+
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
124+
raise NotImplementedError
125+
break
126+
127+
# loop over positional parameters and actual arguments
128+
lst, dct = [], {}
129+
for arg, (name, parm) in zip(args, sig.parameters.items()):
130+
print(arg, name, parm.annotation)
131+
lst.append(normalize_this(arg, parm, return_on_failure))
132+
133+
# normalize keyword arguments
134+
for name, arg in kwds.items():
135+
if not name in sig.parameters:
136+
# unknown kwarg, bail out
137+
raise TypeError(
138+
f"{func.__name__}() got an unexpected keyword argument '{name}'."
139+
)
140+
141+
print("kw: ", name, sig.parameters[name].annotation)
142+
parm = sig.parameters[name]
143+
dct[name] = normalize_this(arg, parm, return_on_failure)
144+
145+
ba = sig.bind(*lst, **dct)
146+
ba.apply_defaults()
147+
148+
# Now that all parameters have been consumed, check:
149+
# Anything that has not been bound is unexpected positional arg => raise.
150+
# If there are too few actual arguments, this fill fail in func(*ba.args) below
151+
if len(args) > len(ba.args):
128152
raise TypeError(
129-
f"{func.__name__}() got an unexpected keyword argument '{name}'."
153+
f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given."
130154
)
131155

132-
print("kw: ", name, sig.parameters[name].annotation)
133-
parm = sig.parameters[name]
134-
dct[name] = normalize_this(arg, parm)
135-
136-
ba = sig.bind(*lst, **dct)
137-
ba.apply_defaults()
138-
139-
# Now that all parameters have been consumed, check:
140-
# Anything that has not been bound is unexpected positional arg => raise.
141-
# If there are too few actual arguments, this fill fail in func(*ba.args) below
142-
if len(args) > len(ba.args):
143-
raise TypeError(
144-
f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given."
145-
)
146-
147-
# TODO:
148-
# 1. [LOOKS OK] kw-only parameters : see vstack
149-
# 2. [LOOKS OK] extra unknown args -- error out : nonzero([2, 0, 3], oops=42)
150-
# 3. [LOOKS OK] optional (tensor_or_none) : untyped => pass through
151-
# 4. [LOOKS OK] DTypeLike : positional or kw
152-
# 5. axes : live in _impl or in types? several ways of handling them
153-
# 6. [OK, NOT HERE] keepdims : peel off, postprocess
154-
# 7. OutLike : normal & keyword-only, peel off, postprocess
155-
# 8. [LOOKS OK] *args
156-
# 9. [LOOKS OK] consolidate normalizations (_funcs, _wrapper)
157-
# 10. [LOOKS OK] consolidate decorators (_{unary,binary}_ufuncs)
158-
# 11. out= arg : validate it's an ndarray
159-
160-
# finally, pass normalized arguments through
161-
result = func(*ba.args, **ba.kwargs)
162-
return result
163-
164-
return wrapped
156+
# TODO:
157+
# 1. [LOOKS OK] kw-only parameters : see vstack
158+
# 2. [LOOKS OK] extra unknown args -- error out : nonzero([2, 0, 3], oops=42)
159+
# 3. [LOOKS OK] optional (tensor_or_none) : untyped => pass through
160+
# 4. [LOOKS OK] DTypeLike : positional or kw
161+
# 5. axes : live in _impl or in types? several ways of handling them
162+
# 6. [OK, NOT HERE] keepdims : peel off, postprocess
163+
# 7. OutLike : normal & keyword-only, peel off, postprocess
164+
# 8. [LOOKS OK] *args
165+
# 9. [LOOKS OK] consolidate normalizations (_funcs, _wrapper)
166+
# 10. [LOOKS OK] consolidate decorators (_{unary,binary}_ufuncs)
167+
# 11. out= arg : validate it's an ndarray
168+
169+
# finally, pass normalized arguments through
170+
result = func(*ba.args, **ba.kwargs)
171+
return result
172+
173+
return wrapped
174+
175+
if _func is None:
176+
return normalizer_inner
177+
else:
178+
return normalizer_inner(_func)

torch_np/_wrapper.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -789,15 +789,12 @@ def i0(x: ArrayLike):
789789
return _helpers.array_from(result)
790790

791791

792-
def isscalar(a):
792+
@normalizer(return_on_failure=False)
793+
def isscalar(a: ArrayLike):
793794
# XXX: this is a stub
794-
try:
795-
from ._ndarray import asarray
796-
797-
t = asarray(a).get()
798-
return t.numel() == 1
799-
except Exception:
800-
return False
795+
if a is False:
796+
return a
797+
return a.numel() == 1
801798

802799

803800
@normalizer

0 commit comments

Comments
 (0)