Skip to content

Commit 134db55

Browse files
authored
Merge pull request #121 from Quansight-Labs/review
Assorted fixes and simplifications
2 parents 53a6da2 + 509e1c2 commit 134db55

File tree

7 files changed

+76
-150
lines changed

7 files changed

+76
-150
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
__pycache__/
33
*.py[cod]
44
.coverage
5+
.hypothesis
56

torch_np/_funcs_impl.py

Lines changed: 37 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616

17-
from . import _dtypes_impl, _helpers
17+
from . import _dtypes_impl
1818
from . import _reductions as _impl
1919
from . import _util
2020
from ._normalizations import (
@@ -27,7 +27,7 @@
2727
normalize_array_like,
2828
)
2929

30-
###### array creation routines
30+
# ###### array creation routines
3131

3232

3333
def copy(
@@ -71,18 +71,16 @@ def atleast_3d(*arys: ArrayLike):
7171

7272

7373
def _concat_check(tup, dtype, out):
74-
"""Check inputs in concatenate et al."""
7574
if tup == ():
76-
# XXX:RuntimeError in torch, ValueError in numpy
7775
raise ValueError("need at least one array to concatenate")
7876

79-
if out is not None:
80-
if dtype is not None:
81-
# mimic numpy
82-
raise TypeError(
83-
"concatenate() only takes `out` or `dtype` as an "
84-
"argument, but both were provided."
85-
)
77+
"""Check inputs in concatenate et al."""
78+
if out is not None and dtype is not None:
79+
# mimic numpy
80+
raise TypeError(
81+
"concatenate() only takes `out` or `dtype` as an "
82+
"argument, but both were provided."
83+
)
8684

8785

8886
def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
@@ -104,12 +102,7 @@ def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
104102
# pure torch implementation, used below and in cov/corrcoef below
105103
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
106104
tensors = _concat_cast_helper(tensors, out, dtype, casting)
107-
108-
try:
109-
result = torch.cat(tensors, axis)
110-
except (IndexError, RuntimeError) as e:
111-
raise _util.AxisError(*e.args)
112-
return result
105+
return torch.cat(tensors, axis)
113106

114107

115108
def concatenate(
@@ -177,11 +170,7 @@ def stack(
177170
tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting)
178171
result_ndim = tensors[0].ndim + 1
179172
axis = _util.normalize_axis_index(axis, result_ndim)
180-
try:
181-
result = torch.stack(tensors, axis=axis)
182-
except RuntimeError as e:
183-
raise ValueError(*e.args)
184-
return result
173+
return torch.stack(tensors, axis=axis)
185174

186175

187176
# ### split ###
@@ -352,24 +341,17 @@ def arange(
352341
dtype = _dtypes_impl.default_dtypes.int_dtype
353342
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
354343
dt_list.append(dtype)
355-
dtype = _dtypes_impl.result_type_impl(dt_list)
344+
target_dtype = _dtypes_impl.result_type_impl(dt_list)
356345

357346
# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
358-
if dtype.is_complex:
359-
work_dtype, target_dtype = torch.float64, dtype
360-
else:
361-
work_dtype, target_dtype = dtype, dtype
347+
work_dtype = torch.float64 if target_dtype.is_complex else target_dtype
362348

363349
if (step > 0 and start > stop) or (step < 0 and start < stop):
364350
# empty range
365351
return torch.empty(0, dtype=target_dtype)
366352

367-
try:
368-
result = torch.arange(start, stop, step, dtype=work_dtype)
369-
result = _util.cast_if_needed(result, target_dtype)
370-
except RuntimeError:
371-
raise ValueError("Maximum allowed size exceeded")
372-
353+
result = torch.arange(start, stop, step, dtype=work_dtype)
354+
result = _util.cast_if_needed(result, target_dtype)
373355
return result
374356

375357

@@ -593,8 +575,7 @@ def where(
593575
y: Optional[ArrayLike] = None,
594576
/,
595577
):
596-
selector = (x is None) == (y is None)
597-
if not selector:
578+
if (x is None) != (y is None):
598579
raise ValueError("either both or neither of x and y should be given")
599580

600581
if condition.dtype != torch.bool:
@@ -603,14 +584,11 @@ def where(
603584
if x is None and y is None:
604585
result = torch.where(condition)
605586
else:
606-
try:
607-
result = torch.where(condition, x, y)
608-
except RuntimeError as e:
609-
raise ValueError(*e.args)
587+
result = torch.where(condition, x, y)
610588
return result
611589

612590

613-
###### module-level queries of object properties
591+
# ###### module-level queries of object properties
614592

615593

616594
def ndim(a: ArrayLike):
@@ -628,7 +606,7 @@ def size(a: ArrayLike, axis=None):
628606
return a.shape[axis]
629607

630608

631-
###### shape manipulations and indexing
609+
# ###### shape manipulations and indexing
632610

633611

634612
def expand_dims(a: ArrayLike, axis):
@@ -665,6 +643,7 @@ def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
665643
return torch.broadcast_to(array, size=shape)
666644

667645

646+
# This is a function from tuples to tuples, so we just reuse it
668647
from torch import broadcast_shapes
669648

670649

@@ -742,16 +721,15 @@ def triu_indices(n, k=0, m=None):
742721
def tril_indices_from(arr: ArrayLike, k=0):
743722
if arr.ndim != 2:
744723
raise ValueError("input array must be 2-d")
745-
result = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
746-
return tuple(result)
724+
# Return a tensor rather than a tuple to avoid a graphbreak
725+
return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
747726

748727

749728
def triu_indices_from(arr: ArrayLike, k=0):
750729
if arr.ndim != 2:
751730
raise ValueError("input array must be 2-d")
752-
result = torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)
753-
# unpack: numpy returns a 2-tuple of index arrays; torch returns a 2-row tensor
754-
return tuple(result)
731+
# Return a tensor rather than a tuple to avoid a graphbreak
732+
return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)
755733

756734

757735
def tri(
@@ -765,34 +743,14 @@ def tri(
765743
if M is None:
766744
M = N
767745
tensor = torch.ones((N, M), dtype=dtype)
768-
tensor = torch.tril(tensor, diagonal=k)
769-
return tensor
746+
return torch.tril(tensor, diagonal=k)
770747

771748

772-
# ### nanfunctions ### # FIXME: this is a stub
749+
# ### nanfunctions ###
773750

774751

775-
def nanmean(
776-
a: ArrayLike,
777-
axis=None,
778-
dtype: Optional[DTypeLike] = None,
779-
out: Optional[OutArray] = None,
780-
keepdims=None,
781-
*,
782-
where: NotImplementedType = None,
783-
):
784-
# XXX: this needs to be rewritten
785-
if dtype is None:
786-
dtype = a.dtype
787-
if axis is None:
788-
result = a.nanmean(dtype=dtype)
789-
if keepdims:
790-
result = torch.full(a.shape, result, dtype=result.dtype)
791-
else:
792-
result = a.nanmean(dtype=dtype, dim=axis, keepdim=bool(keepdims))
793-
if out is not None:
794-
out.copy_(result)
795-
return result
752+
def nanmean():
753+
raise NotImplementedError
796754

797755

798756
def nanmin():
@@ -999,12 +957,7 @@ def clip(
999957
max: Optional[ArrayLike] = None,
1000958
out: Optional[OutArray] = None,
1001959
):
1002-
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
1003-
# one of them to be None. Follow the more lax version.
1004-
if min is None and max is None:
1005-
raise ValueError("One of max or min must be given")
1006-
result = torch.clamp(a, min, max)
1007-
return result
960+
return torch.clamp(a, min, max)
1008961

1009962

1010963
def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
@@ -1368,15 +1321,10 @@ def transpose(a: ArrayLike, axes=None):
13681321
# numpy allows both .tranpose(sh) and .transpose(*sh)
13691322
# also older code uses axes being a list
13701323
if axes in [(), None, (None,)]:
1371-
axes = tuple(range(a.ndim))[::-1]
1324+
axes = tuple(reversed(range(a.ndim)))
13721325
elif len(axes) == 1:
13731326
axes = axes[0]
1374-
1375-
try:
1376-
result = a.permute(axes)
1377-
except RuntimeError:
1378-
raise ValueError("axes don't match array")
1379-
return result
1327+
return a.permute(axes)
13801328

13811329

13821330
def ravel(a: ArrayLike, order: NotImplementedType = "C"):
@@ -1391,41 +1339,6 @@ def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
13911339
return torch.flatten(a)
13921340

13931341

1394-
# ### Type/shape etc queries ###
1395-
1396-
1397-
def real(a: ArrayLike):
1398-
result = torch.real(a)
1399-
return result
1400-
1401-
1402-
def imag(a: ArrayLike):
1403-
if a.is_complex():
1404-
result = a.imag
1405-
else:
1406-
result = torch.zeros_like(a)
1407-
return result
1408-
1409-
1410-
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
1411-
if a.is_floating_point():
1412-
result = torch.round(a, decimals=decimals)
1413-
elif a.is_complex():
1414-
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1415-
result = (
1416-
torch.round(a.real, decimals=decimals)
1417-
+ torch.round(a.imag, decimals=decimals) * 1j
1418-
)
1419-
else:
1420-
# RuntimeError: "round_cpu" not implemented for 'int'
1421-
result = a
1422-
return result
1423-
1424-
1425-
around = round_
1426-
round = round_
1427-
1428-
14291342
# ### reductions ###
14301343

14311344

@@ -1742,6 +1655,9 @@ def sinc(x: ArrayLike):
17421655
return torch.sinc(x)
17431656

17441657

1658+
# ### Type/shape etc queries ###
1659+
1660+
17451661
def real(a: ArrayLike):
17461662
return torch.real(a)
17471663

@@ -1754,7 +1670,7 @@ def imag(a: ArrayLike):
17541670
return result
17551671

17561672

1757-
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
1673+
def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
17581674
if a.is_floating_point():
17591675
result = torch.round(a, decimals=decimals)
17601676
elif a.is_complex():
@@ -1769,8 +1685,8 @@ def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
17691685
return result
17701686

17711687

1772-
around = round_
1773-
round = round_
1688+
around = round
1689+
round_ = round
17741690

17751691

17761692
def real_if_close(a: ArrayLike, tol=100):

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,9 +1532,9 @@ def test_squeeze(self):
15321532
def test_transpose(self):
15331533
a = np.array([[1, 2], [3, 4]])
15341534
assert_equal(a.transpose(), [[1, 3], [2, 4]])
1535-
assert_raises(ValueError, lambda: a.transpose(0))
1536-
assert_raises(ValueError, lambda: a.transpose(0, 0))
1537-
assert_raises(ValueError, lambda: a.transpose(0, 1, 2))
1535+
assert_raises((RuntimeError, ValueError), lambda: a.transpose(0))
1536+
assert_raises((RuntimeError, ValueError), lambda: a.transpose(0, 0))
1537+
assert_raises((RuntimeError, ValueError), lambda: a.transpose(0, 1, 2))
15381538

15391539
def test_sort(self):
15401540
# test ordering for floats and complex containing nans. It is only
@@ -7270,8 +7270,8 @@ def test_error(self):
72707270
c = [True, True]
72717271
a = np.ones((4, 5))
72727272
b = np.ones((5, 5))
7273-
assert_raises(ValueError, np.where, c, a, a)
7274-
assert_raises(ValueError, np.where, c[0], a, b)
7273+
assert_raises((RuntimeError, ValueError), np.where, c, a, a)
7274+
assert_raises((RuntimeError, ValueError), np.where, c[0], a, b)
72757275

72767276
def test_empty_result(self):
72777277
# pass empty where result through an assignment which reads the data of
@@ -7497,14 +7497,14 @@ def test_view_discard_refcount(self):
74977497

74987498
class TestArange:
74997499
def test_infinite(self):
7500-
assert_raises_regex(
7501-
ValueError, "size exceeded",
7500+
assert_raises(
7501+
(RuntimeError, ValueError), # "unsupported range",
75027502
np.arange, 0, np.inf
75037503
)
75047504

75057505
def test_nan_step(self):
75067506
assert_raises(
7507-
ValueError, # "cannot compute length",
7507+
(RuntimeError, ValueError), # "cannot compute length",
75087508
np.arange, 0, 1, np.nan
75097509
)
75107510

0 commit comments

Comments
 (0)