Skip to content

Commit 8dc2628

Browse files
lezcanoev-br
authored andcommitted
MAINT: simplify normalizer
1 parent 9d75cab commit 8dc2628

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

torch_np/_normalizations.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,39 @@ def normalize_ndarray(arg, name=None):
8282
import functools
8383

8484

85+
def maybe_normalize(arg, parm):
86+
"""Normalize arg if a normalizer is registred."""
87+
normalizer = normalizers.get(parm.annotation, None)
88+
return normalizer(arg) if normalizer else arg
89+
90+
91+
def normalizer(func):
92+
@functools.wraps(func)
93+
def wrapped(*args, **kwds):
94+
params = inspect.signature(func).parameters
95+
first_param = next(iter(params.values()))
96+
args_copy = args[:]
97+
# NumPy's API does not have positional args before variadic positional args
98+
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
99+
args = [maybe_normalize(arg, first_param) for arg in args]
100+
else:
101+
args = [
102+
maybe_normalize(arg, parm) for arg, parm in zip(args, params.values())
103+
]
104+
105+
# NB: extra unknown arguments: pass through, will raise in func(*args) below
106+
args += args_copy[len(args) :]
107+
108+
kwds = {
109+
name: maybe_normalize(arg, params[name]) if name in params else arg
110+
for name, arg in kwds.items()
111+
}
112+
return func(*args, **kwds)
113+
114+
return wrapped
115+
116+
117+
'''
85118
def normalize_this(arg, parm):
86119
"""Normalize arg if a normalizer is registred."""
87120
normalizer = normalizers.get(parm.annotation, None)
@@ -123,3 +156,4 @@ def wrapped(*args, **kwds):
123156
return result
124157
125158
return wrapped
159+
'''

0 commit comments

Comments
 (0)