Skip to content

Commit 9302568

Browse files
committed
MAINT: remove return_on_failure at normalizer
1 parent 011ecbf commit 9302568

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

torch_np/_funcs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
NDArray,
2020
OutArray,
2121
SubokLike,
22+
normalize_array_like,
2223
normalizer,
2324
)
2425

@@ -1789,10 +1790,13 @@ def i0(x: ArrayLike):
17891790
return torch.special.i0(x)
17901791

17911792

1792-
@normalizer(return_on_failure=False)
1793-
def isscalar(a: ArrayLike):
1793+
def isscalar(a):
17941794
# XXX: this is a stub
1795-
return a.numel() == 1
1795+
try:
1796+
t = normalize_array_like(a)
1797+
return t.numel() == 1
1798+
except Exception:
1799+
return False
17961800

17971801

17981802
"""

torch_np/_normalizations.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
22
"""
3+
import functools
34
import operator
45
import typing
56
from typing import Optional, Sequence
@@ -81,21 +82,11 @@ def normalize_ndarray(arg, name=None):
8182
AxisLike: normalize_axis_like,
8283
}
8384

84-
import functools
85-
86-
_sentinel = object()
8785

88-
89-
def maybe_normalize(arg, parm, return_on_failure=_sentinel):
86+
def maybe_normalize(arg, parm):
9087
"""Normalize arg if a normalizer is registred."""
9188
normalizer = normalizers.get(parm.annotation, None)
92-
try:
93-
return normalizer(arg, parm.name) if normalizer else arg
94-
except Exception as exc:
95-
if return_on_failure is not _sentinel:
96-
return return_on_failure
97-
else:
98-
raise exc from None
89+
return normalizer(arg, parm.name) if normalizer else arg
9990

10091

10192
# ### Return value helpers ###
@@ -152,7 +143,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
152143
# ### The main decorator to normalize arguments / postprocess the output ###
153144

154145

155-
def normalizer(_func=None, *, return_on_failure=_sentinel, promote_scalar_result=False):
146+
def normalizer(_func=None, *, promote_scalar_result=False):
156147
def normalizer_inner(func):
157148
@functools.wraps(func)
158149
def wrapped(*args, **kwds):
@@ -161,14 +152,12 @@ def wrapped(*args, **kwds):
161152
first_param = next(iter(params.values()))
162153
# NumPy's API does not have positional args before variadic positional args
163154
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
164-
args = [
165-
maybe_normalize(arg, first_param, return_on_failure) for arg in args
166-
]
155+
args = [maybe_normalize(arg, first_param) for arg in args]
167156
else:
168157
# NB: extra unknown arguments: pass through, will raise in func(*args) below
169158
args = (
170159
tuple(
171-
maybe_normalize(arg, parm, return_on_failure)
160+
maybe_normalize(arg, parm)
172161
for arg, parm in zip(args, params.values())
173162
)
174163
+ args[len(params.values()) :]

0 commit comments

Comments
 (0)