Skip to content

Commit f63c6bb

Browse files
committed
Merge branch 'scalarmath' into main
reviewed at gh-16
2 parents bc75eb2 + 2b95ac2 commit f63c6bb

File tree

8 files changed

+234
-139
lines changed

8 files changed

+234
-139
lines changed

torch_np/_dtypes.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __repr__(self):
6767

6868
__str__ = __repr__
6969

70+
@property
7071
def itemsize(self):
7172
elem = self.type(1)
7273
return elem.get().element_size()
@@ -193,6 +194,8 @@ def torch_dtype_from(dtyp):
193194
raise TypeError
194195

195196

197+
# ### Defaults and dtype discovery
198+
196199
def default_int_type():
197200
return dtype('int64')
198201

@@ -201,14 +204,33 @@ def default_float_type():
201204
return dtype('float64')
202205

203206

207+
def default_complex_type():
208+
return dtype('complex128')
209+
210+
204211
def is_floating(dtyp):
205212
dtyp = dtype(dtyp)
206-
return dtyp.typecode in typecodes['AllFloat']
213+
return issubclass(dtyp.type, _scalar_types.floating)
214+
207215

208216
def is_integer(dtyp):
209217
dtyp = dtype(dtyp)
210-
return dtyp.typecode in typecodes['AllInteger']
211-
218+
return issubclass(dtyp.type, _scalar_types.integer)
219+
220+
221+
def get_default_dtype_for(dtyp):
222+
typ = dtype(dtyp).type
223+
if issubclass(typ, _scalar_types.integer):
224+
result = default_int_type()
225+
elif issubclass(typ, _scalar_types.floating):
226+
result = default_float_type()
227+
elif issubclass(typ, _scalar_types.complexfloating):
228+
result = default_complex_type()
229+
elif issubclass(typ, _scalar_types.bool_):
230+
result = dtype('bool')
231+
else:
232+
raise TypeError("dtype %s not understood." % dtyp)
233+
return result
212234

213235

214236
def issubclass_(arg, klass):

torch_np/_ndarray.py

Lines changed: 117 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ def base(self):
8181
def T(self):
8282
return self.transpose()
8383

84+
@property
85+
def real(self):
86+
return asarray(self._tensor.real)
87+
88+
@property
89+
def imag(self):
90+
try:
91+
return asarray(self._tensor.imag)
92+
except RuntimeError:
93+
zeros = torch.zeros_like(self._tensor)
94+
return ndarray._from_tensor_and_base(zeros, None)
95+
8496
# ctors
8597
def astype(self, dtype):
8698
newt = ndarray()
@@ -102,6 +114,13 @@ def __str__(self):
102114

103115
### comparisons ###
104116
def __eq__(self, other):
117+
try:
118+
t_other = asarray(other).get
119+
except RuntimeError:
120+
# Failed to convert other to array: definitely not equal.
121+
# TODO: generalize, delegate to ufuncs
122+
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
123+
return asarray(falsy)
105124
return asarray(self._tensor == asarray(other).get())
106125

107126
def __neq__(self, other):
@@ -119,7 +138,6 @@ def __ge__(self, other):
119138
def __le__(self, other):
120139
return asarray(self._tensor <= asarray(other).get())
121140

122-
123141
def __bool__(self):
124142
try:
125143
return bool(self._tensor)
@@ -141,6 +159,9 @@ def __hash__(self):
141159
def __float__(self):
142160
return float(self._tensor)
143161

162+
def __int__(self):
163+
return int(self._tensor)
164+
144165
# XXX : are single-element ndarrays scalars?
145166
def is_integer(self):
146167
if self.shape == ():
@@ -167,7 +188,10 @@ def __iadd__(self, other):
167188

168189
def __sub__(self, other):
169190
other_tensor = asarray(other).get()
170-
return asarray(self._tensor.__sub__(other_tensor))
191+
try:
192+
return asarray(self._tensor.__sub__(other_tensor))
193+
except RuntimeError as e:
194+
raise TypeError(e.args)
171195

172196
def __mul__(self, other):
173197
other_tensor = asarray(other).get()
@@ -177,10 +201,30 @@ def __rmul__(self, other):
177201
other_tensor = asarray(other).get()
178202
return asarray(self._tensor.__rmul__(other_tensor))
179203

204+
def __floordiv__(self, other):
205+
other_tensor = asarray(other).get()
206+
return asarray(self._tensor.__floordiv__(other_tensor))
207+
208+
def __ifloordiv__(self, other):
209+
other_tensor = asarray(other).get()
210+
return asarray(self._tensor.__ifloordiv__(other_tensor))
211+
180212
def __truediv__(self, other):
181213
other_tensor = asarray(other).get()
182214
return asarray(self._tensor.__truediv__(other_tensor))
183215

216+
def __itruediv__(self, other):
217+
other_tensor = asarray(other).get()
218+
return asarray(self._tensor.__itruediv__(other_tensor))
219+
220+
def __mod__(self, other):
221+
other_tensor = asarray(other).get()
222+
return asarray(self._tensor.__mod__(other_tensor))
223+
224+
def __imod__(self, other):
225+
other_tensor = asarray(other).get()
226+
return asarray(self._tensor.__imod__(other_tensor))
227+
184228
def __or__(self, other):
185229
other_tensor = asarray(other).get()
186230
return asarray(self._tensor.__or__(other_tensor))
@@ -189,10 +233,22 @@ def __ior__(self, other):
189233
other_tensor = asarray(other).get()
190234
return asarray(self._tensor.__ior__(other_tensor))
191235

192-
193236
def __invert__(self):
194237
return asarray(self._tensor.__invert__())
195238

239+
def __abs__(self):
240+
return asarray(self._tensor.__abs__())
241+
242+
def __neg__(self):
243+
try:
244+
return asarray(self._tensor.__neg__())
245+
except RuntimeError as e:
246+
raise TypeError(e.args)
247+
248+
def __pow__(self, exponent):
249+
exponent_tensor = asarray(exponent).get()
250+
return asarray(self._tensor.__pow__(exponent_tensor))
251+
196252
### methods to match namespace functions
197253

198254
def squeeze(self, axis=None):
@@ -301,7 +357,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal
301357

302358
if dtype is None:
303359
dtype = self.dtype
304-
if not _dtypes.is_floating(dtype):
360+
if _dtypes.is_integer(dtype):
305361
dtype = _dtypes.default_float_type()
306362
torch_dtype = _dtypes.torch_dtype_from(dtype)
307363

@@ -321,7 +377,7 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,
321377

322378
if dtype is None:
323379
dtype = self.dtype
324-
if not _dtypes.is_floating(dtype):
380+
if _dtypes.is_integer(dtype):
325381
dtype = _dtypes.default_float_type()
326382
torch_dtype = _dtypes.torch_dtype_from(dtype)
327383

@@ -343,67 +399,80 @@ def __setitem__(self, index, value):
343399
return self._tensor.__setitem__(index, value)
344400

345401

346-
def asarray(a, dtype=None, order=None, *, like=None):
347-
_util.subok_not_ok(like)
348-
if order is not None:
402+
# This is the ideally the only place which talks to ndarray directly.
403+
# The rest goes through asarray (preferred) or array.
404+
405+
def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,
406+
like=None):
407+
_util.subok_not_ok(like, subok)
408+
if order != 'K':
349409
raise NotImplementedError
350410

351-
if isinstance(a, ndarray):
352-
if dtype is not None and dtype != a.dtype:
353-
a = a.astype(dtype)
354-
return a
411+
# a happy path
412+
if isinstance(object, ndarray):
413+
if copy is False and dtype is None and ndmin <= object.ndim:
414+
return object
355415

356-
if isinstance(a, (list, tuple)):
357-
# handle lists of ndarrays, [1, [2, 3], ndarray(4)] etc
416+
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
417+
if isinstance(object, (list, tuple)):
358418
a1 = []
359-
for elem in a:
419+
for elem in object:
360420
if isinstance(elem, ndarray):
361421
a1.append(elem.get().tolist())
362422
else:
363423
a1.append(elem)
424+
object = a1
425+
426+
# get the tensor from "object"
427+
if isinstance(object, ndarray):
428+
tensor = object._tensor
429+
base = object
430+
elif isinstance(object, torch.Tensor):
431+
tensor = object
432+
base = None
364433
else:
365-
a1 = a
434+
tensor = torch.as_tensor(object)
435+
base = None
366436

367-
torch_dtype = _dtypes.torch_dtype_from(dtype)
437+
# At this point, `tensor.dtype` is the pytorch default. Our default may
438+
# differ, so need to typecast. However, we cannot just do `tensor.to`,
439+
# because if our desired dtype is wider then pytorch's, `tensor`
440+
# may have lost precision:
368441

369-
# This and array(...) are the only places which talk to ndarray directly.
370-
# The rest goes through asarray (preferred) or array.
371-
out = ndarray()
372-
tt = torch.as_tensor(a1, dtype=torch_dtype)
373-
out._tensor = tt
374-
return out
442+
# int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)
375443

444+
# Therefore, we treat `tensor.dtype` as a hint, and convert the
445+
# original object *again*, this time with an explicit dtype.
446+
dtyp = _dtypes.dtype_from_torch(tensor.dtype)
447+
default = _dtypes.get_default_dtype_for(dtyp)
448+
torch_dtype = _dtypes.torch_dtype_from(default)
376449

377-
def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,
378-
like=None):
379-
_util.subok_not_ok(like, subok)
380-
if order != 'K':
381-
raise NotImplementedError
382-
383-
if isinstance(object, (list, tuple)):
384-
obj = asarray(object)
385-
return array(obj, dtype, copy=copy, order=order, subok=subok,
386-
ndmin=ndmin, like=like)
450+
tensor = torch.as_tensor(object, dtype=torch_dtype)
387451

388-
if isinstance(object, ndarray):
389-
result = object._tensor
390-
391-
if dtype != object.dtype:
392-
torch_dtype = _dtypes.torch_dtype_from(dtype)
393-
result = result.to(torch_dtype)
394-
else:
452+
# type cast if requested
453+
if dtype is not None:
395454
torch_dtype = _dtypes.torch_dtype_from(dtype)
396-
result = torch.as_tensor(object, dtype=torch_dtype)
455+
tensor = tensor.to(torch_dtype)
456+
base = None
397457

458+
# adjust ndim if needed
459+
ndim_extra = ndmin - tensor.ndim
460+
if ndim_extra > 0:
461+
tensor = tensor.view((1,)*ndim_extra + tensor.shape)
462+
base = None
463+
464+
# copy if requested
398465
if copy:
399-
result = result.clone()
466+
tensor = tensor.clone()
467+
base = None
400468

401-
ndim_extra = ndmin - result.ndim
402-
if ndim_extra > 0:
403-
result = result.reshape((1,)*ndim_extra + result.shape)
404-
out = ndarray()
405-
out._tensor = result
406-
return out
469+
return ndarray._from_tensor_and_base(tensor, base)
470+
471+
472+
def asarray(a, dtype=None, order=None, *, like=None):
473+
if order is None:
474+
order = 'K'
475+
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
407476

408477

409478

torch_np/_scalar_types.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ def __new__(self, value):
2525
if isinstance(value, _ndarray.ndarray):
2626
tensor = value.get()
2727
else:
28-
tensor = torch.as_tensor(value, dtype=torch_dtype)
28+
try:
29+
tensor = torch.as_tensor(value, dtype=torch_dtype)
30+
except RuntimeError as e:
31+
if "Overflow" in str(e):
32+
raise OverflowError(e.args)
33+
raise e
2934
#
3035
# With numpy:
3136
# >>> a = np.ones(3)
@@ -135,6 +140,7 @@ class bool_(generic):
135140
half = float16
136141
single = float32
137142
double = float64
143+
float_ = float64
138144

139145
csingle = complex64
140146
cdouble = complex128
@@ -169,8 +175,8 @@ class bool_(generic):
169175
__all__ = list(_typemap.keys())
170176
__all__.remove('bool')
171177

172-
__all__ += ['bool_', 'intp', 'int_', 'intc', 'byte', 'short', 'longlong', 'ubyte', 'half', 'single', 'double',
173-
'csingle', 'cdouble']
178+
__all__ += ['bool_', 'intp', 'int_', 'intc', 'byte', 'short', 'longlong',
179+
'ubyte', 'half', 'single', 'double', 'csingle', 'cdouble', 'float_']
174180
__all__ += ['sctypes']
175181
__all__ += ['generic', 'number',
176182
'integer', 'signedinteger', 'unsignedinteger',

torch_np/_wrapper.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,11 @@ def argwhere(a):
510510
return asarray(torch.argwhere(tensor))
511511

512512

513+
def abs(a):
514+
# FIXME: should go the other way, together with other ufuncs
515+
arr = asarray(a)
516+
return a.__abs__()
517+
513518
from ._ndarray import axis_out_keepdims_wrapper
514519

515520
@axis_out_keepdims_wrapper
@@ -702,18 +707,14 @@ def angle(z, deg=False):
702707
return result
703708

704709

705-
@asarray_replacer()
706710
def real(a):
707-
return torch.real(a)
711+
arr = asarray(a)
712+
return arr.real
708713

709714

710-
@asarray_replacer()
711715
def imag(a):
712-
# torch.imag raises on real-valued inputs
713-
if torch.is_complex(a):
714-
return torch.imag(a)
715-
else:
716-
return torch.zeros_like(a)
716+
arr = asarray(a)
717+
return arr.imag
717718

718719

719720
@asarray_replacer()

torch_np/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .utils import (assert_equal, assert_array_equal, assert_almost_equal,
22
assert_warns, assert_)
3+
from .utils import _gen_alignment_data
34

45
from .testing import assert_allclose # FIXME
56

0 commit comments

Comments
 (0)