Skip to content

bare-bones normalizations via type hints #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9647fb6
MAINT: bare-bones normalizations via type hints
ev-br Feb 26, 2023
0b8264f
BUG: normalizations: raise on mismatch between parameters and actual …
ev-br Feb 28, 2023
d583c62
MAINT: normalize dtype in concatenate and *stack family
ev-br Mar 2, 2023
ffe46fa
normalize Optional[ArrayLike] via annotations
ev-br Mar 3, 2023
352f715
MAINT: use normalizer/ArrayLike in _funcs
ev-br Mar 3, 2023
7d26871
MAINT: modify tests arr.base --> arr.get()._base
ev-br Mar 3, 2023
ce9861a
BUG: handle positional-only parameters in @ normalize
ev-br Mar 3, 2023
eec7bc3
MAINT: remove to_tensors_or_none, use Optional[ArrayLike] instead
ev-br Mar 3, 2023
0c98dfb
ENH: normalize tuples of array_likes
ev-br Mar 4, 2023
f5731e5
ENH: annotate *args
ev-br Mar 4, 2023
b7112e3
MAINT: use normalizations across namespace functions
ev-br Mar 4, 2023
94e21dd
lint
ev-br Mar 4, 2023
2fdd3c6
MAINT: simplify sum, prod, mean, var, std, argmin, argmax
ev-br Mar 6, 2023
ab85d72
MAINT: emulate_keepdims via a decorator in _detail/reductions.py
ev-br Mar 7, 2023
93acb7a
MAINT: count_nonzero
ev-br Mar 8, 2023
47d8a1e
MAINT: cumsum/cumprod
ev-br Mar 8, 2023
eb7c4c6
MAINT: quantile/percentile/median
ev-br Mar 8, 2023
a20320b
MAINT: simplify/normalize average
ev-br Mar 8, 2023
7b447af
MAINT: normalize array-like arg of full(), a few others
ev-br Mar 8, 2023
b1ca69a
MAINT: rm dead code from decorators.py
ev-br Mar 8, 2023
0bfddd0
MAINT: rework unary and binary ufuncs w/ normalizations
ev-br Mar 9, 2023
649431c
lint
ev-br Mar 10, 2023
a7ac280
MAINT: move normalization logic to _normalizations
ev-br Mar 10, 2023
10672bb
MAINT: use normalizations in tnp.random
ev-br Mar 10, 2023
16c5aed
MAINT: annotate out as NDArray, remove scattered isinstance checks
ev-br Mar 10, 2023
fe9011d
MAINT: better error message for wrong axis arguments
ev-br Mar 10, 2023
b3d5f0a
MAINT: remove isort:skip directives (circ imports are well hidden now)
ev-br Mar 10, 2023
69e657a
TST: unxfail tests of out and dtype as positional args
ev-br Mar 15, 2023
27cb10f
MAINT: remove debug leftovers
ev-br Mar 20, 2023
a6eb581
MAINT: add a comment on axis=() in reductions
ev-br Mar 20, 2023
8c78725
MAINT: simplify arg/param handing in normalize
ev-br Mar 21, 2023
9d75cab
MAINT: simplify handling of variadic *args in normalize
ev-br Mar 21, 2023
7dced32
MAINT: simplify normalizer
lezcano Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports
from . import random
from ._binary_ufuncs import *
from ._detail._index_tricks import *
Expand All @@ -8,6 +7,7 @@
from ._getlimits import errstate, finfo, iinfo
from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type
from ._unary_ufuncs import *
from ._wrapper import *

# from . import testing

Expand Down
96 changes: 48 additions & 48 deletions torch_np/_binary_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,52 @@
from ._decorators import deco_binary_ufunc_from_impl
from ._detail import _ufunc_impl
from typing import Optional

from . import _helpers
from ._detail import _binary_ufuncs
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer

__all__ = [
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
]


def deco_binary_ufunc(torch_func):
"""Common infra for binary ufuncs.

Normalize arguments, sort out type casting, broadcasting and delegate to
the pytorch functions for the actual work.
"""

def wrapped(
x1: ArrayLike,
x2: ArrayLike,
/,
out: Optional[NDArray] = None,
*,
where=True,
casting="same_kind",
order="K",
dtype: DTypeLike = None,
subok: SubokLike = False,
signature=None,
extobj=None,
):
tensors = _helpers.ufunc_preprocess(
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
)
result = torch_func(*tensors)
return _helpers.result_or_out(result, out)

return wrapped


#
# Functions in this file implement binary ufuncs: wrap two first arguments in
# asarray and delegate to functions from _ufunc_impl.
#
# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks
# with ufunc args, and delegate heavy lifting to pytorch equivalents.
# For each torch ufunc implementation, decorate and attach the decorated name
# to this module. Its contents is then exported to the public namespace in __init__.py
#
for name in __all__:
ufunc = getattr(_binary_ufuncs, name)
decorated = normalizer(deco_binary_ufunc(ufunc))

# the list is autogenerated, cf autogen/gen_ufunc_2.py
add = deco_binary_ufunc_from_impl(_ufunc_impl.add)
arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2)
bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and)
bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or)
bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor)
copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign)
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal)
float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power)
floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide)
fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax)
fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin)
fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod)
gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd)
greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater)
greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal)
heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside)
hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot)
lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm)
ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp)
left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift)
less = deco_binary_ufunc_from_impl(_ufunc_impl.less)
less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal)
logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp)
logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2)
logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and)
logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or)
logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor)
matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul)
maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum)
minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum)
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply)
nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter)
not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal)
power = deco_binary_ufunc_from_impl(_ufunc_impl.power)
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift)
subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract)
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
decorated.__qualname__ = name # XXX: is this really correct?
decorated.__name__ = name
vars()[name] = decorated
115 changes: 0 additions & 115 deletions torch_np/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,10 @@
import functools
import operator

import torch

from . import _dtypes, _helpers
from ._detail import _util

NoValue = None


def dtype_to_torch(func):
@functools.wraps(func)
def wrapped(*args, dtype=None, **kwds):
torch_dtype = None
if dtype is not None:
dtype = _dtypes.dtype(dtype)
torch_dtype = dtype.torch_dtype
return func(*args, dtype=torch_dtype, **kwds)

return wrapped


def emulate_out_arg(func):
"""Simulate the out=... handling: move the result tensor to the out array.

With this decorator, the inner function just does not see the out array.
"""

@functools.wraps(func)
def wrapped(*args, out=None, **kwds):
result_tensor = func(*args, **kwds)
return _helpers.result_or_out(result_tensor, out)

return wrapped


def out_shape_dtype(func):
"""Handle out=... kwarg for ufuncs.
Expand All @@ -51,89 +22,3 @@ def wrapped(*args, out=None, **kwds):
return _helpers.result_or_out(result_tensor, out)

return wrapped


def deco_unary_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
@dtype_to_torch
@out_shape_dtype
def wrapped(x1, *args, **kwds):
from ._ndarray import asarray

x1_tensor = asarray(x1).get()
result = impl_func((x1_tensor,), *args, **kwds)
return result

return wrapped


# TODO: deduplicate with _ndarray/asarray_replacer,
# and _wrapper/concatenate et al
def deco_binary_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
@dtype_to_torch
@out_shape_dtype
def wrapped(x1, x2, *args, **kwds):
from ._ndarray import asarray

x1_tensor = asarray(x1).get()
x2_tensor = asarray(x2).get()
return impl_func((x1_tensor, x2_tensor), *args, **kwds)

return wrapped


def axis_keepdims_wrapper(func):
"""`func` accepts an array-like as a 1st arg, returns a tensor.

This decorator implements the generic handling of axis, out and keepdims
arguments for reduction functions.

Note that we peel off `out=...` and `keepdims=...` args (torch functions never
see them). The `axis` argument we normalize and pass through to pytorch functions.

"""
# TODO: sort out function signatures: how they flow through all decorators etc
@functools.wraps(func)
def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds):
from ._ndarray import asarray, ndarray

tensor = asarray(a).get()

# standardize the axis argument
if isinstance(axis, ndarray):
axis = operator.index(axis)

result = _util.axis_expand_func(func, tensor, axis, *args, **kwds)

if keepdims:
result = _util.apply_keepdims(result, axis, tensor.ndim)

return result

return wrapped


def axis_none_ravel_wrapper(func):
"""`func` accepts an array-like as a 1st arg, returns a tensor.

This decorator implements the generic handling of axis=None acting on a
raveled array. One use is cumprod / cumsum. concatenate also uses a
similar logic.

"""

@functools.wraps(func)
def wrapped(a, axis=None, *args, **kwds):
from ._ndarray import asarray, ndarray

tensor = asarray(a).get()

# standardize the axis argument
if isinstance(axis, ndarray):
axis = operator.index(axis)

result = _util.axis_ravel_func(func, tensor, axis, *args, **kwds)
return result

return wrapped
53 changes: 53 additions & 0 deletions torch_np/_detail/_binary_ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
This listing is further exported to public symbols in the `torch_np/_binary_ufuncs.py` module.
"""

import torch

# renames
from torch import add, arctan2, bitwise_and
from torch import bitwise_left_shift as left_shift
from torch import bitwise_or
from torch import bitwise_right_shift as right_shift
from torch import bitwise_xor, copysign, divide
from torch import eq as equal
from torch import (
float_power,
floor_divide,
fmax,
fmin,
fmod,
gcd,
greater,
greater_equal,
heaviside,
hypot,
lcm,
ldexp,
less,
less_equal,
logaddexp,
logaddexp2,
logical_and,
logical_or,
logical_xor,
maximum,
minimum,
multiply,
nextafter,
not_equal,
)
from torch import pow as power
from torch import remainder, subtract

from . import _dtypes_impl, _util


# work around torch limitations w.r.t. numpy
def matmul(x, y):
# work around RuntimeError: expected scalar type Int but found Double
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
x = _util.cast_if_needed(x, dtype)
y = _util.cast_if_needed(y, dtype)
result = torch.matmul(x, y)
return result
Loading