Skip to content

Commit 1f1973d

Browse files
committed
MAINT: rationalize the array/asarray split
1 parent 3d9f09c commit 1f1973d

File tree

4 files changed

+101
-59
lines changed

4 files changed

+101
-59
lines changed

torch_np/_dtypes.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ def torch_dtype_from(dtyp):
194194
raise TypeError
195195

196196

197+
# ### Defaults and dtype discovery
198+
197199
def default_int_type():
198200
return dtype('int64')
199201

@@ -202,14 +204,33 @@ def default_float_type():
202204
return dtype('float64')
203205

204206

207+
def default_complex_type():
208+
return dtype('complex128')
209+
210+
205211
def is_floating(dtyp):
206212
dtyp = dtype(dtyp)
207-
return dtyp.typecode in typecodes['AllFloat']
213+
return issubclass(dtyp.type, _scalar_types.floating)
214+
208215

209216
def is_integer(dtyp):
210217
dtyp = dtype(dtyp)
211-
return dtyp.typecode in typecodes['AllInteger']
212-
218+
return issubclass(dtyp.type, _scalar_types.integer)
219+
220+
221+
def get_default_dtype_for(dtyp):
222+
typ = dtype(dtyp).type
223+
if issubclass(typ, _scalar_types.integer):
224+
result = default_int_type()
225+
elif issubclass(typ, _scalar_types.floating):
226+
result = default_float_type()
227+
elif issubclass(typ, _scalar_types.complexfloating):
228+
result = default_complex_type()
229+
elif issubclass(typ, _scalar_types.bool_):
230+
result = dtype('bool')
231+
else:
232+
raise TypeError("dtype %s not understood." % dtyp)
233+
return result
213234

214235

215236
def issubclass_(arg, klass):

torch_np/_ndarray.py

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def __hash__(self):
159159
def __float__(self):
160160
return float(self._tensor)
161161

162+
def __int__(self):
163+
return int(self._tensor)
164+
162165
# XXX : are single-element ndarrays scalars?
163166
def is_integer(self):
164167
if self.shape == ():
@@ -354,7 +357,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal
354357

355358
if dtype is None:
356359
dtype = self.dtype
357-
if not _dtypes.is_floating(dtype):
360+
if _dtypes.is_integer(dtype):
358361
dtype = _dtypes.default_float_type()
359362
torch_dtype = _dtypes.torch_dtype_from(dtype)
360363

@@ -374,7 +377,7 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,
374377

375378
if dtype is None:
376379
dtype = self.dtype
377-
if not _dtypes.is_floating(dtype):
380+
if _dtypes.is_integer(dtype):
378381
dtype = _dtypes.default_float_type()
379382
torch_dtype = _dtypes.torch_dtype_from(dtype)
380383

@@ -396,67 +399,80 @@ def __setitem__(self, index, value):
396399
return self._tensor.__setitem__(index, value)
397400

398401

399-
def asarray(a, dtype=None, order=None, *, like=None):
400-
_util.subok_not_ok(like)
401-
if order is not None:
402+
# This is the ideally the only place which talks to ndarray directly.
403+
# The rest goes through asarray (preferred) or array.
404+
405+
def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,
406+
like=None):
407+
_util.subok_not_ok(like, subok)
408+
if order != 'K':
402409
raise NotImplementedError
403410

404-
if isinstance(a, ndarray):
405-
if dtype is not None and dtype != a.dtype:
406-
a = a.astype(dtype)
407-
return a
411+
# a happy path
412+
if isinstance(object, ndarray):
413+
if copy is False and dtype is None and ndmin <= object.ndim:
414+
return object
408415

409-
if isinstance(a, (list, tuple)):
410-
# handle lists of ndarrays, [1, [2, 3], ndarray(4)] etc
416+
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
417+
if isinstance(object, (list, tuple)):
411418
a1 = []
412-
for elem in a:
419+
for elem in object:
413420
if isinstance(elem, ndarray):
414421
a1.append(elem.get().tolist())
415422
else:
416423
a1.append(elem)
424+
object = a1
425+
426+
# get the tensor from "object"
427+
if isinstance(object, ndarray):
428+
tensor = object._tensor
429+
base = object
430+
elif isinstance(object, torch.Tensor):
431+
tensor = object
432+
base = None
417433
else:
418-
a1 = a
434+
tensor = torch.as_tensor(object)
435+
base = None
419436

420-
torch_dtype = _dtypes.torch_dtype_from(dtype)
437+
# At this point, `tensor.dtype` is the pytorch default. Our default may
438+
# differ, so need to typecast. However, we cannot just do `tensor.to`,
439+
# because if our desired dtype is wider then pytorch's, `tensor`
440+
# may have lost precision:
421441

422-
# This and array(...) are the only places which talk to ndarray directly.
423-
# The rest goes through asarray (preferred) or array.
424-
out = ndarray()
425-
tt = torch.as_tensor(a1, dtype=torch_dtype)
426-
out._tensor = tt
427-
return out
442+
# int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)
428443

444+
# Therefore, we treat `tensor.dtype` as a hint, and convert the
445+
# original object *again*, this time with an explicit dtype.
446+
dtyp = _dtypes.dtype_from_torch(tensor.dtype)
447+
default = _dtypes.get_default_dtype_for(dtyp)
448+
torch_dtype = _dtypes.torch_dtype_from(default)
429449

430-
def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,
431-
like=None):
432-
_util.subok_not_ok(like, subok)
433-
if order != 'K':
434-
raise NotImplementedError
435-
436-
if isinstance(object, (list, tuple)):
437-
obj = asarray(object)
438-
return array(obj, dtype, copy=copy, order=order, subok=subok,
439-
ndmin=ndmin, like=like)
450+
tensor = torch.as_tensor(object, dtype=torch_dtype)
440451

441-
if isinstance(object, ndarray):
442-
result = object._tensor
443-
444-
if dtype != object.dtype:
445-
torch_dtype = _dtypes.torch_dtype_from(dtype)
446-
result = result.to(torch_dtype)
447-
else:
452+
# type cast if requested
453+
if dtype is not None:
448454
torch_dtype = _dtypes.torch_dtype_from(dtype)
449-
result = torch.as_tensor(object, dtype=torch_dtype)
455+
tensor = tensor.to(torch_dtype)
456+
base = None
450457

458+
# adjust ndim if needed
459+
ndim_extra = ndmin - tensor.ndim
460+
if ndim_extra > 0:
461+
tensor = tensor.view((1,)*ndim_extra + tensor.shape)
462+
base = None
463+
464+
# copy if requested
451465
if copy:
452-
result = result.clone()
466+
tensor = tensor.clone()
467+
base = None
453468

454-
ndim_extra = ndmin - result.ndim
455-
if ndim_extra > 0:
456-
result = result.reshape((1,)*ndim_extra + result.shape)
457-
out = ndarray()
458-
out._tensor = result
459-
return out
469+
return ndarray._from_tensor_and_base(tensor, base)
470+
471+
472+
def asarray(a, dtype=None, order=None, *, like=None):
473+
if order is None:
474+
order = 'K'
475+
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
460476

461477

462478

torch_np/tests/numpy_tests/core/test_scalarmath.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def test_integers_to_negative_integer_power(self):
191191
# 1 ** -1 possible special case
192192
base = [np.array(1, dt)[()] for dt in 'bhilB']
193193
for i1, i2 in itertools.product(base, exp):
194-
pass
195194
if i1.dtype != np.uint64:
196195
assert_raises(ValueError, operator.pow, i1, i2)
197196
else:
@@ -435,15 +434,18 @@ def test_branches(self):
435434

436435
class TestConversion:
437436
def test_int_from_long(self):
437+
# NB: this test assumes that the default fp type is float64
438438
l = [1e6, 1e12, 1e18, -1e6, -1e12, -1e18]
439439
li = [10**6, 10**12, 10**18, -10**6, -10**12, -10**18]
440440
for T in [None, np.float64, np.int64]:
441441
a = np.array(l, dtype=T)
442442
assert_equal([int(_m) for _m in a], li)
443443

444+
444445
@pytest.mark.xfail(reason="pytorch does not emit this warning.")
445446
def test_iinfo_long_values_1(self):
446447
for code in 'bBh':
448+
447449
with pytest.warns(DeprecationWarning):
448450
res = np.array(np.iinfo(code).max + 1, dtype=code)
449451
tgt = np.iinfo(code).min

torch_np/tests/test_reductions.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,19 @@ def test_mean_values(self):
248248
rmat = np.arange(20, dtype=float).reshape((4, 5))
249249
cmat = rmat + 1j*rmat
250250

251-
for mat in [rmat, cmat]:
252-
for axis in [0, 1]:
253-
tgt = mat.sum(axis=axis)
254-
res = np.mean(mat, axis=axis) * mat.shape[axis]
255-
assert_allclose(res, tgt)
256-
257-
for axis in [None]:
258-
tgt = mat.sum(axis=axis)
259-
res = np.mean(mat, axis=axis) * mat.size
260-
assert_allclose(res, tgt)
251+
import warnings
252+
with warnings.catch_warnings():
253+
warnings.simplefilter('error')
254+
for mat in [rmat, cmat]:
255+
for axis in [0, 1]:
256+
tgt = mat.sum(axis=axis)
257+
res = np.mean(mat, axis=axis) * mat.shape[axis]
258+
assert_allclose(res, tgt)
259+
260+
for axis in [None]:
261+
tgt = mat.sum(axis=axis)
262+
res = np.mean(mat, axis=axis) * mat.size
263+
assert_allclose(res, tgt)
261264

262265
@pytest.mark.xfail(reason="see pytorch/gh-91597")
263266
def test_mean_float16(self):

0 commit comments

Comments
 (0)