Skip to content

Commit 8b83bfb

Browse files
committed
MAINT: isreal, iscomplex, real_if_close, angle
1 parent 5806a5e commit 8b83bfb

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

torch_np/_detail/implementations.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,47 @@ def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
3939
return result
4040

4141

42+
# ### is arg real or complex valued ###
43+
44+
45+
def tensor_iscomplex(x):
46+
if torch.is_complex(x):
47+
return torch.as_tensor(x).imag != 0
48+
result = torch.zeros_like(x, dtype=torch.bool)
49+
if result.ndim == 0:
50+
result = result.item()
51+
return result
52+
53+
54+
def tensor_isreal(x):
55+
if torch.is_complex(x):
56+
return torch.as_tensor(x).imag == 0
57+
result = torch.zeros_like(x, dtype=torch.bool)
58+
if result.ndim == 0:
59+
result = result.item()
60+
return result
61+
62+
63+
def tensor_real_if_close(x, tol=100):
64+
if not torch.is_complex(x):
65+
return x
66+
mask = torch.abs(x.imag) < tol * torch.finfo(x.dtype).eps
67+
if mask.all():
68+
return x.real
69+
else:
70+
return x
71+
72+
73+
# ### math functions ###
74+
75+
76+
def tensor_angle(z, deg=False):
77+
result = torch.angle(z)
78+
if deg:
79+
result *= 180 / torch.pi
80+
return result
81+
82+
4283
# ### splits ###
4384

4485

torch_np/_wrapper.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,8 @@ def argsort(a, axis=-1, kind=None, order=None):
10251025

10261026
@asarray_replacer()
10271027
def angle(z, deg=False):
1028-
result = torch.angle(z)
1029-
if deg:
1030-
result *= 180 / torch.pi
1031-
return asarray(result)
1028+
result = _impl.tensor_angle(z, deg)
1029+
return result
10321030

10331031

10341032
@asarray_replacer()
@@ -1048,28 +1046,20 @@ def imag(a):
10481046

10491047
@asarray_replacer()
10501048
def real_if_close(a, tol=100):
1051-
if not torch.is_complex(a):
1052-
return a
1053-
if torch.abs(torch.imag) < tol * torch.finfo(a.dtype).eps:
1054-
return torch.real(a)
1055-
else:
1056-
return a
1049+
result = _impl.tensor_real_if_close(a, tol=tol)
1050+
return result
10571051

10581052

10591053
@asarray_replacer()
10601054
def iscomplex(x):
1061-
if torch.is_complex(x):
1062-
return torch.as_tensor(x).imag != 0
1063-
result = torch.zeros_like(x, dtype=torch.bool)
1064-
return result[()]
1055+
result = _impl.tensor_iscomplex(x)
1056+
return result # XXX: missing .item on a zero-dim value; a case for array_or_scalar(value) ?
10651057

10661058

10671059
@asarray_replacer()
10681060
def isreal(x):
1069-
if torch.is_complex(x):
1070-
return torch.as_tensor(x).imag == 0
1071-
result = torch.zeros_like(x, dtype=torch.bool)
1072-
return result[()]
1061+
result = _impl.tensor_isreal(x)
1062+
return result
10731063

10741064

10751065
@asarray_replacer()

0 commit comments

Comments
 (0)