Skip to content

Commit cf4d5c1

Browse files
committed
MAINT: address review comments
1 parent 85d1ba6 commit cf4d5c1

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

torch_np/_detail/implementations.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def tensor_equiv(a1_t, a2_t):
3131
return tensor_equal(a1_t, a2_t)
3232

3333

34-
def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
34+
def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
3535
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
3636
a = a.to(dtype)
3737
b = b.to(dtype)
@@ -42,7 +42,7 @@ def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
4242
# ### is arg real or complex valued ###
4343

4444

45-
def tensor_iscomplex(x):
45+
def iscomplex(x):
4646
if torch.is_complex(x):
4747
return torch.as_tensor(x).imag != 0
4848
result = torch.zeros_like(x, dtype=torch.bool)
@@ -51,7 +51,7 @@ def tensor_iscomplex(x):
5151
return result
5252

5353

54-
def tensor_isreal(x):
54+
def isreal(x):
5555
if torch.is_complex(x):
5656
return torch.as_tensor(x).imag == 0
5757
result = torch.ones_like(x, dtype=torch.bool)
@@ -60,7 +60,8 @@ def tensor_isreal(x):
6060
return result
6161

6262

63-
def tensor_real_if_close(x, tol=100):
63+
def real_if_close(x, tol=100):
64+
# XXX: copies vs views; numpy seems to return a copy?
6465
if not torch.is_complex(x):
6566
return x
6667
mask = torch.abs(x.imag) < tol * torch.finfo(x.dtype).eps
@@ -73,17 +74,17 @@ def tensor_real_if_close(x, tol=100):
7374
# ### math functions ###
7475

7576

76-
def tensor_angle(z, deg=False):
77+
def angle(z, deg=False):
7778
result = torch.angle(z)
7879
if deg:
79-
result *= 180 / torch.pi
80+
result = result * 180 / torch.pi
8081
return result
8182

8283

8384
# ### sorting ###
8485

8586

86-
def tensor_argsort(tensor, axis=-1, kind=None, order=None):
87+
def argsort(tensor, axis=-1, kind=None, order=None):
8788
if order is not None:
8889
raise NotImplementedError
8990
stable = True if kind == "stable" else False
@@ -387,11 +388,11 @@ def bincount(x_tensor, /, weights_tensor=None, minlength=0):
387388
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
388389
if axis != 0 or not endpoint:
389390
raise NotImplementedError
390-
tstart, tstop = torch.as_tensor([start, stop])
391-
base = torch.pow(tstop / tstart, 1.0 / (num - 1))
391+
base = torch.pow(stop / start, 1.0 / (num - 1))
392+
logbase = torch.log(base)
392393
result = torch.logspace(
393-
torch.log(tstart) / torch.log(base),
394-
torch.log(tstop) / torch.log(base),
394+
torch.log(start) / logbase,
395+
torch.log(stop) / logbase,
395396
num,
396397
base=base,
397398
)

torch_np/_wrapper.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
261261
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
262262
if axis != 0 or not endpoint:
263263
raise NotImplementedError
264+
start, stop = _helpers.to_tensors(start, stop)
264265
result = _impl.geomspace(start, stop, num, endpoint, dtype, axis)
265266
return asarray(result)
266267

@@ -954,7 +955,7 @@ def diff(a, n=1, axis=-1, prepend=NoValue, append=NoValue):
954955

955956
@asarray_replacer()
956957
def argsort(a, axis=-1, kind=None, order=None):
957-
result = _impl.tensor_argsort(a, axis, kind, order)
958+
result = _impl.argsort(a, axis, kind, order)
958959
return result
959960

960961

@@ -963,7 +964,7 @@ def argsort(a, axis=-1, kind=None, order=None):
963964

964965
@asarray_replacer()
965966
def angle(z, deg=False):
966-
result = _impl.tensor_angle(z, deg)
967+
result = _impl.angle(z, deg)
967968
return result
968969

969970

@@ -984,19 +985,19 @@ def imag(a):
984985

985986
@asarray_replacer()
986987
def real_if_close(a, tol=100):
987-
result = _impl.tensor_real_if_close(a, tol=tol)
988+
result = _impl.real_if_close(a, tol=tol)
988989
return result
989990

990991

991992
@asarray_replacer()
992993
def iscomplex(x):
993-
result = _impl.tensor_iscomplex(x)
994+
result = _impl.iscomplex(x)
994995
return result # XXX: missing .item on a zero-dim value; a case for array_or_scalar(value) ?
995996

996997

997998
@asarray_replacer()
998999
def isreal(x):
999-
result = _impl.tensor_isreal(x)
1000+
result = _impl.isreal(x)
10001001
return result
10011002

10021003

@@ -1036,13 +1037,13 @@ def isscalar(a):
10361037

10371038
def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
10381039
a_t, b_t = _helpers.to_tensors(a, b)
1039-
result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan)
1040+
result = _impl.isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan)
10401041
return asarray(result)
10411042

10421043

10431044
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
10441045
a_t, b_t = _helpers.to_tensors(a, b)
1045-
result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan)
1046+
result = _impl.isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan)
10461047
return result.all()
10471048

10481049

0 commit comments

Comments
 (0)