Skip to content

Commit 7724e7b

Browse files
committed
BUG: fix tolerance of real_if_close
1 parent 1dbacff commit 7724e7b

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

torch_np/_detail/implementations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,13 @@ def real_if_close(x, tol=100):
6464
# XXX: copies vs views; numpy seems to return a copy?
6565
if not torch.is_complex(x):
6666
return x
67-
mask = torch.abs(x.imag) < tol * torch.finfo(x.dtype).eps
67+
if tol > 1:
68+
# Undocumented in numpy: if tol < 1, it's an absolute tolerance!
69+
# Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon
70+
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577
71+
tol = tol * torch.finfo(x.dtype).eps
72+
73+
mask = torch.abs(x.imag) < tol
6874
if mask.all():
6975
return x.real
7076
else:

torch_np/tests/numpy_tests/lib/test_type_check.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_default_3(self):
8282
assert_equal(mintypecode('idD'), 'D')
8383

8484

85-
@pytest.mark.xfail(reason="not implemented")
85+
@pytest.mark.xfail(reason="TODO: decide on if [1] is a scalar or not")
8686
class TestIsscalar:
8787

8888
def test_basic(self):
@@ -188,7 +188,6 @@ def test_isreal_real(self):
188188
assert res.all()
189189

190190

191-
@pytest.mark.xfail(reason="not implemented")
192191
class TestIscomplexobj:
193192

194193
def test_basic(self):
@@ -206,7 +205,7 @@ def test_list(self):
206205
assert_(not iscomplexobj([3, 1, True]))
207206

208207

209-
@pytest.mark.xfail(reason="not implemented")
208+
210209
class TestIsrealobj:
211210
def test_basic(self):
212211
z = np.array([-1, 0, 1])
@@ -215,7 +214,6 @@ def test_basic(self):
215214
assert_(not isrealobj(z))
216215

217216

218-
@pytest.mark.xfail(reason="not implemented")
219217
class TestIsnan:
220218

221219
def test_goodvalues(self):
@@ -246,7 +244,6 @@ def test_complex1(self):
246244
assert_all(np.isnan(np.array(0+0j)/0.) == 1)
247245

248246

249-
@pytest.mark.xfail(reason="not implemented")
250247
class TestIsfinite:
251248
# Fixme, wrong place, isfinite now ufunc
252249

@@ -278,7 +275,6 @@ def test_complex1(self):
278275
assert_all(np.isfinite(np.array(1+1j)/0.) == 0)
279276

280277

281-
@pytest.mark.xfail(reason="not implemented")
282278
class TestIsinf:
283279
# Fixme, wrong place, isinf now ufunc
284280

@@ -308,7 +304,6 @@ def test_ind(self):
308304
assert_all(np.isinf(np.array((0.,))/0.) == 0)
309305

310306

311-
@pytest.mark.xfail(reason="not implemented")
312307
class TestIsposinf:
313308

314309
def test_generic(self):
@@ -319,7 +314,6 @@ def test_generic(self):
319314
assert_(vals[2] == 1)
320315

321316

322-
@pytest.mark.xfail(reason="not implemented")
323317
class TestIsneginf:
324318

325319
def test_generic(self):
@@ -436,7 +430,6 @@ def test_do_not_rewrite_previous_keyword(self):
436430
assert_equal(type(vals), np.ndarray)
437431

438432

439-
@pytest.mark.xfail(reason="not implemented")
440433
class TestRealIfClose:
441434

442435
def test_basic(self):

0 commit comments

Comments
 (0)