Skip to content

Commit e06cb9d

Browse files
committed
MAINT: add comments on OutArray, remove cruft
1 parent 3887e26 commit e06cb9d

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

torch_np/_normalizations.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414
SubokLike = typing.TypeVar("SubokLike")
1515
AxisLike = typing.TypeVar("AxisLike")
1616
NDArray = typing.TypeVar("NDarray")
17+
18+
# OutArray is to annotate the out= array argument.
19+
#
20+
# This one is special is several respects:
21+
# First, It needs to be an NDArray, and we need to preserve the `result is out`
22+
# semantics. Therefore, we cannot just extract the Tensor from the out array.
23+
# So we never pass the out array to implementer functions and handle it in the
24+
# `normalizer` below.
25+
# Second, the out= argument can be either keyword or positional argument, and
26+
# as a positional arg, it can be anywhere in the signature.
27+
# To handle all this, we define a special `OutArray` annotation and dispatch on it.
28+
#
1729
OutArray = typing.TypeVar("OutArray")
1830

1931

@@ -123,6 +135,7 @@ def maybe_copy_to(out, result, promote_scalar_result=False):
123135
out.tensor.copy_(result)
124136
return out
125137
elif isinstance(result, (tuple, list)):
138+
# FIXME: this is broken (there is no copy_to)
126139
return type(result)(map(copy_to, zip(result, out)))
127140
else:
128141
assert False # We should never hit this path
@@ -180,9 +193,6 @@ def wrapped(*args, **kwds):
180193

181194
if "out" in params:
182195
out = sig.bind(*args, **kwds).arguments.get("out")
183-
184-
### if out is not None: breakpoint()
185-
186196
result = maybe_copy_to(out, result, promote_scalar_result)
187197
result = wrap_tensors(result)
188198

0 commit comments

Comments
 (0)