Skip to content

Commit 91ab9a7

Browse files
committed
Fix and minors
Don't use out= explicitly, improved a bit the implementation of `average`. Minor improvements here and there e.g. - prefer flatten over ravel as it's more PyTorch-y - Prefer `.double()` or `.long()` over `to(float)` (I didn't even know that worked...) for the same reason - Don't call .item() if we can avoid it (added to the list of differences) - remove _mappings - remove the need of semi_private methods - Fixed ndarray.fill, as the normalizer was not working because of the future annotations
1 parent 2ce7d21 commit 91ab9a7

File tree

8 files changed

+131
-357
lines changed

8 files changed

+131
-357
lines changed

torch_np/_funcs.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,8 @@
1919
if inspect.isfunction(getattr(_funcs_impl, x)) and not x.startswith("_")
2020
]
2121

22-
# these implement ndarray methods but need not be public functions
23-
semi_private = [
24-
"_flatten",
25-
"_ndarray_resize",
26-
]
27-
28-
2922
# decorate implementer functions with argument normalizers and export to the top namespace
30-
for name in __all__ + semi_private:
23+
for name in __all__:
3124
func = getattr(_funcs_impl, name)
3225
if name in ["percentile", "quantile", "median"]:
3326
decorated = normalizer(func, promote_scalar_result=True)

torch_np/_funcs_impl.py

Lines changed: 33 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import annotations
99

1010
import builtins
11-
import math
1211
import operator
1312
from typing import Optional, Sequence
1413

@@ -100,7 +99,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
10099

101100
def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
102101
# pure torch implementation, used below and in cov/corrcoef below
103-
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
102+
tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
104103
tensors = _concat_cast_helper(tensors, out, dtype, casting)
105104
return torch.cat(tensors, axis)
106105

@@ -881,21 +880,21 @@ def take(
881880
out: Optional[OutArray] = None,
882881
mode: NotImplementedType = "raise",
883882
):
884-
(a,), axis = _util.axis_none_ravel(a, axis=axis)
883+
(a,), axis = _util.axis_none_flatten(a, axis=axis)
885884
axis = _util.normalize_axis_index(axis, a.ndim)
886885
idx = (slice(None),) * axis + (indices, ...)
887886
result = a[idx]
888887
return result
889888

890889

891890
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
892-
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
891+
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
893892
axis = _util.normalize_axis_index(axis, arr.ndim)
894893
return torch.take_along_dim(arr, indices, axis)
895894

896895

897896
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
898-
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
897+
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
899898
axis = _util.normalize_axis_index(axis, arr.ndim)
900899

901900
indices, values = torch.broadcast_tensors(indices, values)
@@ -917,9 +916,7 @@ def unique(
917916
*,
918917
equal_nan: NotImplementedType = True,
919918
):
920-
if axis is None:
921-
ar = ar.ravel()
922-
axis = 0
919+
(ar,), axis = _util.axis_none_flatten(ar, axis=axis)
923920
axis = _util.normalize_axis_index(axis, ar.ndim)
924921

925922
is_half = ar.dtype == torch.float16
@@ -948,7 +945,7 @@ def argwhere(a: ArrayLike):
948945

949946

950947
def flatnonzero(a: ArrayLike):
951-
return torch.ravel(a).nonzero(as_tuple=True)[0]
948+
return torch.flatten(a).nonzero(as_tuple=True)[0]
952949

953950

954951
def clip(
@@ -980,7 +977,7 @@ def resize(a: ArrayLike, new_shape=None):
980977
if isinstance(new_shape, int):
981978
new_shape = (new_shape,)
982979

983-
a = ravel(a)
980+
a = a.flatten()
984981

985982
new_size = 1
986983
for dim_length in new_shape:
@@ -998,38 +995,6 @@ def resize(a: ArrayLike, new_shape=None):
998995
return reshape(a, new_shape)
999996

1000997

1001-
def _ndarray_resize(a: ArrayLike, new_shape, refcheck=False):
1002-
# implementation of ndarray.resize.
1003-
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
1004-
if refcheck:
1005-
raise NotImplementedError(
1006-
f"resize(..., refcheck={refcheck} is not implemented."
1007-
)
1008-
1009-
if new_shape in [(), (None,)]:
1010-
return a
1011-
1012-
# support both x.resize((2, 2)) and x.resize(2, 2)
1013-
if len(new_shape) == 1:
1014-
new_shape = new_shape[0]
1015-
if isinstance(new_shape, int):
1016-
new_shape = (new_shape,)
1017-
1018-
a = ravel(a)
1019-
1020-
if builtins.any(x < 0 for x in new_shape):
1021-
raise ValueError("all elements of `new_shape` must be non-negative")
1022-
1023-
new_numel = math.prod(new_shape)
1024-
if new_numel < a.numel():
1025-
# shrink
1026-
return a[:new_numel].reshape(new_shape)
1027-
else:
1028-
b = torch.zeros(new_numel)
1029-
b[: a.numel()] = a
1030-
return b.reshape(new_shape)
1031-
1032-
1033998
# ### diag et al ###
1034999

10351000

@@ -1132,13 +1097,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
11321097

11331098

11341099
def vdot(a: ArrayLike, b: ArrayLike, /):
1135-
# 1. torch only accepts 1D arrays, numpy ravels
1100+
# 1. torch only accepts 1D arrays, numpy flattens
11361101
# 2. torch requires matching dtype, while numpy casts (?)
11371102
t_a, t_b = torch.atleast_1d(a, b)
11381103
if t_a.ndim > 1:
1139-
t_a = t_a.ravel()
1104+
t_a = t_a.flatten()
11401105
if t_b.ndim > 1:
1141-
t_b = t_b.ravel()
1106+
t_b = t_b.flatten()
11421107

11431108
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
11441109
is_half = dtype == torch.float16
@@ -1212,7 +1177,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
12121177

12131178

12141179
def _sort_helper(tensor, axis, kind, order):
1215-
(tensor,), axis = _util.axis_none_ravel(tensor, axis=axis)
1180+
(tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
12161181
axis = _util.normalize_axis_index(axis, tensor.ndim)
12171182

12181183
stable = kind == "stable"
@@ -1328,14 +1293,6 @@ def transpose(a: ArrayLike, axes=None):
13281293

13291294

13301295
def ravel(a: ArrayLike, order: NotImplementedType = "C"):
1331-
return torch.ravel(a)
1332-
1333-
1334-
# leading underscore since arr.flatten exists but np.flatten does not
1335-
1336-
1337-
def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
1338-
# may return a copy
13391296
return torch.flatten(a)
13401297

13411298

@@ -1647,7 +1604,7 @@ def diff(
16471604
def angle(z: ArrayLike, deg=False):
16481605
result = torch.angle(z)
16491606
if deg:
1650-
result = result * 180 / torch.pi
1607+
result = result * (180 / torch.pi)
16511608
return result
16521609

16531610

@@ -1658,26 +1615,14 @@ def sinc(x: ArrayLike):
16581615
# ### Type/shape etc queries ###
16591616

16601617

1661-
def real(a: ArrayLike):
1662-
return torch.real(a)
1663-
1664-
1665-
def imag(a: ArrayLike):
1666-
if a.is_complex():
1667-
result = a.imag
1668-
else:
1669-
result = torch.zeros_like(a)
1670-
return result
1671-
1672-
16731618
def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
16741619
if a.is_floating_point():
16751620
result = torch.round(a, decimals=decimals)
16761621
elif a.is_complex():
16771622
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1678-
result = (
1679-
torch.round(a.real, decimals=decimals)
1680-
+ torch.round(a.imag, decimals=decimals) * 1j
1623+
result = torch.complex(
1624+
torch.round(a.real, decimals=decimals),
1625+
torch.round(a.imag, decimals=decimals),
16811626
)
16821627
else:
16831628
# RuntimeError: "round_cpu" not implemented for 'int'
@@ -1690,7 +1635,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
16901635

16911636

16921637
def real_if_close(a: ArrayLike, tol=100):
1693-
# XXX: copies vs views; numpy seems to return a copy?
16941638
if not torch.is_complex(a):
16951639
return a
16961640
if tol > 1:
@@ -1703,47 +1647,49 @@ def real_if_close(a: ArrayLike, tol=100):
17031647
return a.real if mask.all() else a
17041648

17051649

1650+
def real(a: ArrayLike):
1651+
return torch.real(a)
1652+
1653+
1654+
def imag(a: ArrayLike):
1655+
if a.is_complex():
1656+
return a.imag
1657+
return torch.zeros_like(a)
1658+
1659+
17061660
def iscomplex(x: ArrayLike):
17071661
if torch.is_complex(x):
17081662
return x.imag != 0
1709-
result = torch.zeros_like(x, dtype=torch.bool)
1710-
if result.ndim == 0:
1711-
result = result.item()
1712-
return result
1663+
return torch.zeros_like(x, dtype=torch.bool)
17131664

17141665

17151666
def isreal(x: ArrayLike):
17161667
if torch.is_complex(x):
17171668
return x.imag == 0
1718-
result = torch.ones_like(x, dtype=torch.bool)
1719-
if result.ndim == 0:
1720-
result = result.item()
1721-
return result
1669+
return torch.ones_like(x, dtype=torch.bool)
17221670

17231671

17241672
def iscomplexobj(x: ArrayLike):
1725-
result = torch.is_complex(x)
1726-
return result
1673+
return torch.is_complex(x)
17271674

17281675

17291676
def isrealobj(x: ArrayLike):
17301677
return not torch.is_complex(x)
17311678

17321679

17331680
def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
1734-
return torch.isneginf(x, out=out)
1681+
return torch.isneginf(x)
17351682

17361683

17371684
def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
1738-
return torch.isposinf(x, out=out)
1685+
return torch.isposinf(x)
17391686

17401687

17411688
def i0(x: ArrayLike):
17421689
return torch.special.i0(x)
17431690

17441691

17451692
def isscalar(a):
1746-
# XXX: this is a stub
17471693
try:
17481694
t = normalize_array_like(a)
17491695
return t.numel() == 1
@@ -1798,8 +1744,6 @@ def bartlett(M):
17981744

17991745

18001746
def common_type(*tensors: ArrayLike):
1801-
import builtins
1802-
18031747
is_complex = False
18041748
precision = 0
18051749
for a in tensors:
@@ -1836,7 +1780,7 @@ def histogram(
18361780
is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
18371781
is_w_int = weights is None or not weights.dtype.is_floating_point
18381782
if is_a_int:
1839-
a = a.to(float)
1783+
a = a.double()
18401784

18411785
if weights is not None:
18421786
weights = _util.cast_if_needed(weights, a.dtype)
@@ -1856,8 +1800,8 @@ def histogram(
18561800
)
18571801

18581802
if not density and is_w_int:
1859-
h = h.to(int)
1803+
h = h.long()
18601804
if is_a_int:
1861-
b = b.to(int)
1805+
b = b.long()
18621806

18631807
return h, b

torch_np/_getlimits.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import contextlib
2+
13
import torch
24

35
from . import _dtypes
@@ -13,10 +15,6 @@ def iinfo(dtyp):
1315
return torch.iinfo(torch_dtype)
1416

1517

16-
import contextlib
17-
18-
19-
# FIXME: this is only a stub
2018
@contextlib.contextmanager
2119
def errstate(*args, **kwds):
2220
yield

0 commit comments

Comments
 (0)