Skip to content

Commit 13ba799

Browse files
authored
Merge pull request #106 from Quansight-Labs/resize_3
add copyto, resize, histogram
2 parents d8f1461 + 3e7e937 commit 13ba799

File tree

6 files changed

+927
-32
lines changed

6 files changed

+927
-32
lines changed

torch_np/_funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# these implement ndarray methods but need not be public functions
1919
semi_private = [
2020
"_flatten",
21+
"_ndarray_resize",
2122
]
2223

2324

torch_np/_funcs_impl.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# Contents of this module ends up in the main namespace via _funcs.py
77
# where type annotations are used in conjunction with the @normalizer decorator.
88

9-
9+
import builtins
10+
import math
11+
import operator
1012
from typing import Optional, Sequence
1113

1214
import torch
@@ -36,6 +38,13 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
3638
return a.clone()
3739

3840

41+
def copyto(dst: NDArray, src: ArrayLike, casting="same_kind", where=NoValue):
42+
if where is not NoValue:
43+
raise NotImplementedError
44+
(src,) = _util.typecast_tensors((src,), dst.tensor.dtype, casting=casting)
45+
dst.tensor.copy_(src)
46+
47+
3948
def atleast_1d(*arys: ArrayLike):
4049
res = torch.atleast_1d(*arys)
4150
if isinstance(res, tuple):
@@ -987,6 +996,65 @@ def tile(A: ArrayLike, reps):
987996
return torch.tile(A, reps)
988997

989998

999+
def resize(a: ArrayLike, new_shape=None):
1000+
# implementation vendored from
1001+
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497
1002+
if new_shape is None:
1003+
return a
1004+
1005+
if isinstance(new_shape, int):
1006+
new_shape = (new_shape,)
1007+
1008+
a = ravel(a)
1009+
1010+
new_size = 1
1011+
for dim_length in new_shape:
1012+
new_size *= dim_length
1013+
if dim_length < 0:
1014+
raise ValueError("all elements of `new_shape` must be non-negative")
1015+
1016+
if a.numel() == 0 or new_size == 0:
1017+
# First case must zero fill. The second would have repeats == 0.
1018+
return torch.zeros(new_shape, dtype=a.dtype)
1019+
1020+
repeats = -(-new_size // a.numel()) # ceil division
1021+
a = concatenate((a,) * repeats)[:new_size]
1022+
1023+
return reshape(a, new_shape)
1024+
1025+
1026+
def _ndarray_resize(a: ArrayLike, new_shape, refcheck=False):
1027+
# implementation of ndarray.resize.
1028+
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
1029+
if refcheck:
1030+
raise NotImplementedError(
1031+
f"resize(..., refcheck={refcheck} is not implemented."
1032+
)
1033+
1034+
if new_shape in [(), (None,)]:
1035+
return a
1036+
1037+
# support both x.resize((2, 2)) and x.resize(2, 2)
1038+
if len(new_shape) == 1:
1039+
new_shape = new_shape[0]
1040+
if isinstance(new_shape, int):
1041+
new_shape = (new_shape,)
1042+
1043+
a = ravel(a)
1044+
1045+
if builtins.any(x < 0 for x in new_shape):
1046+
raise ValueError("all elements of `new_shape` must be non-negative")
1047+
1048+
new_numel = math.prod(new_shape)
1049+
if new_numel < a.numel():
1050+
# shrink
1051+
return a[:new_numel].reshape(new_shape)
1052+
else:
1053+
b = torch.zeros(new_numel)
1054+
b[: a.numel()] = a
1055+
return b.reshape(new_shape)
1056+
1057+
9901058
# ### diag et al ###
9911059

9921060

@@ -1811,3 +1879,47 @@ def common_type(*tensors: ArrayLike):
18111879
return array_type[1][precision]
18121880
else:
18131881
return array_type[0][precision]
1882+
1883+
1884+
# ### histograms ###
1885+
1886+
1887+
def histogram(
1888+
a: ArrayLike,
1889+
bins: ArrayLike = 10,
1890+
range=None,
1891+
normed=None,
1892+
weights: Optional[ArrayLike] = None,
1893+
density=None,
1894+
):
1895+
if normed is not None:
1896+
raise ValueError("normed argument is deprecated, use density= instead")
1897+
1898+
is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
1899+
is_w_int = weights is None or not weights.dtype.is_floating_point
1900+
if is_a_int:
1901+
a = a.to(float)
1902+
1903+
if weights is not None:
1904+
weights = _util.cast_if_needed(weights, a.dtype)
1905+
1906+
if isinstance(bins, torch.Tensor):
1907+
if bins.ndim == 0:
1908+
# bins was a single int
1909+
bins = operator.index(bins)
1910+
else:
1911+
bins = _util.cast_if_needed(bins, a.dtype)
1912+
1913+
if range is None:
1914+
h, b = torch.histogram(a, bins, weight=weights, density=bool(density))
1915+
else:
1916+
h, b = torch.histogram(
1917+
a, bins, range=range, weight=weights, density=bool(density)
1918+
)
1919+
1920+
if not density and is_w_int:
1921+
h = h.to(int)
1922+
if is_a_int:
1923+
b = b.to(int)
1924+
1925+
return h, b

torch_np/_ndarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,12 @@ def reshape(self, *shape, order="C"):
358358
ravel = _funcs.ravel
359359
flatten = _funcs._flatten
360360

361+
def resize(self, *new_shape, refcheck=False):
362+
# ndarray.resize works in-place (may cause a reallocation though)
363+
self.tensor = _funcs_impl._ndarray_resize(
364+
self.tensor, new_shape, refcheck=refcheck
365+
)
366+
361367
nonzero = _funcs.nonzero
362368
clip = _funcs.clip
363369
repeat = _funcs.repeat

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4513,7 +4513,6 @@ def test_index_getset(self):
45134513
assert it.index == it.base.size
45144514

45154515

4516-
@pytest.mark.xfail(reason='TODO')
45174516
class TestResize:
45184517

45194518
@_no_tracing
@@ -4523,10 +4522,11 @@ def test_basic(self):
45234522
x.resize((5, 5), refcheck=False)
45244523
else:
45254524
x.resize((5, 5))
4526-
assert_array_equal(x.flat[:9],
4527-
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).flat)
4528-
assert_array_equal(x[9:].flat, 0)
4525+
assert_array_equal(x.ravel()[:9],
4526+
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).ravel())
4527+
assert_array_equal(x[9:].ravel(), 0)
45294528

4529+
@pytest.mark.skip(reason="how to find if someone is refencing an array")
45304530
def test_check_reference(self):
45314531
x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
45324532
y = x
@@ -4565,7 +4565,7 @@ def test_invalid_arguments(self):
45654565
assert_raises(TypeError, np.eye(3).resize, 'hi')
45664566
assert_raises(ValueError, np.eye(3).resize, -1)
45674567
assert_raises(TypeError, np.eye(3).resize, order=1)
4568-
assert_raises(TypeError, np.eye(3).resize, refcheck='hi')
4568+
assert_raises((NotImplementedError, TypeError), np.eye(3).resize, refcheck='hi')
45694569

45704570
@_no_tracing
45714571
def test_freeform_shape(self):
@@ -4586,18 +4586,6 @@ def test_zeros_appended(self):
45864586
assert_array_equal(x[0], np.eye(3))
45874587
assert_array_equal(x[1], np.zeros((3, 3)))
45884588

4589-
@_no_tracing
4590-
def test_obj_obj(self):
4591-
# check memory is initialized on resize, gh-4857
4592-
a = np.ones(10, dtype=[('k', object, 2)])
4593-
if IS_PYPY:
4594-
a.resize(15, refcheck=False)
4595-
else:
4596-
a.resize(15,)
4597-
assert_equal(a.shape, (15,))
4598-
assert_array_equal(a['k'][-5:], 0)
4599-
assert_array_equal(a['k'][:-5], 1)
4600-
46014589
def test_empty_view(self):
46024590
# check that sizes containing a zero don't trigger a reallocate for
46034591
# already empty arrays
@@ -4606,6 +4594,7 @@ def test_empty_view(self):
46064594
x_view.resize((0, 10))
46074595
x_view.resize((0, 100))
46084596

4597+
@pytest.mark.skip(reason="ignore weakrefs for ndarray.resize")
46094598
def test_check_weakref(self):
46104599
x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
46114600
xref = weakref.ref(x)

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from hypothesis.extra import numpy as hynp
2727

2828

29-
@pytest.mark.xfail(reason="TODO")
3029
class TestResize:
3130
def test_copies(self):
3231
A = np.array([[1, 2], [3, 4]])
@@ -64,28 +63,17 @@ def test_zeroresize(self):
6463

6564
def test_reshape_from_zero(self):
6665
# See also gh-6740
67-
A = np.zeros(0, dtype=[('a', np.float32)])
66+
A = np.zeros(0, dtype=np.float32)
6867
Ar = np.resize(A, (2, 1))
6968
assert_array_equal(Ar, np.zeros((2, 1), Ar.dtype))
7069
assert_equal(A.dtype, Ar.dtype)
7170

7271
def test_negative_resize(self):
7372
A = np.arange(0, 10, dtype=np.float32)
7473
new_shape = (-10, -1)
75-
with pytest.raises(ValueError, match=r"negative"):
74+
with pytest.raises((RuntimeError, ValueError)):
7675
np.resize(A, new_shape=new_shape)
7776

78-
def test_subclass(self):
79-
class MyArray(np.ndarray):
80-
__array_priority__ = 1.
81-
82-
my_arr = np.array([1]).view(MyArray)
83-
assert type(np.resize(my_arr, 5)) is MyArray
84-
assert type(np.resize(my_arr, 0)) is MyArray
85-
86-
my_arr = np.array([]).view(MyArray)
87-
assert type(np.resize(my_arr, 5)) is MyArray
88-
8977

9078
class TestNonarrayArgs:
9179
# check that non-array arguments to functions wrap them in arrays

0 commit comments

Comments
 (0)