8
8
from __future__ import annotations
9
9
10
10
import builtins
11
- import math
12
11
import operator
13
12
from typing import Optional , Sequence
14
13
@@ -100,7 +99,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
100
99
101
100
def _concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
102
101
# pure torch implementation, used below and in cov/corrcoef below
103
- tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
102
+ tensors , axis = _util .axis_none_flatten (* tensors , axis = axis )
104
103
tensors = _concat_cast_helper (tensors , out , dtype , casting )
105
104
return torch .cat (tensors , axis )
106
105
@@ -881,15 +880,15 @@ def take(
881
880
out : Optional [OutArray ] = None ,
882
881
mode : NotImplementedType = "raise" ,
883
882
):
884
- (a ,), axis = _util .axis_none_ravel (a , axis = axis )
883
+ (a ,), axis = _util .axis_none_flatten (a , axis = axis )
885
884
axis = _util .normalize_axis_index (axis , a .ndim )
886
885
idx = (slice (None ),) * axis + (indices , ...)
887
886
result = a [idx ]
888
887
return result
889
888
890
889
891
890
def take_along_axis (arr : ArrayLike , indices : ArrayLike , axis ):
892
- (arr ,), axis = _util .axis_none_ravel (arr , axis = axis )
891
+ (arr ,), axis = _util .axis_none_flatten (arr , axis = axis )
893
892
axis = _util .normalize_axis_index (axis , arr .ndim )
894
893
return torch .take_along_dim (arr , indices , axis )
895
894
@@ -916,7 +915,7 @@ def put(
916
915
917
916
918
917
def put_along_axis (arr : ArrayLike , indices : ArrayLike , values : ArrayLike , axis ):
919
- (arr ,), axis = _util .axis_none_ravel (arr , axis = axis )
918
+ (arr ,), axis = _util .axis_none_flatten (arr , axis = axis )
920
919
axis = _util .normalize_axis_index (axis , arr .ndim )
921
920
922
921
indices , values = torch .broadcast_tensors (indices , values )
@@ -938,9 +937,7 @@ def unique(
938
937
* ,
939
938
equal_nan : NotImplementedType = True ,
940
939
):
941
- if axis is None :
942
- ar = ar .ravel ()
943
- axis = 0
940
+ (ar ,), axis = _util .axis_none_flatten (ar , axis = axis )
944
941
axis = _util .normalize_axis_index (axis , ar .ndim )
945
942
946
943
is_half = ar .dtype == torch .float16
@@ -969,7 +966,7 @@ def argwhere(a: ArrayLike):
969
966
970
967
971
968
def flatnonzero (a : ArrayLike ):
972
- return torch .ravel (a ).nonzero (as_tuple = True )[0 ]
969
+ return torch .flatten (a ).nonzero (as_tuple = True )[0 ]
973
970
974
971
975
972
def clip (
@@ -1001,7 +998,7 @@ def resize(a: ArrayLike, new_shape=None):
1001
998
if isinstance (new_shape , int ):
1002
999
new_shape = (new_shape ,)
1003
1000
1004
- a = ravel ( a )
1001
+ a = a . flatten ( )
1005
1002
1006
1003
new_size = 1
1007
1004
for dim_length in new_shape :
@@ -1019,38 +1016,6 @@ def resize(a: ArrayLike, new_shape=None):
1019
1016
return reshape (a , new_shape )
1020
1017
1021
1018
1022
- def _ndarray_resize (a : ArrayLike , new_shape , refcheck = False ):
1023
- # implementation of ndarray.resize.
1024
- # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
1025
- if refcheck :
1026
- raise NotImplementedError (
1027
- f"resize(..., refcheck={ refcheck } is not implemented."
1028
- )
1029
-
1030
- if new_shape in [(), (None ,)]:
1031
- return a
1032
-
1033
- # support both x.resize((2, 2)) and x.resize(2, 2)
1034
- if len (new_shape ) == 1 :
1035
- new_shape = new_shape [0 ]
1036
- if isinstance (new_shape , int ):
1037
- new_shape = (new_shape ,)
1038
-
1039
- a = ravel (a )
1040
-
1041
- if builtins .any (x < 0 for x in new_shape ):
1042
- raise ValueError ("all elements of `new_shape` must be non-negative" )
1043
-
1044
- new_numel = math .prod (new_shape )
1045
- if new_numel < a .numel ():
1046
- # shrink
1047
- return a [:new_numel ].reshape (new_shape )
1048
- else :
1049
- b = torch .zeros (new_numel )
1050
- b [: a .numel ()] = a
1051
- return b .reshape (new_shape )
1052
-
1053
-
1054
1019
# ### diag et al ###
1055
1020
1056
1021
@@ -1153,13 +1118,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
1153
1118
1154
1119
1155
1120
def vdot (a : ArrayLike , b : ArrayLike , / ):
1156
- # 1. torch only accepts 1D arrays, numpy ravels
1121
+ # 1. torch only accepts 1D arrays, numpy flattens
1157
1122
# 2. torch requires matching dtype, while numpy casts (?)
1158
1123
t_a , t_b = torch .atleast_1d (a , b )
1159
1124
if t_a .ndim > 1 :
1160
- t_a = t_a .ravel ()
1125
+ t_a = t_a .flatten ()
1161
1126
if t_b .ndim > 1 :
1162
- t_b = t_b .ravel ()
1127
+ t_b = t_b .flatten ()
1163
1128
1164
1129
dtype = _dtypes_impl .result_type_impl ((t_a .dtype , t_b .dtype ))
1165
1130
is_half = dtype == torch .float16
@@ -1233,7 +1198,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1233
1198
1234
1199
1235
1200
def _sort_helper (tensor , axis , kind , order ):
1236
- (tensor ,), axis = _util .axis_none_ravel (tensor , axis = axis )
1201
+ (tensor ,), axis = _util .axis_none_flatten (tensor , axis = axis )
1237
1202
axis = _util .normalize_axis_index (axis , tensor .ndim )
1238
1203
1239
1204
stable = kind == "stable"
@@ -1349,14 +1314,6 @@ def transpose(a: ArrayLike, axes=None):
1349
1314
1350
1315
1351
1316
def ravel (a : ArrayLike , order : NotImplementedType = "C" ):
1352
- return torch .ravel (a )
1353
-
1354
-
1355
- # leading underscore since arr.flatten exists but np.flatten does not
1356
-
1357
-
1358
- def _flatten (a : ArrayLike , order : NotImplementedType = "C" ):
1359
- # may return a copy
1360
1317
return torch .flatten (a )
1361
1318
1362
1319
@@ -1668,7 +1625,7 @@ def diff(
1668
1625
def angle (z : ArrayLike , deg = False ):
1669
1626
result = torch .angle (z )
1670
1627
if deg :
1671
- result = result * 180 / torch .pi
1628
+ result = result * ( 180 / torch .pi )
1672
1629
return result
1673
1630
1674
1631
@@ -1679,26 +1636,14 @@ def sinc(x: ArrayLike):
1679
1636
# ### Type/shape etc queries ###
1680
1637
1681
1638
1682
- def real (a : ArrayLike ):
1683
- return torch .real (a )
1684
-
1685
-
1686
- def imag (a : ArrayLike ):
1687
- if a .is_complex ():
1688
- result = a .imag
1689
- else :
1690
- result = torch .zeros_like (a )
1691
- return result
1692
-
1693
-
1694
1639
def round (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1695
1640
if a .is_floating_point ():
1696
1641
result = torch .round (a , decimals = decimals )
1697
1642
elif a .is_complex ():
1698
1643
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1699
- result = (
1700
- torch .round (a .real , decimals = decimals )
1701
- + torch .round (a .imag , decimals = decimals ) * 1j
1644
+ result = torch . complex (
1645
+ torch .round (a .real , decimals = decimals ),
1646
+ torch .round (a .imag , decimals = decimals ),
1702
1647
)
1703
1648
else :
1704
1649
# RuntimeError: "round_cpu" not implemented for 'int'
@@ -1711,7 +1656,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
1711
1656
1712
1657
1713
1658
def real_if_close (a : ArrayLike , tol = 100 ):
1714
- # XXX: copies vs views; numpy seems to return a copy?
1715
1659
if not torch .is_complex (a ):
1716
1660
return a
1717
1661
if tol > 1 :
@@ -1724,47 +1668,49 @@ def real_if_close(a: ArrayLike, tol=100):
1724
1668
return a .real if mask .all () else a
1725
1669
1726
1670
1671
+ def real (a : ArrayLike ):
1672
+ return torch .real (a )
1673
+
1674
+
1675
+ def imag (a : ArrayLike ):
1676
+ if a .is_complex ():
1677
+ return a .imag
1678
+ return torch .zeros_like (a )
1679
+
1680
+
1727
1681
def iscomplex (x : ArrayLike ):
1728
1682
if torch .is_complex (x ):
1729
1683
return x .imag != 0
1730
- result = torch .zeros_like (x , dtype = torch .bool )
1731
- if result .ndim == 0 :
1732
- result = result .item ()
1733
- return result
1684
+ return torch .zeros_like (x , dtype = torch .bool )
1734
1685
1735
1686
1736
1687
def isreal (x : ArrayLike ):
1737
1688
if torch .is_complex (x ):
1738
1689
return x .imag == 0
1739
- result = torch .ones_like (x , dtype = torch .bool )
1740
- if result .ndim == 0 :
1741
- result = result .item ()
1742
- return result
1690
+ return torch .ones_like (x , dtype = torch .bool )
1743
1691
1744
1692
1745
1693
def iscomplexobj (x : ArrayLike ):
1746
- result = torch .is_complex (x )
1747
- return result
1694
+ return torch .is_complex (x )
1748
1695
1749
1696
1750
1697
def isrealobj (x : ArrayLike ):
1751
1698
return not torch .is_complex (x )
1752
1699
1753
1700
1754
1701
def isneginf (x : ArrayLike , out : Optional [OutArray ] = None ):
1755
- return torch .isneginf (x , out = out )
1702
+ return torch .isneginf (x )
1756
1703
1757
1704
1758
1705
def isposinf (x : ArrayLike , out : Optional [OutArray ] = None ):
1759
- return torch .isposinf (x , out = out )
1706
+ return torch .isposinf (x )
1760
1707
1761
1708
1762
1709
def i0 (x : ArrayLike ):
1763
1710
return torch .special .i0 (x )
1764
1711
1765
1712
1766
1713
def isscalar (a ):
1767
- # XXX: this is a stub
1768
1714
try :
1769
1715
t = normalize_array_like (a )
1770
1716
return t .numel () == 1
@@ -1819,8 +1765,6 @@ def bartlett(M):
1819
1765
1820
1766
1821
1767
def common_type (* tensors : ArrayLike ):
1822
- import builtins
1823
-
1824
1768
is_complex = False
1825
1769
precision = 0
1826
1770
for a in tensors :
@@ -1857,7 +1801,7 @@ def histogram(
1857
1801
is_a_int = not (a .dtype .is_floating_point or a .dtype .is_complex )
1858
1802
is_w_int = weights is None or not weights .dtype .is_floating_point
1859
1803
if is_a_int :
1860
- a = a .to ( float )
1804
+ a = a .double ( )
1861
1805
1862
1806
if weights is not None :
1863
1807
weights = _util .cast_if_needed (weights , a .dtype )
@@ -1877,8 +1821,8 @@ def histogram(
1877
1821
)
1878
1822
1879
1823
if not density and is_w_int :
1880
- h = h .to ( int )
1824
+ h = h .long ( )
1881
1825
if is_a_int :
1882
- b = b .to ( int )
1826
+ b = b .long ( )
1883
1827
1884
1828
return h , b
0 commit comments