Skip to content

Commit 35e647b

Browse files
committed
MAINT: simplify postprocessing in normalizer
1 parent fc372c2 commit 35e647b

File tree

1 file changed

+47
-28
lines changed

1 file changed

+47
-28
lines changed

torch_np/_normalizations.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
22
"""
3+
import functools
4+
import inspect
35
import operator
46
import typing
57
from typing import Optional, Sequence, Union
68

79
import torch
810

9-
from . import _helpers
11+
from . import _dtypes, _helpers
1012

1113
ArrayLike = typing.TypeVar("ArrayLike")
1214
DTypeLike = typing.TypeVar("DTypeLike")
@@ -22,10 +24,6 @@
2224
NDArrayOrSequence = Union[NDArray, Sequence[NDArray]]
2325
OutArray = typing.TypeVar("OutArray")
2426

25-
import inspect
26-
27-
from . import _dtypes
28-
2927

3028
def normalize_array_like(x, name=None):
3129
(tensor,) = _helpers.to_tensors(x)
@@ -87,7 +85,6 @@ def normalize_ndarray(arg, name=None):
8785
AxisLike: normalize_axis_like,
8886
}
8987

90-
import functools
9188

9289
_sentinel = object()
9390

@@ -108,6 +105,44 @@ def normalize_this(arg, parm, return_on_failure=_sentinel):
108105
return arg
109106

110107

108+
# postprocess return values
109+
110+
111+
def postprocess_ndarray(result, **kwds):
112+
return _helpers.array_from(result)
113+
114+
115+
def postprocess_out(result, **kwds):
116+
result, out = result
117+
return _helpers.result_or_out(result, out, **kwds)
118+
119+
120+
def postprocess_tuple(result, **kwds):
121+
return _helpers.tuple_arrays_from(result)
122+
123+
124+
def postprocess_list(result, **kwds):
125+
return list(_helpers.tuple_arrays_from(result))
126+
127+
128+
def postprocess_variadic(result, **kwds):
129+
# a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
130+
if isinstance(result, (tuple, list)):
131+
seq = type(result)
132+
return seq(_helpers.tuple_arrays_from(result))
133+
else:
134+
return _helpers.array_from(result)
135+
136+
137+
postprocessors = {
138+
NDArray: postprocess_ndarray,
139+
OutArray: postprocess_out,
140+
NDArrayOrSequence: postprocess_variadic,
141+
tuple[NDArray]: postprocess_tuple,
142+
list[NDArray]: postprocess_list,
143+
}
144+
145+
111146
def normalizer(_func=None, *, return_on_failure=_sentinel, promote_scalar_out=False):
112147
def normalizer_inner(func):
113148
@functools.wraps(func)
@@ -154,33 +189,17 @@ def wrapped(*args, **kwds):
154189
raise TypeError(
155190
f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given."
156191
)
192+
157193
# finally, pass normalized arguments through
158194
result = func(*ba.args, **ba.kwargs)
159195

160196
# handle returns
161197
r = sig.return_annotation
162-
if r == NDArray:
163-
return _helpers.array_from(result)
164-
elif r == inspect._empty:
165-
return result
166-
elif hasattr(r, "__origin__") and r.__origin__ in (list, tuple):
167-
# this is tuple[NDArray] or list[NDArray]
168-
# XXX: change to separate tuple and list normalizers?
169-
return r.__origin__(_helpers.tuple_arrays_from(result))
170-
elif r == NDArrayOrSequence:
171-
# a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
172-
if isinstance(result, (tuple, list)):
173-
seq = type(result)
174-
return seq(_helpers.tuple_arrays_from(result))
175-
else:
176-
return _helpers.array_from(result)
177-
elif r == OutArray:
178-
result, out = result
179-
return _helpers.result_or_out(
180-
result, out, promote_scalar=promote_scalar_out
181-
)
182-
else:
183-
raise ValueError(f"Unknown return annotation {return_annotation}")
198+
postprocess = postprocessors.get(r, None)
199+
if postprocess:
200+
kwds = {"promote_scalar": promote_scalar_out}
201+
result = postprocess(result, **kwds)
202+
return result
184203

185204
return wrapped
186205

0 commit comments

Comments
 (0)