Skip to content

Commit 9c627db

Browse files
authored
Merge pull request #129 from Quansight-Labs/review
Fixes and minors
2 parents cf64dbd + c3a9b7e commit 9c627db

File tree

8 files changed

+133
-360
lines changed

8 files changed

+133
-360
lines changed

torch_np/_funcs.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,8 @@
1919
if inspect.isfunction(getattr(_funcs_impl, x)) and not x.startswith("_")
2020
]
2121

22-
# these implement ndarray methods but need not be public functions
23-
semi_private = [
24-
"_flatten",
25-
"_ndarray_resize",
26-
]
27-
28-
2922
# decorate implementer functions with argument normalizers and export to the top namespace
30-
for name in __all__ + semi_private:
23+
for name in __all__:
3124
func = getattr(_funcs_impl, name)
3225
if name in ["percentile", "quantile", "median"]:
3326
decorated = normalizer(func, promote_scalar_result=True)

torch_np/_funcs_impl.py

Lines changed: 33 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import annotations
99

1010
import builtins
11-
import math
1211
import operator
1312
from typing import Optional, Sequence
1413

@@ -100,7 +99,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
10099

101100
def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
102101
# pure torch implementation, used below and in cov/corrcoef below
103-
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
102+
tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
104103
tensors = _concat_cast_helper(tensors, out, dtype, casting)
105104
return torch.cat(tensors, axis)
106105

@@ -881,15 +880,15 @@ def take(
881880
out: Optional[OutArray] = None,
882881
mode: NotImplementedType = "raise",
883882
):
884-
(a,), axis = _util.axis_none_ravel(a, axis=axis)
883+
(a,), axis = _util.axis_none_flatten(a, axis=axis)
885884
axis = _util.normalize_axis_index(axis, a.ndim)
886885
idx = (slice(None),) * axis + (indices, ...)
887886
result = a[idx]
888887
return result
889888

890889

891890
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
892-
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
891+
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
893892
axis = _util.normalize_axis_index(axis, arr.ndim)
894893
return torch.take_along_dim(arr, indices, axis)
895894

@@ -916,7 +915,7 @@ def put(
916915

917916

918917
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
919-
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
918+
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
920919
axis = _util.normalize_axis_index(axis, arr.ndim)
921920

922921
indices, values = torch.broadcast_tensors(indices, values)
@@ -938,9 +937,7 @@ def unique(
938937
*,
939938
equal_nan: NotImplementedType = True,
940939
):
941-
if axis is None:
942-
ar = ar.ravel()
943-
axis = 0
940+
(ar,), axis = _util.axis_none_flatten(ar, axis=axis)
944941
axis = _util.normalize_axis_index(axis, ar.ndim)
945942

946943
is_half = ar.dtype == torch.float16
@@ -969,7 +966,7 @@ def argwhere(a: ArrayLike):
969966

970967

971968
def flatnonzero(a: ArrayLike):
972-
return torch.ravel(a).nonzero(as_tuple=True)[0]
969+
return torch.flatten(a).nonzero(as_tuple=True)[0]
973970

974971

975972
def clip(
@@ -1001,7 +998,7 @@ def resize(a: ArrayLike, new_shape=None):
1001998
if isinstance(new_shape, int):
1002999
new_shape = (new_shape,)
10031000

1004-
a = ravel(a)
1001+
a = a.flatten()
10051002

10061003
new_size = 1
10071004
for dim_length in new_shape:
@@ -1019,38 +1016,6 @@ def resize(a: ArrayLike, new_shape=None):
10191016
return reshape(a, new_shape)
10201017

10211018

1022-
def _ndarray_resize(a: ArrayLike, new_shape, refcheck=False):
1023-
# implementation of ndarray.resize.
1024-
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
1025-
if refcheck:
1026-
raise NotImplementedError(
1027-
f"resize(..., refcheck={refcheck} is not implemented."
1028-
)
1029-
1030-
if new_shape in [(), (None,)]:
1031-
return a
1032-
1033-
# support both x.resize((2, 2)) and x.resize(2, 2)
1034-
if len(new_shape) == 1:
1035-
new_shape = new_shape[0]
1036-
if isinstance(new_shape, int):
1037-
new_shape = (new_shape,)
1038-
1039-
a = ravel(a)
1040-
1041-
if builtins.any(x < 0 for x in new_shape):
1042-
raise ValueError("all elements of `new_shape` must be non-negative")
1043-
1044-
new_numel = math.prod(new_shape)
1045-
if new_numel < a.numel():
1046-
# shrink
1047-
return a[:new_numel].reshape(new_shape)
1048-
else:
1049-
b = torch.zeros(new_numel)
1050-
b[: a.numel()] = a
1051-
return b.reshape(new_shape)
1052-
1053-
10541019
# ### diag et al ###
10551020

10561021

@@ -1153,13 +1118,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
11531118

11541119

11551120
def vdot(a: ArrayLike, b: ArrayLike, /):
1156-
# 1. torch only accepts 1D arrays, numpy ravels
1121+
# 1. torch only accepts 1D arrays, numpy flattens
11571122
# 2. torch requires matching dtype, while numpy casts (?)
11581123
t_a, t_b = torch.atleast_1d(a, b)
11591124
if t_a.ndim > 1:
1160-
t_a = t_a.ravel()
1125+
t_a = t_a.flatten()
11611126
if t_b.ndim > 1:
1162-
t_b = t_b.ravel()
1127+
t_b = t_b.flatten()
11631128

11641129
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
11651130
is_half = dtype == torch.float16
@@ -1233,7 +1198,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
12331198

12341199

12351200
def _sort_helper(tensor, axis, kind, order):
1236-
(tensor,), axis = _util.axis_none_ravel(tensor, axis=axis)
1201+
(tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
12371202
axis = _util.normalize_axis_index(axis, tensor.ndim)
12381203

12391204
stable = kind == "stable"
@@ -1349,14 +1314,6 @@ def transpose(a: ArrayLike, axes=None):
13491314

13501315

13511316
def ravel(a: ArrayLike, order: NotImplementedType = "C"):
1352-
return torch.ravel(a)
1353-
1354-
1355-
# leading underscore since arr.flatten exists but np.flatten does not
1356-
1357-
1358-
def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
1359-
# may return a copy
13601317
return torch.flatten(a)
13611318

13621319

@@ -1668,7 +1625,7 @@ def diff(
16681625
def angle(z: ArrayLike, deg=False):
16691626
result = torch.angle(z)
16701627
if deg:
1671-
result = result * 180 / torch.pi
1628+
result = result * (180 / torch.pi)
16721629
return result
16731630

16741631

@@ -1679,26 +1636,14 @@ def sinc(x: ArrayLike):
16791636
# ### Type/shape etc queries ###
16801637

16811638

1682-
def real(a: ArrayLike):
1683-
return torch.real(a)
1684-
1685-
1686-
def imag(a: ArrayLike):
1687-
if a.is_complex():
1688-
result = a.imag
1689-
else:
1690-
result = torch.zeros_like(a)
1691-
return result
1692-
1693-
16941639
def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
16951640
if a.is_floating_point():
16961641
result = torch.round(a, decimals=decimals)
16971642
elif a.is_complex():
16981643
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1699-
result = (
1700-
torch.round(a.real, decimals=decimals)
1701-
+ torch.round(a.imag, decimals=decimals) * 1j
1644+
result = torch.complex(
1645+
torch.round(a.real, decimals=decimals),
1646+
torch.round(a.imag, decimals=decimals),
17021647
)
17031648
else:
17041649
# RuntimeError: "round_cpu" not implemented for 'int'
@@ -1711,7 +1656,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
17111656

17121657

17131658
def real_if_close(a: ArrayLike, tol=100):
1714-
# XXX: copies vs views; numpy seems to return a copy?
17151659
if not torch.is_complex(a):
17161660
return a
17171661
if tol > 1:
@@ -1724,47 +1668,49 @@ def real_if_close(a: ArrayLike, tol=100):
17241668
return a.real if mask.all() else a
17251669

17261670

1671+
def real(a: ArrayLike):
1672+
return torch.real(a)
1673+
1674+
1675+
def imag(a: ArrayLike):
1676+
if a.is_complex():
1677+
return a.imag
1678+
return torch.zeros_like(a)
1679+
1680+
17271681
def iscomplex(x: ArrayLike):
17281682
if torch.is_complex(x):
17291683
return x.imag != 0
1730-
result = torch.zeros_like(x, dtype=torch.bool)
1731-
if result.ndim == 0:
1732-
result = result.item()
1733-
return result
1684+
return torch.zeros_like(x, dtype=torch.bool)
17341685

17351686

17361687
def isreal(x: ArrayLike):
17371688
if torch.is_complex(x):
17381689
return x.imag == 0
1739-
result = torch.ones_like(x, dtype=torch.bool)
1740-
if result.ndim == 0:
1741-
result = result.item()
1742-
return result
1690+
return torch.ones_like(x, dtype=torch.bool)
17431691

17441692

17451693
def iscomplexobj(x: ArrayLike):
1746-
result = torch.is_complex(x)
1747-
return result
1694+
return torch.is_complex(x)
17481695

17491696

17501697
def isrealobj(x: ArrayLike):
17511698
return not torch.is_complex(x)
17521699

17531700

17541701
def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
1755-
return torch.isneginf(x, out=out)
1702+
return torch.isneginf(x)
17561703

17571704

17581705
def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
1759-
return torch.isposinf(x, out=out)
1706+
return torch.isposinf(x)
17601707

17611708

17621709
def i0(x: ArrayLike):
17631710
return torch.special.i0(x)
17641711

17651712

17661713
def isscalar(a):
1767-
# XXX: this is a stub
17681714
try:
17691715
t = normalize_array_like(a)
17701716
return t.numel() == 1
@@ -1819,8 +1765,6 @@ def bartlett(M):
18191765

18201766

18211767
def common_type(*tensors: ArrayLike):
1822-
import builtins
1823-
18241768
is_complex = False
18251769
precision = 0
18261770
for a in tensors:
@@ -1857,7 +1801,7 @@ def histogram(
18571801
is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
18581802
is_w_int = weights is None or not weights.dtype.is_floating_point
18591803
if is_a_int:
1860-
a = a.to(float)
1804+
a = a.double()
18611805

18621806
if weights is not None:
18631807
weights = _util.cast_if_needed(weights, a.dtype)
@@ -1877,8 +1821,8 @@ def histogram(
18771821
)
18781822

18791823
if not density and is_w_int:
1880-
h = h.to(int)
1824+
h = h.long()
18811825
if is_a_int:
1882-
b = b.to(int)
1826+
b = b.long()
18831827

18841828
return h, b

torch_np/_getlimits.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import contextlib
2+
13
import torch
24

35
from . import _dtypes
@@ -13,10 +15,6 @@ def iinfo(dtyp):
1315
return torch.iinfo(torch_dtype)
1416

1517

16-
import contextlib
17-
18-
19-
# FIXME: this is only a stub
2018
@contextlib.contextmanager
2119
def errstate(*args, **kwds):
2220
yield

0 commit comments

Comments
 (0)