Skip to content

Commit 3d9f09c

Browse files
committed
MAINT: expand types/dtypes, adapt test_scalarmath
1 parent 6993215 commit 3d9f09c

File tree

7 files changed

+130
-84
lines changed

7 files changed

+130
-84
lines changed

torch_np/_dtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __repr__(self):
6767

6868
__str__ = __repr__
6969

70+
@property
7071
def itemsize(self):
7172
elem = self.type(1)
7273
return elem.get().element_size()

torch_np/_ndarray.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ def base(self):
8181
def T(self):
8282
return self.transpose()
8383

84+
@property
85+
def real(self):
86+
return asarray(self._tensor.real)
87+
88+
@property
89+
def imag(self):
90+
try:
91+
return asarray(self._tensor.imag)
92+
except RuntimeError:
93+
zeros = torch.zeros_like(self._tensor)
94+
return ndarray._from_tensor_and_base(zeros, None)
95+
8496
# ctors
8597
def astype(self, dtype):
8698
newt = ndarray()
@@ -102,6 +114,13 @@ def __str__(self):
102114

103115
### comparisons ###
104116
def __eq__(self, other):
117+
try:
118+
t_other = asarray(other).get
119+
except RuntimeError:
120+
# Failed to convert other to array: definitely not equal.
121+
# TODO: generalize, delegate to ufuncs
122+
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
123+
return asarray(falsy)
105124
return asarray(self._tensor == asarray(other).get())
106125

107126
def __neq__(self, other):
@@ -119,7 +138,6 @@ def __ge__(self, other):
119138
def __le__(self, other):
120139
return asarray(self._tensor <= asarray(other).get())
121140

122-
123141
def __bool__(self):
124142
try:
125143
return bool(self._tensor)
@@ -167,7 +185,10 @@ def __iadd__(self, other):
167185

168186
def __sub__(self, other):
169187
other_tensor = asarray(other).get()
170-
return asarray(self._tensor.__sub__(other_tensor))
188+
try:
189+
return asarray(self._tensor.__sub__(other_tensor))
190+
except RuntimeError as e:
191+
raise TypeError(e.args)
171192

172193
def __mul__(self, other):
173194
other_tensor = asarray(other).get()
@@ -177,10 +198,30 @@ def __rmul__(self, other):
177198
other_tensor = asarray(other).get()
178199
return asarray(self._tensor.__rmul__(other_tensor))
179200

201+
def __floordiv__(self, other):
202+
other_tensor = asarray(other).get()
203+
return asarray(self._tensor.__floordiv__(other_tensor))
204+
205+
def __ifloordiv__(self, other):
206+
other_tensor = asarray(other).get()
207+
return asarray(self._tensor.__ifloordiv__(other_tensor))
208+
180209
def __truediv__(self, other):
181210
other_tensor = asarray(other).get()
182211
return asarray(self._tensor.__truediv__(other_tensor))
183212

213+
def __itruediv__(self, other):
214+
other_tensor = asarray(other).get()
215+
return asarray(self._tensor.__itruediv__(other_tensor))
216+
217+
def __mod__(self, other):
218+
other_tensor = asarray(other).get()
219+
return asarray(self._tensor.__mod__(other_tensor))
220+
221+
def __imod__(self, other):
222+
other_tensor = asarray(other).get()
223+
return asarray(self._tensor.__imod__(other_tensor))
224+
184225
def __or__(self, other):
185226
other_tensor = asarray(other).get()
186227
return asarray(self._tensor.__or__(other_tensor))
@@ -189,10 +230,22 @@ def __ior__(self, other):
189230
other_tensor = asarray(other).get()
190231
return asarray(self._tensor.__ior__(other_tensor))
191232

192-
193233
def __invert__(self):
194234
return asarray(self._tensor.__invert__())
195235

236+
def __abs__(self):
237+
return asarray(self._tensor.__abs__())
238+
239+
def __neg__(self):
240+
try:
241+
return asarray(self._tensor.__neg__())
242+
except RuntimeError as e:
243+
raise TypeError(e.args)
244+
245+
def __pow__(self, exponent):
246+
exponent_tensor = asarray(exponent).get()
247+
return asarray(self._tensor.__pow__(exponent_tensor))
248+
196249
### methods to match namespace functions
197250

198251
def squeeze(self, axis=None):

torch_np/_scalar_types.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ def __new__(self, value):
2525
if isinstance(value, _ndarray.ndarray):
2626
tensor = value.get()
2727
else:
28-
tensor = torch.as_tensor(value, dtype=torch_dtype)
28+
try:
29+
tensor = torch.as_tensor(value, dtype=torch_dtype)
30+
except RuntimeError as e:
31+
if "Overflow" in str(e):
32+
raise OverflowError(e.args)
33+
raise e
2934
#
3035
# With numpy:
3136
# >>> a = np.ones(3)
@@ -135,6 +140,7 @@ class bool_(generic):
135140
half = float16
136141
single = float32
137142
double = float64
143+
float_ = float64
138144

139145
csingle = complex64
140146
cdouble = complex128
@@ -169,8 +175,8 @@ class bool_(generic):
169175
__all__ = list(_typemap.keys())
170176
__all__.remove('bool')
171177

172-
__all__ += ['bool_', 'intp', 'int_', 'intc', 'byte', 'short', 'longlong', 'ubyte', 'half', 'single', 'double',
173-
'csingle', 'cdouble']
178+
__all__ += ['bool_', 'intp', 'int_', 'intc', 'byte', 'short', 'longlong',
179+
'ubyte', 'half', 'single', 'double', 'csingle', 'cdouble', 'float_']
174180
__all__ += ['sctypes']
175181
__all__ += ['generic', 'number',
176182
'integer', 'signedinteger', 'unsignedinteger',

torch_np/_wrapper.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,11 @@ def argwhere(a):
510510
return asarray(torch.argwhere(tensor))
511511

512512

513+
def abs(a):
514+
# FIXME: should go the other way, together with other ufuncs
515+
arr = asarray(a)
516+
return a.__abs__()
517+
513518
from ._ndarray import axis_out_keepdims_wrapper
514519

515520
@axis_out_keepdims_wrapper
@@ -702,18 +707,14 @@ def angle(z, deg=False):
702707
return result
703708

704709

705-
@asarray_replacer()
706710
def real(a):
707-
return torch.real(a)
711+
arr = asarray(a)
712+
return arr.real
708713

709714

710-
@asarray_replacer()
711715
def imag(a):
712-
# torch.imag raises on real-valued inputs
713-
if torch.is_complex(a):
714-
return torch.imag(a)
715-
else:
716-
return torch.zeros_like(a)
716+
arr = asarray(a)
717+
return arr.imag
717718

718719

719720
@asarray_replacer()

torch_np/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .utils import (assert_equal, assert_array_equal, assert_almost_equal,
22
assert_warns, assert_)
3+
from .utils import _gen_alignment_data
34

45
from .testing import assert_allclose # FIXME
56

torch_np/testing/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,9 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
869869
870870
"""
871871
__tracebackhide__ = True # Hide traceback for py.test
872-
from numpy.core import number, float_, result_type, array
873-
from numpy.core.numerictypes import issubdtype
874-
from numpy.core.fromnumeric import any as npany
872+
from torch_np import number, float_, result_type, array
873+
from torch_np import issubdtype
874+
from torch_np import any as npany
875875

876876
def compare(x, y):
877877
try:

0 commit comments

Comments
 (0)