@@ -82,44 +82,31 @@ def normalize_ndarray(arg, name=None):
82
82
import functools
83
83
84
84
85
- def normalize_this (arg , parm ):
85
+ def maybe_normalize (arg , parm ):
86
86
"""Normalize arg if a normalizer is registred."""
87
87
normalizer = normalizers .get (parm .annotation , None )
88
- if normalizer :
89
- return normalizer (arg )
90
- else :
91
- # untyped arguments pass through
92
- return arg
88
+ return normalizer (arg ) if normalizer else arg
93
89
94
90
95
91
def normalizer (func ):
96
92
@functools .wraps (func )
97
93
def wrapped (* args , ** kwds ):
98
- sig = inspect .signature (func )
99
- sp = dict (sig .parameters )
100
-
101
- # check for *args. If detected, duplicate the correspoding parameter
102
- # to have len(args) annotations for each element of *args.
103
- for j , param in enumerate (sp .values ()):
104
- if param .kind == inspect .Parameter .VAR_POSITIONAL :
105
- sp .pop (param .name )
106
- variadic = {param .name + str (i ): param for i in range (len (args ))}
107
- variadic .update (sp )
108
- sp = variadic
109
- break
110
-
111
- # normalize positional and keyword arguments
112
- # NB: extra unknown arguments: pass through, will raise in func(*lst) below
113
- lst = [normalize_this (arg , parm ) for arg , parm in zip (args , sp .values ())]
114
- lst += args [len (lst ) :]
115
-
116
- dct = {
117
- name : normalize_this (arg , sp [name ]) if name in sp else arg
94
+ params = inspect .signature (func ).parameters
95
+ first_param = next (iter (params .values ()))
96
+ # NumPy's API does not have positional args before variadic positional args
97
+ if first_param .kind == inspect .Parameter .VAR_POSITIONAL :
98
+ args = [maybe_normalize (arg , first_param ) for arg in args ]
99
+ else :
100
+ # NB: extra unknown arguments: pass through, will raise in func(*args) below
101
+ args = tuple (
102
+ maybe_normalize (arg , parm ) for arg , parm in zip (args , params .values ())
103
+ ) + args [len (params .values ()) :]
104
+
105
+ kwds = {
106
+ name : maybe_normalize (arg , params [name ]) if name in params else arg
118
107
for name , arg in kwds .items ()
119
108
}
120
-
121
- result = func (* lst , ** dct )
122
-
123
- return result
109
+ return func (* args , ** kwds )
124
110
125
111
return wrapped
112
+
0 commit comments