Skip to content

wrap returns out annotation #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 31 additions & 43 deletions torch_np/_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def wrapped(
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)

result = torch_func(*tensors)
return _helpers.result_or_out(result, out)
return result

return wrapped

Expand Down Expand Up @@ -77,69 +77,57 @@ def matmul(

# NB: do not broadcast input tensors against the out=... array
result = _binary_ufuncs.matmul(*tensors)
return _helpers.result_or_out(result, out)


#
# For each torch ufunc implementation, decorate and attach the decorated name
# to this module. Its contents is then exported to the public namespace in __init__.py
#
for name in __all__:
ufunc = getattr(_binary_ufuncs, name)
decorated = normalizer(deco_binary_ufunc(ufunc))

decorated.__qualname__ = name # XXX: is this really correct?
decorated.__name__ = name
vars()[name] = decorated


# a stub implementation of divmod, should be improved after
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
#
# Implementation details: we just call two ufuncs which have been created
# just above, for x1 // x2 and x1 % x2.
# This means we are normalizing x1, x2 in each of the ufuncs --- note that there
# is no @normalizer on divmod.
return result


def divmod(
x1,
x2,
x1: ArrayLike,
x2: ArrayLike,
out1: Optional[NDArray] = None,
out2: Optional[NDArray] = None,
/,
out=None,
out: Optional[tuple[NDArray]] = (None, None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit tuple[Optional[NDArray], Optional[NDArray]].

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? It's either (None, None) or (NDArray, NDArray); can't have (out1, None) (I think)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tuple, as opposed, to list takes the number of arguments that it accepts explicitly. As such, there you are declaring there a tuple of length 1 and one NDArray.
Then, Optional[tuple[NDArray, NDArray]] accepts either a tuple of ndarrays (fine) or None (not (None, None)).
What you want is to have a (partial) function that accepts tuple[Optional[NDArray], Optional[NDArray]] and errors out when just one of the elements of the tuple is None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to patch this small correction in any other PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch. python typing meta-language is seriously weird.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is not specific to Python, the same would be true in C++. This would be an std::tuple<std::optional<array, std::optional<array>> rather than a std::optional<std::tuple<array, array>> :)

*,
where=True,
casting="same_kind",
order="K",
dtype=None,
dtype: DTypeLike = None,
subok: SubokLike = False,
signature=None,
extobj=None,
):
out1, out2 = None, None
num_outs = sum(x is None for x in [out1, out2])
if sum_outs == 1:
raise ValueError("both out1 and out2 need to be provided")
if sum_outs != 0 and out != (None, None):
raise ValueError("Either provide out1 and out2, or out.")
if out is not None:
out1, out2 = out
if out1.shape != out2.shape or out1.dtype != out2.dtype:
raise ValueError("out1, out2 must be compatible")

kwds = dict(
where=where,
casting=casting,
order=order,
dtype=dtype,
subok=subok,
signature=signature,
extobj=extobj,
tensors = _helpers.ufunc_preprocess(
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed ufunc_preprocess should not take an out parameter as setting an out= parameter should not affect the type of the computation. Or is this fixed in some other PR and we'll just need to update this one afterwards?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still in the TODO pipeline. Will do after the current "stack" is done.

)

# NB: use local names for
quot = floor_divide(x1, x2, out=out1, **kwds)
rem = remainder(x1, x2, out=out2, **kwds)

quot = _helpers.result_or_out(quot.tensor, out1)
rem = _helpers.result_or_out(rem.tensor, out2)
result = _binary_ufuncs.divmod(*tensors)

return quot, rem


#
# For each torch ufunc implementation, decorate and attach the decorated name
# to this module. Its contents is then exported to the public namespace in __init__.py
#
for name in __all__:
ufunc = getattr(_binary_ufuncs, name)
decorated = normalizer(deco_binary_ufunc(ufunc))

decorated.__qualname__ = name # XXX: is this really correct?
decorated.__name__ = name
vars()[name] = decorated


def modf(x, /, *args, **kwds):
quot, rem = divmod(x, 1, *args, **kwds)
return rem, quot
Expand Down
24 changes: 0 additions & 24 deletions torch_np/_decorators.py

This file was deleted.

6 changes: 6 additions & 0 deletions torch_np/_detail/_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,9 @@ def matmul(x, y):
result = result.to(dtype)

return result


# a stub implementation of divmod, should be improved after
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
def divmod(x, y):
return x // y, x % y
44 changes: 22 additions & 22 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def clip(
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
# one of them to be None. Follow the more lax version.
result = _impl.clip(a, min, max)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand Down Expand Up @@ -80,7 +80,7 @@ def trace(
out: Optional[NDArray] = None,
):
result = _impl.trace(a, offset, axis1, axis2, dtype)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand Down Expand Up @@ -135,7 +135,7 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
@normalizer
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
result = _impl.dot(a, b)
return _helpers.result_or_out(result, out)
return result


# ### sort and partition ###
Expand Down Expand Up @@ -234,7 +234,7 @@ def imag(a: ArrayLike):
@normalizer
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
result = _impl.round(a, decimals)
return _helpers.result_or_out(result, out)
return result


around = round_
Expand All @@ -257,7 +257,7 @@ def sum(
result = _impl.sum(
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -273,7 +273,7 @@ def prod(
result = _impl.prod(
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
return result


product = prod
Expand All @@ -290,7 +290,7 @@ def mean(
where=NoValue,
):
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -307,7 +307,7 @@ def var(
result = _impl.var(
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -324,7 +324,7 @@ def std(
result = _impl.std(
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -336,7 +336,7 @@ def argmin(
keepdims=NoValue,
):
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -348,7 +348,7 @@ def argmax(
keepdims=NoValue,
):
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -361,7 +361,7 @@ def amax(
where=NoValue,
):
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


max = amax
Expand All @@ -377,7 +377,7 @@ def amin(
where=NoValue,
):
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


min = amin
Expand All @@ -388,7 +388,7 @@ def ptp(
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
):
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -401,7 +401,7 @@ def all(
where=NoValue,
):
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -414,7 +414,7 @@ def any(
where=NoValue,
):
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -431,7 +431,7 @@ def cumsum(
out: Optional[NDArray] = None,
):
result = _impl.cumsum(a, axis=axis, dtype=dtype)
return _helpers.result_or_out(result, out)
return result


@normalizer
Expand All @@ -442,13 +442,13 @@ def cumprod(
out: Optional[NDArray] = None,
):
result = _impl.cumprod(a, axis=axis, dtype=dtype)
return _helpers.result_or_out(result, out)
return result


cumproduct = cumprod


@normalizer
@normalizer(promote_scalar_result=True)
def quantile(
a: ArrayLike,
q: ArrayLike,
Expand All @@ -469,10 +469,10 @@ def quantile(
keepdims=keepdims,
interpolation=interpolation,
)
return _helpers.result_or_out(result, out, promote_scalar=True)
return result


@normalizer
@normalizer(promote_scalar_result=True)
def percentile(
a: ArrayLike,
q: ArrayLike,
Expand All @@ -493,7 +493,7 @@ def percentile(
keepdims=keepdims,
interpolation=interpolation,
)
return _helpers.result_or_out(result, out, promote_scalar=True)
return result


def median(
Expand Down
33 changes: 0 additions & 33 deletions torch_np/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,6 @@ def ufunc_preprocess(
return tensors


# ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ###


def result_or_out(result_tensor, out_array=None, promote_scalar=False):
"""A helper for returns with out= argument.

If `promote_scalar is True`, then:
if result_tensor.numel() == 1 and out is zero-dimensional,
result_tensor is placed into the out array.
This weirdness is used e.g. in `np.percentile`
"""
if out_array is not None:
if result_tensor.shape != out_array.shape:
can_fit = result_tensor.numel() == 1 and out_array.ndim == 0
if promote_scalar and can_fit:
result_tensor = result_tensor.squeeze()
else:
raise ValueError(
f"Bad size of the out array: out.shape = {out_array.shape}"
f" while result.shape = {result_tensor.shape}."
)
out_tensor = out_array.tensor
out_tensor.copy_(result_tensor)
return out_array
else:
from ._ndarray import ndarray

return ndarray(result_tensor)


# ### Various ways of converting array-likes to tensors ###


def ndarrays_to_tensors(*inputs):
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
from ._ndarray import asarray, ndarray
Expand Down
Loading