|
14 | 14 | SubokLike = typing.TypeVar("SubokLike")
|
15 | 15 | AxisLike = typing.TypeVar("AxisLike")
|
16 | 16 | NDArray = typing.TypeVar("NDarray")
|
| 17 | + |
| 18 | +# OutArray is to annotate the out= array argument. |
| 19 | +# |
| 20 | +# This one is special is several respects: |
| 21 | +# First, It needs to be an NDArray, and we need to preserve the `result is out` |
| 22 | +# semantics. Therefore, we cannot just extract the Tensor from the out array. |
| 23 | +# So we never pass the out array to implementer functions and handle it in the |
| 24 | +# `normalizer` below. |
| 25 | +# Second, the out= argument can be either keyword or positional argument, and |
| 26 | +# as a positional arg, it can be anywhere in the signature. |
| 27 | +# To handle all this, we define a special `OutArray` annotation and dispatch on it. |
| 28 | +# |
17 | 29 | OutArray = typing.TypeVar("OutArray")
|
18 | 30 |
|
19 | 31 |
|
@@ -123,6 +135,7 @@ def maybe_copy_to(out, result, promote_scalar_result=False):
|
123 | 135 | out.tensor.copy_(result)
|
124 | 136 | return out
|
125 | 137 | elif isinstance(result, (tuple, list)):
|
| 138 | + # FIXME: this is broken (there is no copy_to) |
126 | 139 | return type(result)(map(copy_to, zip(result, out)))
|
127 | 140 | else:
|
128 | 141 | assert False # We should never hit this path
|
@@ -180,9 +193,6 @@ def wrapped(*args, **kwds):
|
180 | 193 |
|
181 | 194 | if "out" in params:
|
182 | 195 | out = sig.bind(*args, **kwds).arguments.get("out")
|
183 |
| - |
184 |
| - ### if out is not None: breakpoint() |
185 |
| - |
186 | 196 | result = maybe_copy_to(out, result, promote_scalar_result)
|
187 | 197 | result = wrap_tensors(result)
|
188 | 198 |
|
|
0 commit comments