1
1
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
2
2
"""
3
+ import functools
4
+ import inspect
3
5
import operator
4
6
import typing
5
7
from typing import Optional , Sequence , Union
6
8
7
9
import torch
8
10
9
- from . import _helpers
11
+ from . import _dtypes , _helpers
10
12
11
13
ArrayLike = typing .TypeVar ("ArrayLike" )
12
14
DTypeLike = typing .TypeVar ("DTypeLike" )
22
24
NDArrayOrSequence = Union [NDArray , Sequence [NDArray ]]
23
25
OutArray = typing .TypeVar ("OutArray" )
24
26
25
- import inspect
26
-
27
- from . import _dtypes
28
-
29
27
30
28
def normalize_array_like (x , name = None ):
31
29
(tensor ,) = _helpers .to_tensors (x )
@@ -87,7 +85,6 @@ def normalize_ndarray(arg, name=None):
87
85
AxisLike : normalize_axis_like ,
88
86
}
89
87
90
- import functools
91
88
92
89
_sentinel = object ()
93
90
@@ -108,6 +105,44 @@ def normalize_this(arg, parm, return_on_failure=_sentinel):
108
105
return arg
109
106
110
107
108
+ # postprocess return values
109
+
110
+
111
+ def postprocess_ndarray (result , ** kwds ):
112
+ return _helpers .array_from (result )
113
+
114
+
115
+ def postprocess_out (result , ** kwds ):
116
+ result , out = result
117
+ return _helpers .result_or_out (result , out , ** kwds )
118
+
119
+
120
+ def postprocess_tuple (result , ** kwds ):
121
+ return _helpers .tuple_arrays_from (result )
122
+
123
+
124
+ def postprocess_list (result , ** kwds ):
125
+ return list (_helpers .tuple_arrays_from (result ))
126
+
127
+
128
+ def postprocess_variadic (result , ** kwds ):
129
+ # a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
130
+ if isinstance (result , (tuple , list )):
131
+ seq = type (result )
132
+ return seq (_helpers .tuple_arrays_from (result ))
133
+ else :
134
+ return _helpers .array_from (result )
135
+
136
+
137
+ postprocessors = {
138
+ NDArray : postprocess_ndarray ,
139
+ OutArray : postprocess_out ,
140
+ NDArrayOrSequence : postprocess_variadic ,
141
+ tuple [NDArray ]: postprocess_tuple ,
142
+ list [NDArray ]: postprocess_list ,
143
+ }
144
+
145
+
111
146
def normalizer (_func = None , * , return_on_failure = _sentinel , promote_scalar_out = False ):
112
147
def normalizer_inner (func ):
113
148
@functools .wraps (func )
@@ -154,33 +189,17 @@ def wrapped(*args, **kwds):
154
189
raise TypeError (
155
190
f"{ func .__name__ } () takes { len (ba .args )} positional argument but { len (args )} were given."
156
191
)
192
+
157
193
# finally, pass normalized arguments through
158
194
result = func (* ba .args , ** ba .kwargs )
159
195
160
196
# handle returns
161
197
r = sig .return_annotation
162
- if r == NDArray :
163
- return _helpers .array_from (result )
164
- elif r == inspect ._empty :
165
- return result
166
- elif hasattr (r , "__origin__" ) and r .__origin__ in (list , tuple ):
167
- # this is tuple[NDArray] or list[NDArray]
168
- # XXX: change to separate tuple and list normalizers?
169
- return r .__origin__ (_helpers .tuple_arrays_from (result ))
170
- elif r == NDArrayOrSequence :
171
- # a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
172
- if isinstance (result , (tuple , list )):
173
- seq = type (result )
174
- return seq (_helpers .tuple_arrays_from (result ))
175
- else :
176
- return _helpers .array_from (result )
177
- elif r == OutArray :
178
- result , out = result
179
- return _helpers .result_or_out (
180
- result , out , promote_scalar = promote_scalar_out
181
- )
182
- else :
183
- raise ValueError (f"Unknown return annotation { return_annotation } " )
198
+ postprocess = postprocessors .get (r , None )
199
+ if postprocess :
200
+ kwds = {"promote_scalar" : promote_scalar_out }
201
+ result = postprocess (result , ** kwds )
202
+ return result
184
203
185
204
return wrapped
186
205
0 commit comments