Skip to content

Commit a394a88

Browse files
lezcanoev-br
authored andcommitted
MAINT: simplify normalizer
1 parent 9d75cab commit a394a88

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

torch_np/_normalizations.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,34 @@ def normalize_ndarray(arg, name=None):
8282
import functools
8383

8484

85+
def maybe_normalize(arg, parm):
86+
"""Normalize arg if a normalizer is registred."""
87+
normalizer = normalizers.get(parm.annotation, None)
88+
return normalizer(arg) if normalizer else arg
89+
90+
def normalizer(func):
91+
def wrapped(*args, **kwds):
92+
params = inspect.signature(func).parameters
93+
first_param = next(iter(params.values()))
94+
args_copy = args[:]
95+
# NumPy's API does not have positional args before variadic positional args
96+
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
97+
args = [maybe_normalize(arg, first_param) for arg in args]
98+
else:
99+
args = [maybe_normalize(arg, parm) for arg, parm in zip(args, params.values())]
100+
101+
# NB: extra unknown arguments: pass through, will raise in func(*args) below
102+
args += args_copy[len(args) :]
103+
104+
kwds = {
105+
name: maybe_normalize(arg, params[name]) if name in params else arg
106+
for name, arg in kwds.items()
107+
}
108+
return func(*args, **kwds)
109+
return wrapped
110+
111+
112+
'''
85113
def normalize_this(arg, parm):
86114
"""Normalize arg if a normalizer is registred."""
87115
normalizer = normalizers.get(parm.annotation, None)
@@ -123,3 +151,4 @@ def wrapped(*args, **kwds):
123151
return result
124152
125153
return wrapped
154+
'''

0 commit comments

Comments
 (0)