@@ -110,7 +110,7 @@ def _concatenate(
110
110
tensors , axis = 0 , out = None , dtype = None , casting : Optional [CastingModes ] = "same_kind"
111
111
):
112
112
# pure torch implementation, used below and in cov/corrcoef below
113
- tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
113
+ tensors , axis = _util .axis_none_flatten (* tensors , axis = axis )
114
114
tensors = _concat_cast_helper (tensors , out , dtype , casting )
115
115
return torch .cat (tensors , axis )
116
116
@@ -903,21 +903,42 @@ def take(
903
903
out : Optional [OutArray ] = None ,
904
904
mode : NotImplementedType = "raise" ,
905
905
):
906
- (a ,), axis = _util .axis_none_ravel (a , axis = axis )
906
+ (a ,), axis = _util .axis_none_flatten (a , axis = axis )
907
907
axis = _util .normalize_axis_index (axis , a .ndim )
908
908
idx = (slice (None ),) * axis + (indices , ...)
909
909
result = a [idx ]
910
910
return result
911
911
912
912
913
913
def take_along_axis (arr : ArrayLike , indices : ArrayLike , axis ):
914
- (arr ,), axis = _util .axis_none_ravel (arr , axis = axis )
914
+ (arr ,), axis = _util .axis_none_flatten (arr , axis = axis )
915
915
axis = _util .normalize_axis_index (axis , arr .ndim )
916
916
return torch .take_along_dim (arr , indices , axis )
917
917
918
918
919
+ def put (
920
+ a : NDArray ,
921
+ ind : ArrayLike ,
922
+ v : ArrayLike ,
923
+ mode : NotImplementedType = "raise" ,
924
+ ):
925
+ v = v .type (a .dtype )
926
+ # If ind is larger than v, expand v to at least the size of ind. Any
927
+ # unnecessary trailing elements are then trimmed.
928
+ if ind .numel () > v .numel ():
929
+ ratio = (ind .numel () + v .numel () - 1 ) // v .numel ()
930
+ v = v .unsqueeze (0 ).expand ((ratio ,) + v .shape )
931
+ # Trim unnecessary elements, regarldess if v was expanded or not. Note
932
+ # np.put() trims v to match ind by default too.
933
+ if ind .numel () < v .numel ():
934
+ v = v .flatten ()
935
+ v = v [: ind .numel ()]
936
+ a .put_ (ind , v )
937
+ return None
938
+
939
+
919
940
def put_along_axis (arr : ArrayLike , indices : ArrayLike , values : ArrayLike , axis ):
920
- (arr ,), axis = _util .axis_none_ravel (arr , axis = axis )
941
+ (arr ,), axis = _util .axis_none_flatten (arr , axis = axis )
921
942
axis = _util .normalize_axis_index (axis , arr .ndim )
922
943
923
944
indices , values = torch .broadcast_tensors (indices , values )
@@ -939,9 +960,7 @@ def unique(
939
960
* ,
940
961
equal_nan : NotImplementedType = True ,
941
962
):
942
- if axis is None :
943
- ar = ar .ravel ()
944
- axis = 0
963
+ (ar ,), axis = _util .axis_none_flatten (ar , axis = axis )
945
964
axis = _util .normalize_axis_index (axis , ar .ndim )
946
965
947
966
is_half = ar .dtype == torch .float16
@@ -970,7 +989,7 @@ def argwhere(a: ArrayLike):
970
989
971
990
972
991
def flatnonzero (a : ArrayLike ):
973
- return torch .ravel (a ).nonzero (as_tuple = True )[0 ]
992
+ return torch .flatten (a ).nonzero (as_tuple = True )[0 ]
974
993
975
994
976
995
def clip (
@@ -1002,7 +1021,7 @@ def resize(a: ArrayLike, new_shape=None):
1002
1021
if isinstance (new_shape , int ):
1003
1022
new_shape = (new_shape ,)
1004
1023
1005
- a = ravel ( a )
1024
+ a = a . flatten ( )
1006
1025
1007
1026
new_size = 1
1008
1027
for dim_length in new_shape :
@@ -1020,38 +1039,6 @@ def resize(a: ArrayLike, new_shape=None):
1020
1039
return reshape (a , new_shape )
1021
1040
1022
1041
1023
- def _ndarray_resize (a : ArrayLike , new_shape , refcheck = False ):
1024
- # implementation of ndarray.resize.
1025
- # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
1026
- if refcheck :
1027
- raise NotImplementedError (
1028
- f"resize(..., refcheck={ refcheck } is not implemented."
1029
- )
1030
-
1031
- if new_shape in [(), (None ,)]:
1032
- return a
1033
-
1034
- # support both x.resize((2, 2)) and x.resize(2, 2)
1035
- if len (new_shape ) == 1 :
1036
- new_shape = new_shape [0 ]
1037
- if isinstance (new_shape , int ):
1038
- new_shape = (new_shape ,)
1039
-
1040
- a = ravel (a )
1041
-
1042
- if builtins .any (x < 0 for x in new_shape ):
1043
- raise ValueError ("all elements of `new_shape` must be non-negative" )
1044
-
1045
- new_numel = math .prod (new_shape )
1046
- if new_numel < a .numel ():
1047
- # shrink
1048
- return a [:new_numel ].reshape (new_shape )
1049
- else :
1050
- b = torch .zeros (new_numel )
1051
- b [: a .numel ()] = a
1052
- return b .reshape (new_shape )
1053
-
1054
-
1055
1042
# ### diag et al ###
1056
1043
1057
1044
@@ -1154,13 +1141,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
1154
1141
1155
1142
1156
1143
def vdot (a : ArrayLike , b : ArrayLike , / ):
1157
- # 1. torch only accepts 1D arrays, numpy ravels
1144
+ # 1. torch only accepts 1D arrays, numpy flattens
1158
1145
# 2. torch requires matching dtype, while numpy casts (?)
1159
1146
t_a , t_b = torch .atleast_1d (a , b )
1160
1147
if t_a .ndim > 1 :
1161
- t_a = t_a .ravel ()
1148
+ t_a = t_a .flatten ()
1162
1149
if t_b .ndim > 1 :
1163
- t_b = t_b .ravel ()
1150
+ t_b = t_b .flatten ()
1164
1151
1165
1152
dtype = _dtypes_impl .result_type_impl ((t_a .dtype , t_b .dtype ))
1166
1153
is_half = dtype == torch .float16
@@ -1310,7 +1297,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
1310
1297
1311
1298
1312
1299
def _sort_helper (tensor , axis , kind , order ):
1313
- (tensor ,), axis = _util .axis_none_ravel (tensor , axis = axis )
1300
+ (tensor ,), axis = _util .axis_none_flatten (tensor , axis = axis )
1314
1301
axis = _util .normalize_axis_index (axis , tensor .ndim )
1315
1302
1316
1303
stable = kind == "stable"
@@ -1426,14 +1413,6 @@ def transpose(a: ArrayLike, axes=None):
1426
1413
1427
1414
1428
1415
def ravel (a : ArrayLike , order : NotImplementedType = "C" ):
1429
- return torch .ravel (a )
1430
-
1431
-
1432
- # leading underscore since arr.flatten exists but np.flatten does not
1433
-
1434
-
1435
- def _flatten (a : ArrayLike , order : NotImplementedType = "C" ):
1436
- # may return a copy
1437
1416
return torch .flatten (a )
1438
1417
1439
1418
@@ -1745,7 +1724,7 @@ def diff(
1745
1724
def angle (z : ArrayLike , deg = False ):
1746
1725
result = torch .angle (z )
1747
1726
if deg :
1748
- result = result * 180 / torch .pi
1727
+ result = result * ( 180 / torch .pi )
1749
1728
return result
1750
1729
1751
1730
@@ -1756,26 +1735,14 @@ def sinc(x: ArrayLike):
1756
1735
# ### Type/shape etc queries ###
1757
1736
1758
1737
1759
- def real (a : ArrayLike ):
1760
- return torch .real (a )
1761
-
1762
-
1763
- def imag (a : ArrayLike ):
1764
- if a .is_complex ():
1765
- result = a .imag
1766
- else :
1767
- result = torch .zeros_like (a )
1768
- return result
1769
-
1770
-
1771
1738
def round (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1772
1739
if a .is_floating_point ():
1773
1740
result = torch .round (a , decimals = decimals )
1774
1741
elif a .is_complex ():
1775
1742
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1776
- result = (
1777
- torch .round (a .real , decimals = decimals )
1778
- + torch .round (a .imag , decimals = decimals ) * 1j
1743
+ result = torch . complex (
1744
+ torch .round (a .real , decimals = decimals ),
1745
+ torch .round (a .imag , decimals = decimals ),
1779
1746
)
1780
1747
else :
1781
1748
# RuntimeError: "round_cpu" not implemented for 'int'
@@ -1788,7 +1755,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
1788
1755
1789
1756
1790
1757
def real_if_close (a : ArrayLike , tol = 100 ):
1791
- # XXX: copies vs views; numpy seems to return a copy?
1792
1758
if not torch .is_complex (a ):
1793
1759
return a
1794
1760
if tol > 1 :
@@ -1801,47 +1767,49 @@ def real_if_close(a: ArrayLike, tol=100):
1801
1767
return a .real if mask .all () else a
1802
1768
1803
1769
1770
+ def real (a : ArrayLike ):
1771
+ return torch .real (a )
1772
+
1773
+
1774
+ def imag (a : ArrayLike ):
1775
+ if a .is_complex ():
1776
+ return a .imag
1777
+ return torch .zeros_like (a )
1778
+
1779
+
1804
1780
def iscomplex (x : ArrayLike ):
1805
1781
if torch .is_complex (x ):
1806
1782
return x .imag != 0
1807
- result = torch .zeros_like (x , dtype = torch .bool )
1808
- if result .ndim == 0 :
1809
- result = result .item ()
1810
- return result
1783
+ return torch .zeros_like (x , dtype = torch .bool )
1811
1784
1812
1785
1813
1786
def isreal (x : ArrayLike ):
1814
1787
if torch .is_complex (x ):
1815
1788
return x .imag == 0
1816
- result = torch .ones_like (x , dtype = torch .bool )
1817
- if result .ndim == 0 :
1818
- result = result .item ()
1819
- return result
1789
+ return torch .ones_like (x , dtype = torch .bool )
1820
1790
1821
1791
1822
1792
def iscomplexobj (x : ArrayLike ):
1823
- result = torch .is_complex (x )
1824
- return result
1793
+ return torch .is_complex (x )
1825
1794
1826
1795
1827
1796
def isrealobj (x : ArrayLike ):
1828
1797
return not torch .is_complex (x )
1829
1798
1830
1799
1831
1800
def isneginf (x : ArrayLike , out : Optional [OutArray ] = None ):
1832
- return torch .isneginf (x , out = out )
1801
+ return torch .isneginf (x )
1833
1802
1834
1803
1835
1804
def isposinf (x : ArrayLike , out : Optional [OutArray ] = None ):
1836
- return torch .isposinf (x , out = out )
1805
+ return torch .isposinf (x )
1837
1806
1838
1807
1839
1808
def i0 (x : ArrayLike ):
1840
1809
return torch .special .i0 (x )
1841
1810
1842
1811
1843
1812
def isscalar (a ):
1844
- # XXX: this is a stub
1845
1813
try :
1846
1814
t = normalize_array_like (a )
1847
1815
return t .numel () == 1
@@ -1932,7 +1900,7 @@ def histogram(
1932
1900
is_a_int = not (a .dtype .is_floating_point or a .dtype .is_complex )
1933
1901
is_w_int = weights is None or not weights .dtype .is_floating_point
1934
1902
if is_a_int :
1935
- a = a .to ( float )
1903
+ a = a .double ( )
1936
1904
1937
1905
if weights is not None :
1938
1906
weights = _util .cast_if_needed (weights , a .dtype )
@@ -1952,8 +1920,8 @@ def histogram(
1952
1920
)
1953
1921
1954
1922
if not density and is_w_int :
1955
- h = h .to ( int )
1923
+ h = h .long ( )
1956
1924
if is_a_int :
1957
- b = b .to ( int )
1925
+ b = b .long ( )
1958
1926
1959
1927
return h , b
0 commit comments