Skip to content

Commit cc6cc7c

Browse files
lezcanoev-br
authored andcommitted
MAINT: simplify normalizer
1 parent 9d75cab commit cc6cc7c

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

torch_np/_normalizations.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,36 @@ def normalize_ndarray(arg, name=None):
8282
import functools
8383

8484

85+
def maybe_normalize(arg, parm):
86+
"""Normalize arg if a normalizer is registred."""
87+
normalizer = normalizers.get(parm.annotation, None)
88+
return normalizer(arg) if normalizer else arg
89+
90+
91+
def normalizer(func):
92+
@functools.wraps(func)
93+
def wrapped(*args, **kwds):
94+
params = inspect.signature(func).parameters
95+
first_param = next(iter(params.values()))
96+
# NumPy's API does not have positional args before variadic positional args
97+
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
98+
args = [maybe_normalize(arg, first_param) for arg in args]
99+
else:
100+
# NB: extra unknown arguments: pass through, will raise in func(*args) below
101+
args = [
102+
maybe_normalize(arg, parm) for arg, parm in zip(args, params.values())
103+
] + args[len(params.values()) :]
104+
105+
kwds = {
106+
name: maybe_normalize(arg, params[name]) if name in params else arg
107+
for name, arg in kwds.items()
108+
}
109+
return func(*args, **kwds)
110+
111+
return wrapped
112+
113+
114+
'''
85115
def normalize_this(arg, parm):
86116
"""Normalize arg if a normalizer is registred."""
87117
normalizer = normalizers.get(parm.annotation, None)
@@ -123,3 +153,4 @@ def wrapped(*args, **kwds):
123153
return result
124154
125155
return wrapped
156+
'''

0 commit comments

Comments
 (0)