Skip to content

Commit 7dced32

Browse files
lezcanoev-br
authored andcommitted
MAINT: simplify normalizer
1 parent 9d75cab commit 7dced32

File tree

1 file changed

+20
-30
lines changed

1 file changed

+20
-30
lines changed

torch_np/_normalizations.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,44 +82,34 @@ def normalize_ndarray(arg, name=None):
8282
import functools
8383

8484

85-
def normalize_this(arg, parm):
85+
def maybe_normalize(arg, parm):
8686
"""Normalize arg if a normalizer is registred."""
8787
normalizer = normalizers.get(parm.annotation, None)
88-
if normalizer:
89-
return normalizer(arg)
90-
else:
91-
# untyped arguments pass through
92-
return arg
88+
return normalizer(arg) if normalizer else arg
9389

9490

9591
def normalizer(func):
9692
@functools.wraps(func)
9793
def wrapped(*args, **kwds):
98-
sig = inspect.signature(func)
99-
sp = dict(sig.parameters)
100-
101-
# check for *args. If detected, duplicate the correspoding parameter
102-
# to have len(args) annotations for each element of *args.
103-
for j, param in enumerate(sp.values()):
104-
if param.kind == inspect.Parameter.VAR_POSITIONAL:
105-
sp.pop(param.name)
106-
variadic = {param.name + str(i): param for i in range(len(args))}
107-
variadic.update(sp)
108-
sp = variadic
109-
break
110-
111-
# normalize positional and keyword arguments
112-
# NB: extra unknown arguments: pass through, will raise in func(*lst) below
113-
lst = [normalize_this(arg, parm) for arg, parm in zip(args, sp.values())]
114-
lst += args[len(lst) :]
115-
116-
dct = {
117-
name: normalize_this(arg, sp[name]) if name in sp else arg
94+
params = inspect.signature(func).parameters
95+
first_param = next(iter(params.values()))
96+
# NumPy's API does not have positional args before variadic positional args
97+
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
98+
args = [maybe_normalize(arg, first_param) for arg in args]
99+
else:
100+
# NB: extra unknown arguments: pass through, will raise in func(*args) below
101+
args = (
102+
tuple(
103+
maybe_normalize(arg, parm)
104+
for arg, parm in zip(args, params.values())
105+
)
106+
+ args[len(params.values()) :]
107+
)
108+
109+
kwds = {
110+
name: maybe_normalize(arg, params[name]) if name in params else arg
118111
for name, arg in kwds.items()
119112
}
120-
121-
result = func(*lst, **dct)
122-
123-
return result
113+
return func(*args, **kwds)
124114

125115
return wrapped

0 commit comments

Comments
 (0)