Skip to content

add copyto, resize, histogram #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# these implement ndarray methods but need not be public functions
semi_private = [
"_flatten",
"_ndarray_resize",
]


Expand Down
114 changes: 113 additions & 1 deletion torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
# Contents of this module ends up in the main namespace via _funcs.py
# where type annotations are used in conjunction with the @normalizer decorator.


import builtins
import math
import operator
from typing import Optional, Sequence

import torch
Expand Down Expand Up @@ -36,6 +38,13 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
return a.clone()


def copyto(dst: NDArray, src: ArrayLike, casting="same_kind", where=NoValue):
if where is not NoValue:
raise NotImplementedError
(src,) = _util.typecast_tensors((src,), dst.tensor.dtype, casting=casting)
dst.tensor.copy_(src)


def atleast_1d(*arys: ArrayLike):
res = torch.atleast_1d(*arys)
if isinstance(res, tuple):
Expand Down Expand Up @@ -987,6 +996,65 @@ def tile(A: ArrayLike, reps):
return torch.tile(A, reps)


def resize(a: ArrayLike, new_shape=None):
# implementation vendored from
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497
if new_shape is None:
return a

if isinstance(new_shape, int):
new_shape = (new_shape,)

a = ravel(a)

new_size = 1
for dim_length in new_shape:
new_size *= dim_length
if dim_length < 0:
raise ValueError("all elements of `new_shape` must be non-negative")

if a.numel() == 0 or new_size == 0:
# First case must zero fill. The second would have repeats == 0.
return torch.zeros(new_shape, dtype=a.dtype)

repeats = -(-new_size // a.numel()) # ceil division
a = concatenate((a,) * repeats)[:new_size]

return reshape(a, new_shape)


def _ndarray_resize(a: ArrayLike, new_shape, refcheck=False):
# implementation of ndarray.resize.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my my...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I wanted to reference the next line)

# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
if refcheck:
raise NotImplementedError(
f"resize(..., refcheck={refcheck} is not implemented."
)

if new_shape in [(), (None,)]:
return a

# support both x.resize((2, 2)) and x.resize(2, 2)
if len(new_shape) == 1:
new_shape = new_shape[0]
if isinstance(new_shape, int):
new_shape = (new_shape,)

a = ravel(a)

if builtins.any(x < 0 for x in new_shape):
raise ValueError("all elements of `new_shape` must be non-negative")

new_numel = math.prod(new_shape)
if new_numel < a.numel():
# shrink
return a[:new_numel].reshape(new_shape)
else:
b = torch.zeros(new_numel)
b[: a.numel()] = a
return b.reshape(new_shape)


# ### diag et al ###


Expand Down Expand Up @@ -1811,3 +1879,47 @@ def common_type(*tensors: ArrayLike):
return array_type[1][precision]
else:
return array_type[0][precision]


# ### histograms ###


def histogram(
a: ArrayLike,
bins: ArrayLike = 10,
range=None,
normed=None,
weights: Optional[ArrayLike] = None,
density=None,
):
if normed is not None:
raise ValueError("normed argument is deprecated, use density= instead")

is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
is_w_int = weights is None or not weights.dtype.is_floating_point
if is_a_int:
a = a.to(float)

if weights is not None:
weights = _util.cast_if_needed(weights, a.dtype)

if isinstance(bins, torch.Tensor):
if bins.ndim == 0:
# bins was a single int
bins = operator.index(bins)
else:
bins = _util.cast_if_needed(bins, a.dtype)

if range is None:
h, b = torch.histogram(a, bins, weight=weights, density=bool(density))
else:
h, b = torch.histogram(
a, bins, range=range, weight=weights, density=bool(density)
)

if not density and is_w_int:
h = h.to(int)
if is_a_int:
b = b.to(int)

return h, b
6 changes: 6 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,12 @@ def reshape(self, *shape, order="C"):
ravel = _funcs.ravel
flatten = _funcs._flatten

def resize(self, *new_shape, refcheck=False):
# ndarray.resize works in-place (may cause a reallocation though)
self.tensor = _funcs_impl._ndarray_resize(
self.tensor, new_shape, refcheck=refcheck
)

nonzero = _funcs.nonzero
clip = _funcs.clip
repeat = _funcs.repeat
Expand Down
23 changes: 6 additions & 17 deletions torch_np/tests/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4513,7 +4513,6 @@ def test_index_getset(self):
assert it.index == it.base.size


@pytest.mark.xfail(reason='TODO')
class TestResize:

@_no_tracing
Expand All @@ -4523,10 +4522,11 @@ def test_basic(self):
x.resize((5, 5), refcheck=False)
else:
x.resize((5, 5))
assert_array_equal(x.flat[:9],
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).flat)
assert_array_equal(x[9:].flat, 0)
assert_array_equal(x.ravel()[:9],
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).ravel())
assert_array_equal(x[9:].ravel(), 0)

@pytest.mark.skip(reason="how to find if someone is refencing an array")
def test_check_reference(self):
x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
y = x
Expand Down Expand Up @@ -4565,7 +4565,7 @@ def test_invalid_arguments(self):
assert_raises(TypeError, np.eye(3).resize, 'hi')
assert_raises(ValueError, np.eye(3).resize, -1)
assert_raises(TypeError, np.eye(3).resize, order=1)
assert_raises(TypeError, np.eye(3).resize, refcheck='hi')
assert_raises((NotImplementedError, TypeError), np.eye(3).resize, refcheck='hi')

@_no_tracing
def test_freeform_shape(self):
Expand All @@ -4586,18 +4586,6 @@ def test_zeros_appended(self):
assert_array_equal(x[0], np.eye(3))
assert_array_equal(x[1], np.zeros((3, 3)))

@_no_tracing
def test_obj_obj(self):
# check memory is initialized on resize, gh-4857
a = np.ones(10, dtype=[('k', object, 2)])
if IS_PYPY:
a.resize(15, refcheck=False)
else:
a.resize(15,)
assert_equal(a.shape, (15,))
assert_array_equal(a['k'][-5:], 0)
assert_array_equal(a['k'][:-5], 1)

def test_empty_view(self):
# check that sizes containing a zero don't trigger a reallocate for
# already empty arrays
Expand All @@ -4606,6 +4594,7 @@ def test_empty_view(self):
x_view.resize((0, 10))
x_view.resize((0, 100))

@pytest.mark.skip(reason="ignore weakrefs for ndarray.resize")
def test_check_weakref(self):
x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
xref = weakref.ref(x)
Expand Down
16 changes: 2 additions & 14 deletions torch_np/tests/numpy_tests/core/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from hypothesis.extra import numpy as hynp


@pytest.mark.xfail(reason="TODO")
class TestResize:
def test_copies(self):
A = np.array([[1, 2], [3, 4]])
Expand Down Expand Up @@ -64,28 +63,17 @@ def test_zeroresize(self):

def test_reshape_from_zero(self):
# See also gh-6740
A = np.zeros(0, dtype=[('a', np.float32)])
A = np.zeros(0, dtype=np.float32)
Ar = np.resize(A, (2, 1))
assert_array_equal(Ar, np.zeros((2, 1), Ar.dtype))
assert_equal(A.dtype, Ar.dtype)

def test_negative_resize(self):
A = np.arange(0, 10, dtype=np.float32)
new_shape = (-10, -1)
with pytest.raises(ValueError, match=r"negative"):
with pytest.raises((RuntimeError, ValueError)):
np.resize(A, new_shape=new_shape)

def test_subclass(self):
class MyArray(np.ndarray):
__array_priority__ = 1.

my_arr = np.array([1]).view(MyArray)
assert type(np.resize(my_arr, 5)) is MyArray
assert type(np.resize(my_arr, 0)) is MyArray

my_arr = np.array([]).view(MyArray)
assert type(np.resize(my_arr, 5)) is MyArray


class TestNonarrayArgs:
# check that non-array arguments to functions wrap them in arrays
Expand Down
Loading