Skip to content

Commit 2940386

Browse files
committed
MAINT: make .real/imag attributes writeable
1 parent 1222061 commit 2940386

File tree

6 files changed

+63
-75
lines changed

6 files changed

+63
-75
lines changed

torch_np/_ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def T(self):
8686
def real(self):
8787
return asarray(self._tensor.real)
8888

89+
@real.setter
90+
def real(self, value):
91+
self._tensor.real = asarray(value).get()
92+
8993
@property
9094
def imag(self):
9195
try:
@@ -94,6 +98,10 @@ def imag(self):
9498
zeros = torch.zeros_like(self._tensor)
9599
return ndarray._from_tensor_and_base(zeros, None)
96100

101+
@imag.setter
102+
def imag(self, value):
103+
self._tensor.imag = asarray(value).get()
104+
97105
# ctors
98106
def astype(self, dtype):
99107
newt = ndarray()

torch_np/_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,11 @@ def isscalar(a):
757757
return False
758758

759759

760+
def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
761+
a = asarray(a).get()
762+
b = asarray(a).get()
763+
return asarray(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
764+
760765
###### mapping from numpy API objects to wrappers from this module ######
761766

762767
# All is in the mapping dict in _mapping.py

torch_np/testing/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .utils import (assert_equal, assert_array_equal, assert_almost_equal,
2-
assert_warns, assert_)
2+
assert_warns, assert_, assert_allclose)
33
from .utils import _gen_alignment_data
44

5-
from .testing import assert_allclose # FIXME
5+
#from .testing import assert_allclose # FIXME
66

77

torch_np/testing/testing.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

torch_np/testing/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,7 @@ def _assert_valid_refcount(op):
11691169

11701170

11711171
def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1172-
err_msg='', verbose=True):
1172+
err_msg='', verbose=True, check_dtype=False):
11731173
"""
11741174
Raises an AssertionError if two objects are not equal up to desired
11751175
tolerance.
@@ -1226,14 +1226,17 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
12261226
12271227
"""
12281228
__tracebackhide__ = True # Hide traceback for py.test
1229-
import numpy as np
12301229

12311230
def compare(x, y):
1232-
return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol,
1231+
return np.isclose(x, y, rtol=rtol, atol=atol,
12331232
equal_nan=equal_nan)
12341233

12351234
actual, desired = asanyarray(actual), asanyarray(desired)
12361235
header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
1236+
1237+
if check_dtype:
1238+
assert actual.dtype == desired.dtype
1239+
12371240
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
12381241
verbose=verbose, header=header, equal_nan=equal_nan)
12391242

torch_np/tests/test_reductions.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from pytest import raises as assert_raises
33

44
import torch_np as np
5-
from torch_np.testing import assert_equal, assert_array_equal, assert_allclose
5+
from torch_np.testing import (assert_equal, assert_array_equal, assert_allclose,
6+
assert_almost_equal)
67

78
import torch_np._util as _util
89

@@ -321,12 +322,10 @@ def test_sum(self):
321322
def test_sum_stability(self):
322323
a = np.ones(500, dtype=np.float32)
323324
zero = np.zeros(1, dtype='float32')[0]
324-
assert_allclose((a / 10.).sum() - a.size / 10., zero, atol=1.5e-4,
325-
check_dtype=False)
325+
assert_allclose((a / 10.).sum() - a.size / 10., zero, atol=1.5e-4)
326326

327327
a = np.ones(500, dtype=np.float64)
328-
assert_allclose((a / 10.).sum() - a.size / 10., 0., atol=1.5e-13,
329-
check_dtype=False)
328+
assert_allclose((a / 10.).sum() - a.size / 10., 0., atol=1.5e-13)
330329

331330
def test_sum_boolean(self):
332331
a = (np.arange(7) % 2 == 0)
@@ -338,8 +337,8 @@ def test_sum_boolean(self):
338337
assert res_float.dtype == 'float64'
339338

340339

341-
@pytest.mark.xfail(reason="dtype(value) needs implementing")
342-
def test_sum_dtypes(self):
340+
@pytest.mark.xfail(reason="sum: does not warn on overflow")
341+
def test_sum_dtypes_warnings(self):
343342
for dt in (int, np.float16, np.float32, np.float64):
344343
for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127,
345344
128, 1024, 1235):
@@ -357,48 +356,54 @@ def test_sum_dtypes(self):
357356
assert_almost_equal(np.sum(d), tgt)
358357
assert_equal(len(w), 2 * overflow)
359358

360-
assert_almost_equal(np.sum(d[::-1]), tgt)
359+
assert_almost_equal(np.sum(np.flip(d)), tgt)
361360
assert_equal(len(w), 3 * overflow)
362361

362+
def test_sum_dtypes_2(self):
363+
for dt in (int, np.float16, np.float32, np.float64):
363364
d = np.ones(500, dtype=dt)
364365
assert_almost_equal(np.sum(d[::2]), 250.)
365366
assert_almost_equal(np.sum(d[1::2]), 250.)
366367
assert_almost_equal(np.sum(d[::3]), 167.)
367368
assert_almost_equal(np.sum(d[1::3]), 167.)
368-
assert_almost_equal(np.sum(d[::-2]), 250.)
369-
assert_almost_equal(np.sum(d[-1::-2]), 250.)
370-
assert_almost_equal(np.sum(d[::-3]), 167.)
371-
assert_almost_equal(np.sum(d[-1::-3]), 167.)
369+
assert_almost_equal(np.sum(np.flip(d)[::2]), 250.)
370+
371+
assert_almost_equal(np.sum(np.flip(d)[1::2]), 250.)
372+
373+
assert_almost_equal(np.sum(np.flip(d)[::3]), 167.)
374+
assert_almost_equal(np.sum(np.flip(d)[1::3]), 167.)
375+
372376
# sum with first reduction entry != 0
373377
d = np.ones((1,), dtype=dt)
374378
d += d
375379
assert_almost_equal(d, 2.)
376380

377-
@pytest.mark.xfail(reason="dtype(value) needs implementing")
378-
def test_sum_complex(self):
379-
for dt in (np.complex64, np.complex128):
380-
for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127,
381-
128, 1024, 1235):
382-
tgt = dt(v * (v + 1) / 2) - dt((v * (v + 1) / 2) * 1j)
383-
d = np.empty(v, dtype=dt)
384-
d.real = np.arange(1, v + 1)
385-
d.imag = -np.arange(1, v + 1)
386-
assert_allclose(np.sum(d), tgt, atol=1.5e-5)
387-
assert_allcllose(np.sum(d[::-1]), tgt, atol=1.5e-7)
388-
389-
d = np.ones(500, dtype=dt) + 1j
390-
assert_allclose(np.sum(d[::2]), 250. + 250j, atol=1.5e-7)
391-
assert_allclose(np.sum(d[1::2]), 250. + 250j, atol=1.5e-7)
392-
assert_allclose(np.sum(d[::3]), 167. + 167j, atol=1.5e-7)
393-
assert_allclose(np.sum(d[1::3]), 167. + 167j, atol=1.5e-7)
394-
assert_allclose(np.sum(d[::-2]), 250. + 250j, atol=1.5e-7)
395-
assert_allclose(np.sum(d[-1::-2]), 250. + 250j, atol=1.5e-7)
396-
assert_allclose(np.sum(d[::-3]), 167. + 167j, atol=1.5e-7)
397-
assert_allclose(np.sum(d[-1::-3]), 167. + 167j, atol=1.5e-7)
398-
# sum with first reduction entry != 0
399-
d = np.ones((1,), dtype=dt) + 1j
400-
d += d
401-
assert_allclose(d, 2. + 2j, atol=1.5e-7)
381+
@pytest.mark.parametrize("dt", [np.complex64, np.complex128])
382+
def test_sum_complex_1(self, dt):
383+
for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127,
384+
128, 1024, 1235):
385+
tgt = dt(v * (v + 1) / 2) - dt((v * (v + 1) / 2) * 1j)
386+
d = np.empty(v, dtype=dt)
387+
d.real = np.arange(1, v + 1)
388+
d.imag = -np.arange(1, v + 1)
389+
assert_allclose(np.sum(d), tgt, atol=1.5e-5)
390+
assert_allclose(np.sum(np.flip(d)), tgt, atol=1.5e-7)
391+
392+
@pytest.mark.parametrize("dt", [np.complex64, np.complex128])
393+
def test_sum_complex_2(self, dt):
394+
d = np.ones(500, dtype=dt) + 1j
395+
assert_allclose(np.sum(d[::2]), 250. + 250j, atol=1.5e-7)
396+
assert_allclose(np.sum(d[1::2]), 250. + 250j, atol=1.5e-7)
397+
assert_allclose(np.sum(d[::3]), 167. + 167j, atol=1.5e-7)
398+
assert_allclose(np.sum(d[1::3]), 167. + 167j, atol=1.5e-7)
399+
assert_allclose(np.sum(np.flip(d)[::2]), 250. + 250j, atol=1.5e-7)
400+
assert_allclose(np.sum(np.flip(d)[1::2]), 250. + 250j, atol=1.5e-7)
401+
assert_allclose(np.sum(np.flip(d)[::3]), 167. + 167j, atol=1.5e-7)
402+
assert_allclose(np.sum(np.flip(d)[1::3]), 167. + 167j, atol=1.5e-7)
403+
# sum with first reduction entry != 0
404+
d = np.ones((1,), dtype=dt) + 1j
405+
d += d
406+
assert_allclose(d, 2. + 2j, atol=1.5e-7)
402407

403408
@pytest.mark.xfail(reason='initial=... need implementing')
404409
def test_sum_initial(self):

0 commit comments

Comments
 (0)