Skip to content

Commit 5054611

Browse files
authored
Merge branch 'main' into einsum_tests
2 parents cd5f74a + 9c627db commit 5054611

File tree

14 files changed

+325
-376
lines changed

14 files changed

+325
-376
lines changed

autogen/gen_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import numpy as np
99
import torch
1010

11+
np._set_promotion_state("weak")
12+
1113

1214
class dtype:
1315
def __init__(self, name):

torch_np/_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def sctype_from_string(s):
243243
return _aliases[s]
244244
if s in _python_types:
245245
return _python_types[s]
246-
raise TypeError(f"data type '{s}' not understood")
246+
raise TypeError(f"data type {s!r} not understood")
247247

248248

249249
def sctype_from_torch_dtype(torch_dtype):

torch_np/_funcs.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,8 @@
1919
if inspect.isfunction(getattr(_funcs_impl, x)) and not x.startswith("_")
2020
]
2121

22-
# these implement ndarray methods but need not be public functions
23-
semi_private = [
24-
"_flatten",
25-
"_ndarray_resize",
26-
]
27-
28-
2922
# decorate implementer functions with argument normalizers and export to the top namespace
30-
for name in __all__ + semi_private:
23+
for name in __all__:
3124
func = getattr(_funcs_impl, name)
3225
if name in ["percentile", "quantile", "median"]:
3326
decorated = normalizer(func, promote_scalar_result=True)

torch_np/_funcs_impl.py

Lines changed: 54 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _concatenate(
110110
tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
111111
):
112112
# pure torch implementation, used below and in cov/corrcoef below
113-
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
113+
tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
114114
tensors = _concat_cast_helper(tensors, out, dtype, casting)
115115
return torch.cat(tensors, axis)
116116

@@ -903,21 +903,42 @@ def take(
903903
out: Optional[OutArray] = None,
904904
mode: NotImplementedType = "raise",
905905
):
906-
(a,), axis = _util.axis_none_ravel(a, axis=axis)
906+
(a,), axis = _util.axis_none_flatten(a, axis=axis)
907907
axis = _util.normalize_axis_index(axis, a.ndim)
908908
idx = (slice(None),) * axis + (indices, ...)
909909
result = a[idx]
910910
return result
911911

912912

913913
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
914-
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
914+
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
915915
axis = _util.normalize_axis_index(axis, arr.ndim)
916916
return torch.take_along_dim(arr, indices, axis)
917917

918918

919+
def put(
920+
a: NDArray,
921+
ind: ArrayLike,
922+
v: ArrayLike,
923+
mode: NotImplementedType = "raise",
924+
):
925+
v = v.type(a.dtype)
926+
# If ind is larger than v, expand v to at least the size of ind. Any
927+
# unnecessary trailing elements are then trimmed.
928+
if ind.numel() > v.numel():
929+
ratio = (ind.numel() + v.numel() - 1) // v.numel()
930+
v = v.unsqueeze(0).expand((ratio,) + v.shape)
931+
# Trim unnecessary elements, regarldess if v was expanded or not. Note
932+
# np.put() trims v to match ind by default too.
933+
if ind.numel() < v.numel():
934+
v = v.flatten()
935+
v = v[: ind.numel()]
936+
a.put_(ind, v)
937+
return None
938+
939+
919940
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
920-
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)
941+
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
921942
axis = _util.normalize_axis_index(axis, arr.ndim)
922943

923944
indices, values = torch.broadcast_tensors(indices, values)
@@ -939,9 +960,7 @@ def unique(
939960
*,
940961
equal_nan: NotImplementedType = True,
941962
):
942-
if axis is None:
943-
ar = ar.ravel()
944-
axis = 0
963+
(ar,), axis = _util.axis_none_flatten(ar, axis=axis)
945964
axis = _util.normalize_axis_index(axis, ar.ndim)
946965

947966
is_half = ar.dtype == torch.float16
@@ -970,7 +989,7 @@ def argwhere(a: ArrayLike):
970989

971990

972991
def flatnonzero(a: ArrayLike):
973-
return torch.ravel(a).nonzero(as_tuple=True)[0]
992+
return torch.flatten(a).nonzero(as_tuple=True)[0]
974993

975994

976995
def clip(
@@ -1002,7 +1021,7 @@ def resize(a: ArrayLike, new_shape=None):
10021021
if isinstance(new_shape, int):
10031022
new_shape = (new_shape,)
10041023

1005-
a = ravel(a)
1024+
a = a.flatten()
10061025

10071026
new_size = 1
10081027
for dim_length in new_shape:
@@ -1020,38 +1039,6 @@ def resize(a: ArrayLike, new_shape=None):
10201039
return reshape(a, new_shape)
10211040

10221041

1023-
def _ndarray_resize(a: ArrayLike, new_shape, refcheck=False):
1024-
# implementation of ndarray.resize.
1025-
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
1026-
if refcheck:
1027-
raise NotImplementedError(
1028-
f"resize(..., refcheck={refcheck} is not implemented."
1029-
)
1030-
1031-
if new_shape in [(), (None,)]:
1032-
return a
1033-
1034-
# support both x.resize((2, 2)) and x.resize(2, 2)
1035-
if len(new_shape) == 1:
1036-
new_shape = new_shape[0]
1037-
if isinstance(new_shape, int):
1038-
new_shape = (new_shape,)
1039-
1040-
a = ravel(a)
1041-
1042-
if builtins.any(x < 0 for x in new_shape):
1043-
raise ValueError("all elements of `new_shape` must be non-negative")
1044-
1045-
new_numel = math.prod(new_shape)
1046-
if new_numel < a.numel():
1047-
# shrink
1048-
return a[:new_numel].reshape(new_shape)
1049-
else:
1050-
b = torch.zeros(new_numel)
1051-
b[: a.numel()] = a
1052-
return b.reshape(new_shape)
1053-
1054-
10551042
# ### diag et al ###
10561043

10571044

@@ -1154,13 +1141,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
11541141

11551142

11561143
def vdot(a: ArrayLike, b: ArrayLike, /):
1157-
# 1. torch only accepts 1D arrays, numpy ravels
1144+
# 1. torch only accepts 1D arrays, numpy flattens
11581145
# 2. torch requires matching dtype, while numpy casts (?)
11591146
t_a, t_b = torch.atleast_1d(a, b)
11601147
if t_a.ndim > 1:
1161-
t_a = t_a.ravel()
1148+
t_a = t_a.flatten()
11621149
if t_b.ndim > 1:
1163-
t_b = t_b.ravel()
1150+
t_b = t_b.flatten()
11641151

11651152
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
11661153
is_half = dtype == torch.float16
@@ -1310,7 +1297,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
13101297

13111298

13121299
def _sort_helper(tensor, axis, kind, order):
1313-
(tensor,), axis = _util.axis_none_ravel(tensor, axis=axis)
1300+
(tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
13141301
axis = _util.normalize_axis_index(axis, tensor.ndim)
13151302

13161303
stable = kind == "stable"
@@ -1426,14 +1413,6 @@ def transpose(a: ArrayLike, axes=None):
14261413

14271414

14281415
def ravel(a: ArrayLike, order: NotImplementedType = "C"):
1429-
return torch.ravel(a)
1430-
1431-
1432-
# leading underscore since arr.flatten exists but np.flatten does not
1433-
1434-
1435-
def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
1436-
# may return a copy
14371416
return torch.flatten(a)
14381417

14391418

@@ -1745,7 +1724,7 @@ def diff(
17451724
def angle(z: ArrayLike, deg=False):
17461725
result = torch.angle(z)
17471726
if deg:
1748-
result = result * 180 / torch.pi
1727+
result = result * (180 / torch.pi)
17491728
return result
17501729

17511730

@@ -1756,26 +1735,14 @@ def sinc(x: ArrayLike):
17561735
# ### Type/shape etc queries ###
17571736

17581737

1759-
def real(a: ArrayLike):
1760-
return torch.real(a)
1761-
1762-
1763-
def imag(a: ArrayLike):
1764-
if a.is_complex():
1765-
result = a.imag
1766-
else:
1767-
result = torch.zeros_like(a)
1768-
return result
1769-
1770-
17711738
def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
17721739
if a.is_floating_point():
17731740
result = torch.round(a, decimals=decimals)
17741741
elif a.is_complex():
17751742
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1776-
result = (
1777-
torch.round(a.real, decimals=decimals)
1778-
+ torch.round(a.imag, decimals=decimals) * 1j
1743+
result = torch.complex(
1744+
torch.round(a.real, decimals=decimals),
1745+
torch.round(a.imag, decimals=decimals),
17791746
)
17801747
else:
17811748
# RuntimeError: "round_cpu" not implemented for 'int'
@@ -1788,7 +1755,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
17881755

17891756

17901757
def real_if_close(a: ArrayLike, tol=100):
1791-
# XXX: copies vs views; numpy seems to return a copy?
17921758
if not torch.is_complex(a):
17931759
return a
17941760
if tol > 1:
@@ -1801,47 +1767,49 @@ def real_if_close(a: ArrayLike, tol=100):
18011767
return a.real if mask.all() else a
18021768

18031769

1770+
def real(a: ArrayLike):
1771+
return torch.real(a)
1772+
1773+
1774+
def imag(a: ArrayLike):
1775+
if a.is_complex():
1776+
return a.imag
1777+
return torch.zeros_like(a)
1778+
1779+
18041780
def iscomplex(x: ArrayLike):
18051781
if torch.is_complex(x):
18061782
return x.imag != 0
1807-
result = torch.zeros_like(x, dtype=torch.bool)
1808-
if result.ndim == 0:
1809-
result = result.item()
1810-
return result
1783+
return torch.zeros_like(x, dtype=torch.bool)
18111784

18121785

18131786
def isreal(x: ArrayLike):
18141787
if torch.is_complex(x):
18151788
return x.imag == 0
1816-
result = torch.ones_like(x, dtype=torch.bool)
1817-
if result.ndim == 0:
1818-
result = result.item()
1819-
return result
1789+
return torch.ones_like(x, dtype=torch.bool)
18201790

18211791

18221792
def iscomplexobj(x: ArrayLike):
1823-
result = torch.is_complex(x)
1824-
return result
1793+
return torch.is_complex(x)
18251794

18261795

18271796
def isrealobj(x: ArrayLike):
18281797
return not torch.is_complex(x)
18291798

18301799

18311800
def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
1832-
return torch.isneginf(x, out=out)
1801+
return torch.isneginf(x)
18331802

18341803

18351804
def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
1836-
return torch.isposinf(x, out=out)
1805+
return torch.isposinf(x)
18371806

18381807

18391808
def i0(x: ArrayLike):
18401809
return torch.special.i0(x)
18411810

18421811

18431812
def isscalar(a):
1844-
# XXX: this is a stub
18451813
try:
18461814
t = normalize_array_like(a)
18471815
return t.numel() == 1
@@ -1932,7 +1900,7 @@ def histogram(
19321900
is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
19331901
is_w_int = weights is None or not weights.dtype.is_floating_point
19341902
if is_a_int:
1935-
a = a.to(float)
1903+
a = a.double()
19361904

19371905
if weights is not None:
19381906
weights = _util.cast_if_needed(weights, a.dtype)
@@ -1952,8 +1920,8 @@ def histogram(
19521920
)
19531921

19541922
if not density and is_w_int:
1955-
h = h.to(int)
1923+
h = h.long()
19561924
if is_a_int:
1957-
b = b.to(int)
1925+
b = b.long()
19581926

19591927
return h, b

torch_np/_getlimits.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import contextlib
2+
13
import torch
24

35
from . import _dtypes
@@ -13,10 +15,6 @@ def iinfo(dtyp):
1315
return torch.iinfo(torch_dtype)
1416

1517

16-
import contextlib
17-
18-
19-
# FIXME: this is only a stub
2018
@contextlib.contextmanager
2119
def errstate(*args, **kwds):
2220
yield

0 commit comments

Comments
 (0)