@@ -31,7 +31,7 @@ def tensor_equiv(a1_t, a2_t):
31
31
return tensor_equal (a1_t , a2_t )
32
32
33
33
34
- def tensor_isclose (a , b , rtol = 1.0e-5 , atol = 1.0e-8 , equal_nan = False ):
34
+ def isclose (a , b , rtol = 1.0e-5 , atol = 1.0e-8 , equal_nan = False ):
35
35
dtype = _dtypes_impl .result_type_impl ((a .dtype , b .dtype ))
36
36
a = a .to (dtype )
37
37
b = b .to (dtype )
@@ -42,7 +42,7 @@ def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
42
42
# ### is arg real or complex valued ###
43
43
44
44
45
- def tensor_iscomplex (x ):
45
+ def iscomplex (x ):
46
46
if torch .is_complex (x ):
47
47
return torch .as_tensor (x ).imag != 0
48
48
result = torch .zeros_like (x , dtype = torch .bool )
@@ -51,7 +51,7 @@ def tensor_iscomplex(x):
51
51
return result
52
52
53
53
54
- def tensor_isreal (x ):
54
+ def isreal (x ):
55
55
if torch .is_complex (x ):
56
56
return torch .as_tensor (x ).imag == 0
57
57
result = torch .ones_like (x , dtype = torch .bool )
@@ -60,7 +60,8 @@ def tensor_isreal(x):
60
60
return result
61
61
62
62
63
- def tensor_real_if_close (x , tol = 100 ):
63
+ def real_if_close (x , tol = 100 ):
64
+ # XXX: copies vs views; numpy seems to return a copy?
64
65
if not torch .is_complex (x ):
65
66
return x
66
67
mask = torch .abs (x .imag ) < tol * torch .finfo (x .dtype ).eps
@@ -73,20 +74,20 @@ def tensor_real_if_close(x, tol=100):
73
74
# ### math functions ###
74
75
75
76
76
- def tensor_angle (z , deg = False ):
77
+ def angle (z , deg = False ):
77
78
result = torch .angle (z )
78
79
if deg :
79
- result *= 180 / torch .pi
80
+ result = result * 180 / torch .pi
80
81
return result
81
82
82
83
83
84
# ### sorting ###
84
85
85
86
86
- def tensor_argsort (tensor , axis = - 1 , kind = None , order = None ):
87
+ def argsort (tensor , axis = - 1 , kind = None , order = None ):
87
88
if order is not None :
88
89
raise NotImplementedError
89
- stable = True if kind == "stable" else False
90
+ stable = kind == "stable"
90
91
if axis is None :
91
92
axis = - 1
92
93
return torch .argsort (tensor , stable = stable , dim = axis , descending = False )
@@ -387,11 +388,11 @@ def bincount(x_tensor, /, weights_tensor=None, minlength=0):
387
388
def geomspace (start , stop , num = 50 , endpoint = True , dtype = None , axis = 0 ):
388
389
if axis != 0 or not endpoint :
389
390
raise NotImplementedError
390
- tstart , tstop = torch .as_tensor ([ start , stop ] )
391
- base = torch .pow ( tstop / tstart , 1.0 / ( num - 1 ) )
391
+ base = torch .pow ( stop / start , 1.0 / ( num - 1 ) )
392
+ logbase = torch .log ( base )
392
393
result = torch .logspace (
393
- torch .log (tstart ) / torch . log ( base ) ,
394
- torch .log (tstop ) / torch . log ( base ) ,
394
+ torch .log (start ) / logbase ,
395
+ torch .log (stop ) / logbase ,
395
396
num ,
396
397
base = base ,
397
398
)
0 commit comments