Skip to content

Commit 3a16b37

Browse files
committed
MAINT: move imag, round from ndarray to free functions
1 parent dd21e56 commit 3a16b37

File tree

7 files changed

+40
-46
lines changed

7 files changed

+40
-46
lines changed

torch_np/_detail/implementations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,15 @@ def round(tensor, decimals=0):
646646
return result
647647

648648

649+
def imag(tensor):
650+
try:
651+
result = tensor.imag
652+
except RuntimeError:
653+
# RuntimeError: imag is not implemented for tensors with non-complex dtypes.
654+
result = torch.zeros_like(tensor)
655+
return result
656+
657+
649658
# ### put/take along axis ###
650659

651660

torch_np/_funcs.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,28 @@ def _flatten(a, order="C"):
169169
(tensor,) = _helpers.to_tensors(a)
170170
result = _impl._flatten(tensor)
171171
return _helpers.array_from(result, a)
172+
173+
174+
# ### Type/shape etc queries ###
175+
176+
177+
def real(a):
178+
(tensor,) = _helpers.to_tensors(a)
179+
result = torch.real(tensor)
180+
return _helpers.array_from(result)
181+
182+
183+
def imag(a):
184+
(tensor,) = _helpers.to_tensors(a)
185+
result = _impl.imag(tensor)
186+
return _helpers.array_from(result)
187+
188+
189+
def round_(a, decimals=0, out=None):
190+
(tensor,) = _helpers.to_tensors(a)
191+
result = _impl.round(tensor, decimals)
192+
return _helpers.result_or_out(result, out)
193+
194+
195+
around = round_
196+
round = round_

torch_np/_helpers.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,3 @@ def to_tensors_or_none(*inputs):
124124
from ._ndarray import asarray, ndarray
125125

126126
return tuple(None if value is None else asarray(value).get() for value in inputs)
127-
128-
129-
def _outer(x, y):
130-
from ._ndarray import asarray
131-
132-
x_tensor, y_tensor = to_tensors(x, y)
133-
result = torch.outer(x_tensor, y_tensor)
134-
return asarray(result)

torch_np/_ndarray.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,27 +139,21 @@ def T(self):
139139

140140
@property
141141
def real(self):
142-
return asarray(self._tensor.real)
142+
return _funcs.real(self)
143143

144144
@real.setter
145145
def real(self, value):
146146
self._tensor.real = asarray(value).get()
147147

148148
@property
149149
def imag(self):
150-
try:
151-
return asarray(self._tensor.imag)
152-
except RuntimeError:
153-
zeros = torch.zeros_like(self._tensor)
154-
return ndarray._from_tensor_and_base(zeros, None)
150+
return _funcs.imag(self)
155151

156152
@imag.setter
157153
def imag(self, value):
158154
self._tensor.imag = asarray(value).get()
159155

160-
def round(self, decimals=0, out=None):
161-
result = _impl.round(self._tensor, decimals)
162-
return _helpers.result_or_out(result, out)
156+
round = _funcs.round
163157

164158
# ctors
165159
def astype(self, dtype):

torch_np/_wrapper.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -457,11 +457,6 @@ def size(a, axis=None):
457457
###### shape manipulations and indexing
458458

459459

460-
def ravel(a, order="C"):
461-
arr = asarray(a)
462-
return arr.ravel(order=order)
463-
464-
465460
def expand_dims(a, axis):
466461
a = asarray(a)
467462
shape = _util.expand_shape(a.shape, axis)
@@ -559,15 +554,6 @@ def roll(a, shift, axis=None):
559554
return asarray(result)
560555

561556

562-
def round_(a, decimals=0, out=None):
563-
arr = asarray(a)
564-
return arr.round(decimals, out=out)
565-
566-
567-
around = round_
568-
round = round_
569-
570-
571557
###### tri{l, u} and related
572558
@asarray_replacer()
573559
def tril(m, k=0):
@@ -897,16 +883,6 @@ def sinc(x):
897883
return torch.sinc(x)
898884

899885

900-
def real(a):
901-
arr = asarray(a)
902-
return arr.real
903-
904-
905-
def imag(a):
906-
arr = asarray(a)
907-
return arr.imag
908-
909-
910886
@asarray_replacer()
911887
def real_if_close(a, tol=100):
912888
result = _impl.real_if_close(a, tol=tol)

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
def get_mat(n):
3939
data = np.arange(n)
4040
# data = np.add.outer(data, data)
41-
from torch_np._helpers import _outer
42-
data = _outer(data, data)
41+
data = data[:, None] + data[None, :]
4342
return data
4443

4544

torch_np/tests/numpy_tests/lib/test_twodim_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121

2222
def get_mat(n):
2323
data = np.arange(n)
24-
# data = np.add.outer(data, data)
25-
from torch_np._helpers import _outer
26-
data = _outer(data, data)
24+
# data = np.add.outer(data, data)
25+
data = data[:, None] + data[None, :]
2726
return data
2827

2928

0 commit comments

Comments
 (0)