8
8
from __future__ import annotations
9
9
10
10
import builtins
11
- import math
12
11
import operator
13
12
from typing import Optional , Sequence
14
13
@@ -100,7 +99,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
100
99
101
100
def _concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
102
101
# pure torch implementation, used below and in cov/corrcoef below
103
- tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
102
+ tensors , axis = _util .axis_none_flatten (* tensors , axis = axis )
104
103
tensors = _concat_cast_helper (tensors , out , dtype , casting )
105
104
return torch .cat (tensors , axis )
106
105
@@ -881,21 +880,21 @@ def take(
881
880
out : Optional [OutArray ] = None ,
882
881
mode : NotImplementedType = "raise" ,
883
882
):
884
- (a ,), axis = _util .axis_none_ravel (a , axis = axis )
883
+ (a ,), axis = _util .axis_none_flatten (a , axis = axis )
885
884
axis = _util .normalize_axis_index (axis , a .ndim )
886
885
idx = (slice (None ),) * axis + (indices , ...)
887
886
result = a [idx ]
888
887
return result
889
888
890
889
891
890
def take_along_axis (arr : ArrayLike , indices : ArrayLike , axis ):
892
- (arr ,), axis = _util .axis_none_ravel (arr , axis = axis )
891
+ (arr ,), axis = _util .axis_none_flatten (arr , axis = axis )
893
892
axis = _util .normalize_axis_index (axis , arr .ndim )
894
893
return torch .take_along_dim (arr , indices , axis )
895
894
896
895
897
896
def put_along_axis (arr : ArrayLike , indices : ArrayLike , values : ArrayLike , axis ):
898
- (arr ,), axis = _util .axis_none_ravel (arr , axis = axis )
897
+ (arr ,), axis = _util .axis_none_flatten (arr , axis = axis )
899
898
axis = _util .normalize_axis_index (axis , arr .ndim )
900
899
901
900
indices , values = torch .broadcast_tensors (indices , values )
@@ -917,9 +916,7 @@ def unique(
917
916
* ,
918
917
equal_nan : NotImplementedType = True ,
919
918
):
920
- if axis is None :
921
- ar = ar .ravel ()
922
- axis = 0
919
+ (ar ,), axis = _util .axis_none_flatten (ar , axis = axis )
923
920
axis = _util .normalize_axis_index (axis , ar .ndim )
924
921
925
922
is_half = ar .dtype == torch .float16
@@ -948,7 +945,7 @@ def argwhere(a: ArrayLike):
948
945
949
946
950
947
def flatnonzero (a : ArrayLike ):
951
- return torch .ravel (a ).nonzero (as_tuple = True )[0 ]
948
+ return torch .flatten (a ).nonzero (as_tuple = True )[0 ]
952
949
953
950
954
951
def clip (
@@ -980,7 +977,7 @@ def resize(a: ArrayLike, new_shape=None):
980
977
if isinstance (new_shape , int ):
981
978
new_shape = (new_shape ,)
982
979
983
- a = ravel ( a )
980
+ a = a . flatten ( )
984
981
985
982
new_size = 1
986
983
for dim_length in new_shape :
@@ -998,38 +995,6 @@ def resize(a: ArrayLike, new_shape=None):
998
995
return reshape (a , new_shape )
999
996
1000
997
1001
- def _ndarray_resize (a : ArrayLike , new_shape , refcheck = False ):
1002
- # implementation of ndarray.resize.
1003
- # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
1004
- if refcheck :
1005
- raise NotImplementedError (
1006
- f"resize(..., refcheck={ refcheck } is not implemented."
1007
- )
1008
-
1009
- if new_shape in [(), (None ,)]:
1010
- return a
1011
-
1012
- # support both x.resize((2, 2)) and x.resize(2, 2)
1013
- if len (new_shape ) == 1 :
1014
- new_shape = new_shape [0 ]
1015
- if isinstance (new_shape , int ):
1016
- new_shape = (new_shape ,)
1017
-
1018
- a = ravel (a )
1019
-
1020
- if builtins .any (x < 0 for x in new_shape ):
1021
- raise ValueError ("all elements of `new_shape` must be non-negative" )
1022
-
1023
- new_numel = math .prod (new_shape )
1024
- if new_numel < a .numel ():
1025
- # shrink
1026
- return a [:new_numel ].reshape (new_shape )
1027
- else :
1028
- b = torch .zeros (new_numel )
1029
- b [: a .numel ()] = a
1030
- return b .reshape (new_shape )
1031
-
1032
-
1033
998
# ### diag et al ###
1034
999
1035
1000
@@ -1132,13 +1097,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
1132
1097
1133
1098
1134
1099
def vdot (a : ArrayLike , b : ArrayLike , / ):
1135
- # 1. torch only accepts 1D arrays, numpy ravels
1100
+ # 1. torch only accepts 1D arrays, numpy flattens
1136
1101
# 2. torch requires matching dtype, while numpy casts (?)
1137
1102
t_a , t_b = torch .atleast_1d (a , b )
1138
1103
if t_a .ndim > 1 :
1139
- t_a = t_a .ravel ()
1104
+ t_a = t_a .flatten ()
1140
1105
if t_b .ndim > 1 :
1141
- t_b = t_b .ravel ()
1106
+ t_b = t_b .flatten ()
1142
1107
1143
1108
dtype = _dtypes_impl .result_type_impl ((t_a .dtype , t_b .dtype ))
1144
1109
is_half = dtype == torch .float16
@@ -1212,7 +1177,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1212
1177
1213
1178
1214
1179
def _sort_helper (tensor , axis , kind , order ):
1215
- (tensor ,), axis = _util .axis_none_ravel (tensor , axis = axis )
1180
+ (tensor ,), axis = _util .axis_none_flatten (tensor , axis = axis )
1216
1181
axis = _util .normalize_axis_index (axis , tensor .ndim )
1217
1182
1218
1183
stable = kind == "stable"
@@ -1328,14 +1293,6 @@ def transpose(a: ArrayLike, axes=None):
1328
1293
1329
1294
1330
1295
def ravel (a : ArrayLike , order : NotImplementedType = "C" ):
1331
- return torch .ravel (a )
1332
-
1333
-
1334
- # leading underscore since arr.flatten exists but np.flatten does not
1335
-
1336
-
1337
- def _flatten (a : ArrayLike , order : NotImplementedType = "C" ):
1338
- # may return a copy
1339
1296
return torch .flatten (a )
1340
1297
1341
1298
@@ -1647,7 +1604,7 @@ def diff(
1647
1604
def angle (z : ArrayLike , deg = False ):
1648
1605
result = torch .angle (z )
1649
1606
if deg :
1650
- result = result * 180 / torch .pi
1607
+ result = result * ( 180 / torch .pi )
1651
1608
return result
1652
1609
1653
1610
@@ -1658,26 +1615,14 @@ def sinc(x: ArrayLike):
1658
1615
# ### Type/shape etc queries ###
1659
1616
1660
1617
1661
- def real (a : ArrayLike ):
1662
- return torch .real (a )
1663
-
1664
-
1665
- def imag (a : ArrayLike ):
1666
- if a .is_complex ():
1667
- result = a .imag
1668
- else :
1669
- result = torch .zeros_like (a )
1670
- return result
1671
-
1672
-
1673
1618
def round (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1674
1619
if a .is_floating_point ():
1675
1620
result = torch .round (a , decimals = decimals )
1676
1621
elif a .is_complex ():
1677
1622
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1678
- result = (
1679
- torch .round (a .real , decimals = decimals )
1680
- + torch .round (a .imag , decimals = decimals ) * 1j
1623
+ result = torch . complex (
1624
+ torch .round (a .real , decimals = decimals ),
1625
+ torch .round (a .imag , decimals = decimals ),
1681
1626
)
1682
1627
else :
1683
1628
# RuntimeError: "round_cpu" not implemented for 'int'
@@ -1690,7 +1635,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
1690
1635
1691
1636
1692
1637
def real_if_close (a : ArrayLike , tol = 100 ):
1693
- # XXX: copies vs views; numpy seems to return a copy?
1694
1638
if not torch .is_complex (a ):
1695
1639
return a
1696
1640
if tol > 1 :
@@ -1703,47 +1647,49 @@ def real_if_close(a: ArrayLike, tol=100):
1703
1647
return a .real if mask .all () else a
1704
1648
1705
1649
1650
+ def real (a : ArrayLike ):
1651
+ return torch .real (a )
1652
+
1653
+
1654
+ def imag (a : ArrayLike ):
1655
+ if a .is_complex ():
1656
+ return a .imag
1657
+ return torch .zeros_like (a )
1658
+
1659
+
1706
1660
def iscomplex (x : ArrayLike ):
1707
1661
if torch .is_complex (x ):
1708
1662
return x .imag != 0
1709
- result = torch .zeros_like (x , dtype = torch .bool )
1710
- if result .ndim == 0 :
1711
- result = result .item ()
1712
- return result
1663
+ return torch .zeros_like (x , dtype = torch .bool )
1713
1664
1714
1665
1715
1666
def isreal (x : ArrayLike ):
1716
1667
if torch .is_complex (x ):
1717
1668
return x .imag == 0
1718
- result = torch .ones_like (x , dtype = torch .bool )
1719
- if result .ndim == 0 :
1720
- result = result .item ()
1721
- return result
1669
+ return torch .ones_like (x , dtype = torch .bool )
1722
1670
1723
1671
1724
1672
def iscomplexobj (x : ArrayLike ):
1725
- result = torch .is_complex (x )
1726
- return result
1673
+ return torch .is_complex (x )
1727
1674
1728
1675
1729
1676
def isrealobj (x : ArrayLike ):
1730
1677
return not torch .is_complex (x )
1731
1678
1732
1679
1733
1680
def isneginf (x : ArrayLike , out : Optional [OutArray ] = None ):
1734
- return torch .isneginf (x , out = out )
1681
+ return torch .isneginf (x )
1735
1682
1736
1683
1737
1684
def isposinf (x : ArrayLike , out : Optional [OutArray ] = None ):
1738
- return torch .isposinf (x , out = out )
1685
+ return torch .isposinf (x )
1739
1686
1740
1687
1741
1688
def i0 (x : ArrayLike ):
1742
1689
return torch .special .i0 (x )
1743
1690
1744
1691
1745
1692
def isscalar (a ):
1746
- # XXX: this is a stub
1747
1693
try :
1748
1694
t = normalize_array_like (a )
1749
1695
return t .numel () == 1
@@ -1798,8 +1744,6 @@ def bartlett(M):
1798
1744
1799
1745
1800
1746
def common_type (* tensors : ArrayLike ):
1801
- import builtins
1802
-
1803
1747
is_complex = False
1804
1748
precision = 0
1805
1749
for a in tensors :
@@ -1836,7 +1780,7 @@ def histogram(
1836
1780
is_a_int = not (a .dtype .is_floating_point or a .dtype .is_complex )
1837
1781
is_w_int = weights is None or not weights .dtype .is_floating_point
1838
1782
if is_a_int :
1839
- a = a .to ( float )
1783
+ a = a .double ( )
1840
1784
1841
1785
if weights is not None :
1842
1786
weights = _util .cast_if_needed (weights , a .dtype )
@@ -1856,8 +1800,8 @@ def histogram(
1856
1800
)
1857
1801
1858
1802
if not density and is_w_int :
1859
- h = h .to ( int )
1803
+ h = h .long ( )
1860
1804
if is_a_int :
1861
- b = b .to ( int )
1805
+ b = b .long ( )
1862
1806
1863
1807
return h , b
0 commit comments