Skip to content

MAINT: split remaining functions into normalizations and implementations #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 66 additions & 36 deletions torch_np/_detail/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,55 +317,51 @@ def average(a, axis, weights, returned=False, keepdims=False):
return result, wsum


def average_noweights(a_tensor, axis, keepdims=False):
result = mean(a_tensor, axis=axis, keepdims=keepdims)
scl = torch.as_tensor(a_tensor.numel() / result.numel(), dtype=result.dtype)
def average_noweights(a, axis, keepdims=False):
result = mean(a, axis=axis, keepdims=keepdims)
scl = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
return result, scl


def average_weights(a_tensor, axis, w_tensor, keepdims=False):
def average_weights(a, axis, w, keepdims=False):

# dtype
# FIXME: 1. use result_type
# 2. actually implement multiply w/dtype
if not a_tensor.dtype.is_floating_point:
if not a.dtype.is_floating_point:
result_dtype = torch.float64
a_tensor = a_tensor.to(result_dtype)
a = a.to(result_dtype)

result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
result_dtype = _dtypes_impl.result_type_impl([a.dtype, w.dtype])

a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
a = _util.cast_if_needed(a, result_dtype)
w = _util.cast_if_needed(w, result_dtype)

# axis=None ravels, so store the originals to reuse with keepdims=True below
ax, ndim = axis, a_tensor.ndim
ax, ndim = axis, a.ndim

# axis
if axis is None:
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
a_tensor, w_tensor, axis=axis
)
(a, w), axis = _util.axis_none_ravel(a, w, axis=axis)

# axis & weights
if a_tensor.shape != w_tensor.shape:
if a.shape != w.shape:
if axis is None:
raise TypeError(
"Axis must be specified when shapes of a and weights " "differ."
)
if w_tensor.ndim != 1:
if w.ndim != 1:
raise TypeError("1D weights expected when shapes of a and weights differ.")
if w_tensor.shape[0] != a_tensor.shape[axis]:
if w.shape[0] != a.shape[axis]:
raise ValueError("Length of weights not compatible with specified axis.")

# setup weight to broadcast along axis
w_tensor = torch.broadcast_to(
w_tensor, (a_tensor.ndim - 1) * (1,) + w_tensor.shape
)
w_tensor = w_tensor.swapaxes(-1, axis)
w = torch.broadcast_to(w, (a.ndim - 1) * (1,) + w.shape)
w = w.swapaxes(-1, axis)

# do the work
numerator = torch.mul(a_tensor, w_tensor).sum(axis)
denominator = w_tensor.sum(axis)
numerator = torch.mul(a, w).sum(axis)
denominator = w.sum(axis)
result = numerator / denominator

# keepdims
Expand All @@ -375,36 +371,70 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
return result, denominator


def quantile(a_tensor, q_tensor, axis, method, keepdims=False):

if (0 > q_tensor).any() or (q_tensor > 1).any():
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)

if not a_tensor.dtype.is_floating_point:
def quantile(
a,
q,
axis,
overwrite_input,
method,
keepdims=False,
interpolation=None,
):
if overwrite_input:
# raise NotImplementedError("overwrite_input in quantile not implemented.")
# NumPy documents that `overwrite_input` MAY modify inputs:
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
# Here we choose to work out-of-place because why not.
pass

if interpolation is not None:
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")

if not a.dtype.is_floating_point:
dtype = _dtypes_impl.default_float_dtype
a_tensor = a_tensor.to(dtype)
a = a.to(dtype)

# edge case: torch.quantile only supports float32 and float64
if a_tensor.dtype == torch.float16:
a_tensor = a_tensor.to(torch.float32)
if a.dtype == torch.float16:
a = a.to(torch.float32)

# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
# axis
if axis is not None:
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
axis = _util.normalize_axis_tuple(axis, a.ndim)
axis = _util.allow_only_single_axis(axis)

q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
q = _util.cast_if_needed(q, a.dtype)

# axis=None ravels, so store the originals to reuse with keepdims=True below
ax, ndim = axis, a_tensor.ndim
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
ax, ndim = axis, a.ndim
(a, q), axis = _util.axis_none_ravel(a, q, axis=axis)

result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method)
result = torch.quantile(a, q, axis=axis, interpolation=method)

# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
# while the decorator expects (a, axis, ...)
# this can be fixed, of course, but the cure seems worse then the desease
if keepdims:
result = _util.apply_keepdims(result, ax, ndim)
return result


def percentile(
a,
q,
axis,
overwrite_input,
method,
keepdims=False,
interpolation=None,
):
return quantile(
a,
q / 100.0,
axis=axis,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
interpolation=interpolation,
)
86 changes: 78 additions & 8 deletions torch_np/_detail/implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
if n < 0:
raise ValueError(f"order must be non-negative but got {n}")

if n == 0:
# match numpy and return the input immediately
return a_tensor

if prepend_tensor is not None:
shape = list(a_tensor.shape)
shape[axis] = prepend_tensor.shape[axis] if prepend_tensor.ndim > 0 else 1
Expand Down Expand Up @@ -357,6 +361,14 @@ def vstack(tensors, *, dtype=None, casting="same_kind"):
return result


def tile(tensor, reps):
if isinstance(reps, int):
reps = (reps,)

result = torch.tile(tensor, reps)
return result


# #### cov & corrcoef


Expand Down Expand Up @@ -449,17 +461,29 @@ def indices(dimensions, dtype=int, sparse=False):
return res


def bincount(x_tensor, /, weights_tensor=None, minlength=0):
def bincount(x, /, weights=None, minlength=0):
if x.numel() == 0:
# edge case allowed by numpy
x = x.new_empty(0, dtype=int)

int_dtype = _dtypes_impl.default_int_dtype
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
(x,) = _util.cast_dont_broadcast((x,), int_dtype, casting="safe")

result = torch.bincount(x_tensor, weights_tensor, minlength)
result = torch.bincount(x, weights, minlength)
return result


# ### linspace, geomspace, logspace and arange ###


def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
if axis != 0 or retstep or not endpoint:
raise NotImplementedError
# XXX: raises TypeError if start or stop are not scalars
result = torch.linspace(start, stop, num, dtype=dtype)
return result


def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
if axis != 0 or not endpoint:
raise NotImplementedError
Expand All @@ -474,6 +498,13 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
return result


def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
if axis != 0 or not endpoint:
raise NotImplementedError
result = torch.logspace(start, stop, num, base=base, dtype=dtype)
return result


def arange(start=None, stop=None, step=1, dtype=None):
if step == 0:
raise ZeroDivisionError
Expand Down Expand Up @@ -523,36 +554,75 @@ def eye(N, M=None, k=0, dtype=float):
return z


def zeros_like(a, dtype=None, shape=None):
def zeros(shape, dtype=None, order="C"):
if order != "C":
raise NotImplementedError
if dtype is None:
dtype = _dtypes_impl.default_float_dtype
result = torch.zeros(shape, dtype=dtype)
return result


def zeros_like(a, dtype=None, shape=None, order="K"):
if order != "K":
raise NotImplementedError
result = torch.zeros_like(a, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result


def ones_like(a, dtype=None, shape=None):
def ones(shape, dtype=None, order="C"):
if order != "C":
raise NotImplementedError
if dtype is None:
dtype = _dtypes_impl.default_float_dtype
result = torch.ones(shape, dtype=dtype)
return result


def ones_like(a, dtype=None, shape=None, order="K"):
if order != "K":
raise NotImplementedError
result = torch.ones_like(a, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result


def full_like(a, fill_value, dtype=None, shape=None):
def full_like(a, fill_value, dtype=None, shape=None, order="K"):
if order != "K":
raise NotImplementedError
# XXX: fill_value broadcasts
result = torch.full_like(a, fill_value, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result


def empty_like(prototype, dtype=None, shape=None):
def empty(shape, dtype=None, order="C"):
if order != "C":
raise NotImplementedError
if dtype is None:
dtype = _dtypes_impl.default_float_dtype
result = torch.empty(shape, dtype=dtype)
return result


def empty_like(prototype, dtype=None, shape=None, order="K"):
if order != "K":
raise NotImplementedError
result = torch.empty_like(prototype, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result


def full(shape, fill_value, dtype=None):
def full(shape, fill_value, dtype=None, order="C"):
if isinstance(shape, int):
shape = (shape,)
if order != "C":
raise NotImplementedError
if dtype is None:
dtype = fill_value.dtype
if not isinstance(shape, (tuple, list)):
Expand Down
61 changes: 55 additions & 6 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ def nonzero(a: ArrayLike):
return _helpers.tuple_arrays_from(result)


def argwhere(a):
(tensor,) = _helpers.to_tensors(a)
result = torch.argwhere(tensor)
@normalizer
def argwhere(a: ArrayLike):
result = torch.argwhere(a)
return _helpers.array_from(result)


@normalizer
def flatnonzero(a: ArrayLike):
result = a.ravel().nonzero(as_tuple=True)[0]
return _helpers.array_from(result)


Expand All @@ -49,6 +55,12 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
return _helpers.array_from(result)


@normalizer
def tile(A: ArrayLike, reps):
result = _impl.tile(A, reps)
return _helpers.array_from(result)


# ### diag et al ###


Expand Down Expand Up @@ -448,8 +460,45 @@ def quantile(
*,
interpolation=None,
):
if interpolation is not None:
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
result = _impl.quantile(
a,
q,
axis,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
interpolation=interpolation,
)
return _helpers.result_or_out(result, out, promote_scalar=True)

result = _impl.quantile(a, q, axis, method=method, keepdims=keepdims)

@normalizer
def percentile(
a: ArrayLike,
q: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
overwrite_input=False,
method="linear",
keepdims=False,
*,
interpolation=None,
):
result = _impl.percentile(
a,
q,
axis,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
interpolation=interpolation,
)
return _helpers.result_or_out(result, out, promote_scalar=True)


def median(
a, axis=None, out: Optional[NDArray] = None, overwrite_input=False, keepdims=False
):
return quantile(
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
)
Loading