Skip to content

Commit f5e5eaf

Browse files
lezcanoev-br
authored andcommitted
MAINT: simplify normalizer
1 parent 9d75cab commit f5e5eaf

File tree

1 file changed

+17
-30
lines changed

1 file changed

+17
-30
lines changed

torch_np/_normalizations.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,44 +82,31 @@ def normalize_ndarray(arg, name=None):
8282
import functools
8383

8484

85-
def normalize_this(arg, parm):
85+
def maybe_normalize(arg, parm):
8686
"""Normalize arg if a normalizer is registred."""
8787
normalizer = normalizers.get(parm.annotation, None)
88-
if normalizer:
89-
return normalizer(arg)
90-
else:
91-
# untyped arguments pass through
92-
return arg
88+
return normalizer(arg) if normalizer else arg
9389

9490

9591
def normalizer(func):
9692
@functools.wraps(func)
9793
def wrapped(*args, **kwds):
98-
sig = inspect.signature(func)
99-
sp = dict(sig.parameters)
100-
101-
# check for *args. If detected, duplicate the correspoding parameter
102-
# to have len(args) annotations for each element of *args.
103-
for j, param in enumerate(sp.values()):
104-
if param.kind == inspect.Parameter.VAR_POSITIONAL:
105-
sp.pop(param.name)
106-
variadic = {param.name + str(i): param for i in range(len(args))}
107-
variadic.update(sp)
108-
sp = variadic
109-
break
110-
111-
# normalize positional and keyword arguments
112-
# NB: extra unknown arguments: pass through, will raise in func(*lst) below
113-
lst = [normalize_this(arg, parm) for arg, parm in zip(args, sp.values())]
114-
lst += args[len(lst) :]
115-
116-
dct = {
117-
name: normalize_this(arg, sp[name]) if name in sp else arg
94+
params = inspect.signature(func).parameters
95+
first_param = next(iter(params.values()))
96+
# NumPy's API does not have positional args before variadic positional args
97+
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
98+
args = [maybe_normalize(arg, first_param) for arg in args]
99+
else:
100+
# NB: extra unknown arguments: pass through, will raise in func(*args) below
101+
args = tuple(
102+
maybe_normalize(arg, parm) for arg, parm in zip(args, params.values())
103+
) + args[len(params.values()) :]
104+
105+
kwds = {
106+
name: maybe_normalize(arg, params[name]) if name in params else arg
118107
for name, arg in kwds.items()
119108
}
120-
121-
result = func(*lst, **dct)
122-
123-
return result
109+
return func(*args, **kwds)
124110

125111
return wrapped
112+

0 commit comments

Comments
 (0)