1
1
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
2
2
"""
3
+ import functools
3
4
import operator
4
5
import typing
5
6
from typing import Optional , Sequence
@@ -81,21 +82,11 @@ def normalize_ndarray(arg, name=None):
81
82
AxisLike : normalize_axis_like ,
82
83
}
83
84
84
- import functools
85
-
86
- _sentinel = object ()
87
85
88
-
89
- def maybe_normalize (arg , parm , return_on_failure = _sentinel ):
86
+ def maybe_normalize (arg , parm ):
90
87
"""Normalize arg if a normalizer is registred."""
91
88
normalizer = normalizers .get (parm .annotation , None )
92
- try :
93
- return normalizer (arg , parm .name ) if normalizer else arg
94
- except Exception as exc :
95
- if return_on_failure is not _sentinel :
96
- return return_on_failure
97
- else :
98
- raise exc from None
89
+ return normalizer (arg , parm .name ) if normalizer else arg
99
90
100
91
101
92
# ### Return value helpers ###
@@ -152,7 +143,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
152
143
# ### The main decorator to normalize arguments / postprocess the output ###
153
144
154
145
155
- def normalizer (_func = None , * , return_on_failure = _sentinel , promote_scalar_result = False ):
146
+ def normalizer (_func = None , * , promote_scalar_result = False ):
156
147
def normalizer_inner (func ):
157
148
@functools .wraps (func )
158
149
def wrapped (* args , ** kwds ):
@@ -161,14 +152,12 @@ def wrapped(*args, **kwds):
161
152
first_param = next (iter (params .values ()))
162
153
# NumPy's API does not have positional args before variadic positional args
163
154
if first_param .kind == inspect .Parameter .VAR_POSITIONAL :
164
- args = [
165
- maybe_normalize (arg , first_param , return_on_failure ) for arg in args
166
- ]
155
+ args = [maybe_normalize (arg , first_param ) for arg in args ]
167
156
else :
168
157
# NB: extra unknown arguments: pass through, will raise in func(*args) below
169
158
args = (
170
159
tuple (
171
- maybe_normalize (arg , parm , return_on_failure )
160
+ maybe_normalize (arg , parm )
172
161
for arg , parm in zip (args , params .values ())
173
162
)
174
163
+ args [len (params .values ()) :]
0 commit comments