Skip to content

Fixes and minors #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,8 @@
if inspect.isfunction(getattr(_funcs_impl, x)) and not x.startswith("_")
]

# these implement ndarray methods but need not be public functions
semi_private = [
"_flatten",
"_ndarray_resize",
]


# decorate implementer functions with argument normalizers and export to the top namespace
for name in __all__ + semi_private:
for name in __all__:
func = getattr(_funcs_impl, name)
if name in ["percentile", "quantile", "median"]:
decorated = normalizer(func, promote_scalar_result=True)
Expand Down
122 changes: 33 additions & 89 deletions torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from __future__ import annotations

import builtins
import math
import operator
from typing import Optional, Sequence

Expand Down Expand Up @@ -100,7 +99,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):

def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
# pure torch implementation, used below and in cov/corrcoef below
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
tensors = _concat_cast_helper(tensors, out, dtype, casting)
return torch.cat(tensors, axis)

Expand Down Expand Up @@ -881,21 +880,21 @@ def take(
out: Optional[OutArray] = None,
mode: NotImplementedType = "raise",
):
(a,), axis = _util.axis_none_ravel(a, axis=axis)
(a,), axis = _util.axis_none_flatten(a, axis=axis)
axis = _util.normalize_axis_index(axis, a.ndim)
idx = (slice(None),) * axis + (indices, ...)
result = a[idx]
return result


def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
axis = _util.normalize_axis_index(axis, arr.ndim)
return torch.take_along_dim(arr, indices, axis)


def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
axis = _util.normalize_axis_index(axis, arr.ndim)

indices, values = torch.broadcast_tensors(indices, values)
Expand All @@ -917,9 +916,7 @@ def unique(
*,
equal_nan: NotImplementedType = True,
):
if axis is None:
ar = ar.ravel()
axis = 0
(ar,), axis = _util.axis_none_flatten(ar, axis=axis)
axis = _util.normalize_axis_index(axis, ar.ndim)

is_half = ar.dtype == torch.float16
Expand Down Expand Up @@ -948,7 +945,7 @@ def argwhere(a: ArrayLike):


def flatnonzero(a: ArrayLike):
return torch.ravel(a).nonzero(as_tuple=True)[0]
return torch.flatten(a).nonzero(as_tuple=True)[0]


def clip(
Expand Down Expand Up @@ -980,7 +977,7 @@ def resize(a: ArrayLike, new_shape=None):
if isinstance(new_shape, int):
new_shape = (new_shape,)

a = ravel(a)
a = a.flatten()

new_size = 1
for dim_length in new_shape:
Expand All @@ -998,38 +995,6 @@ def resize(a: ArrayLike, new_shape=None):
return reshape(a, new_shape)


def _ndarray_resize(a: ArrayLike, new_shape, refcheck=False):
# implementation of ndarray.resize.
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
if refcheck:
raise NotImplementedError(
f"resize(..., refcheck={refcheck} is not implemented."
)

if new_shape in [(), (None,)]:
return a

# support both x.resize((2, 2)) and x.resize(2, 2)
if len(new_shape) == 1:
new_shape = new_shape[0]
if isinstance(new_shape, int):
new_shape = (new_shape,)

a = ravel(a)

if builtins.any(x < 0 for x in new_shape):
raise ValueError("all elements of `new_shape` must be non-negative")

new_numel = math.prod(new_shape)
if new_numel < a.numel():
# shrink
return a[:new_numel].reshape(new_shape)
else:
b = torch.zeros(new_numel)
b[: a.numel()] = a
return b.reshape(new_shape)


# ### diag et al ###


Expand Down Expand Up @@ -1132,13 +1097,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):


def vdot(a: ArrayLike, b: ArrayLike, /):
# 1. torch only accepts 1D arrays, numpy ravels
# 1. torch only accepts 1D arrays, numpy flattens
# 2. torch requires matching dtype, while numpy casts (?)
t_a, t_b = torch.atleast_1d(a, b)
if t_a.ndim > 1:
t_a = t_a.ravel()
t_a = t_a.flatten()
if t_b.ndim > 1:
t_b = t_b.ravel()
t_b = t_b.flatten()

dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
is_half = dtype == torch.float16
Expand Down Expand Up @@ -1212,7 +1177,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):


def _sort_helper(tensor, axis, kind, order):
(tensor,), axis = _util.axis_none_ravel(tensor, axis=axis)
(tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
axis = _util.normalize_axis_index(axis, tensor.ndim)

stable = kind == "stable"
Expand Down Expand Up @@ -1328,14 +1293,6 @@ def transpose(a: ArrayLike, axes=None):


def ravel(a: ArrayLike, order: NotImplementedType = "C"):
return torch.ravel(a)


# leading underscore since arr.flatten exists but np.flatten does not


def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
# may return a copy
return torch.flatten(a)


Expand Down Expand Up @@ -1647,7 +1604,7 @@ def diff(
def angle(z: ArrayLike, deg=False):
result = torch.angle(z)
if deg:
result = result * 180 / torch.pi
result = result * (180 / torch.pi)
return result


Expand All @@ -1658,26 +1615,14 @@ def sinc(x: ArrayLike):
# ### Type/shape etc queries ###


def real(a: ArrayLike):
return torch.real(a)


def imag(a: ArrayLike):
if a.is_complex():
result = a.imag
else:
result = torch.zeros_like(a)
return result


def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
if a.is_floating_point():
result = torch.round(a, decimals=decimals)
elif a.is_complex():
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
result = (
torch.round(a.real, decimals=decimals)
+ torch.round(a.imag, decimals=decimals) * 1j
result = torch.complex(
torch.round(a.real, decimals=decimals),
torch.round(a.imag, decimals=decimals),
)
else:
# RuntimeError: "round_cpu" not implemented for 'int'
Expand All @@ -1690,7 +1635,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):


def real_if_close(a: ArrayLike, tol=100):
# XXX: copies vs views; numpy seems to return a copy?
if not torch.is_complex(a):
return a
if tol > 1:
Expand All @@ -1703,47 +1647,49 @@ def real_if_close(a: ArrayLike, tol=100):
return a.real if mask.all() else a


def real(a: ArrayLike):
return torch.real(a)


def imag(a: ArrayLike):
if a.is_complex():
return a.imag
return torch.zeros_like(a)


def iscomplex(x: ArrayLike):
if torch.is_complex(x):
return x.imag != 0
result = torch.zeros_like(x, dtype=torch.bool)
if result.ndim == 0:
result = result.item()
return result
return torch.zeros_like(x, dtype=torch.bool)


def isreal(x: ArrayLike):
if torch.is_complex(x):
return x.imag == 0
result = torch.ones_like(x, dtype=torch.bool)
if result.ndim == 0:
result = result.item()
return result
return torch.ones_like(x, dtype=torch.bool)


def iscomplexobj(x: ArrayLike):
result = torch.is_complex(x)
return result
return torch.is_complex(x)


def isrealobj(x: ArrayLike):
return not torch.is_complex(x)


def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
return torch.isneginf(x, out=out)
return torch.isneginf(x)


def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
return torch.isposinf(x, out=out)
return torch.isposinf(x)


def i0(x: ArrayLike):
return torch.special.i0(x)


def isscalar(a):
# XXX: this is a stub
try:
t = normalize_array_like(a)
return t.numel() == 1
Expand Down Expand Up @@ -1798,8 +1744,6 @@ def bartlett(M):


def common_type(*tensors: ArrayLike):
import builtins

is_complex = False
precision = 0
for a in tensors:
Expand Down Expand Up @@ -1836,7 +1780,7 @@ def histogram(
is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
is_w_int = weights is None or not weights.dtype.is_floating_point
if is_a_int:
a = a.to(float)
a = a.double()

if weights is not None:
weights = _util.cast_if_needed(weights, a.dtype)
Expand All @@ -1856,8 +1800,8 @@ def histogram(
)

if not density and is_w_int:
h = h.to(int)
h = h.long()
if is_a_int:
b = b.to(int)
b = b.long()

return h, b
6 changes: 2 additions & 4 deletions torch_np/_getlimits.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib

import torch

from . import _dtypes
Expand All @@ -13,10 +15,6 @@ def iinfo(dtyp):
return torch.iinfo(torch_dtype)


import contextlib


# FIXME: this is only a stub
@contextlib.contextmanager
def errstate(*args, **kwds):
yield
Loading