14
14
15
15
import torch
16
16
17
- from . import _dtypes_impl , _helpers
17
+ from . import _dtypes_impl
18
18
from . import _reductions as _impl
19
19
from . import _util
20
20
from ._normalizations import (
27
27
normalize_array_like ,
28
28
)
29
29
30
- ###### array creation routines
30
+ # # ##### array creation routines
31
31
32
32
33
33
def copy (
@@ -71,18 +71,16 @@ def atleast_3d(*arys: ArrayLike):
71
71
72
72
73
73
def _concat_check (tup , dtype , out ):
74
- """Check inputs in concatenate et al."""
75
74
if tup == ():
76
- # XXX:RuntimeError in torch, ValueError in numpy
77
75
raise ValueError ("need at least one array to concatenate" )
78
76
79
- if out is not None :
80
- if dtype is not None :
81
- # mimic numpy
82
- raise TypeError (
83
- "concatenate() only takes `out` or `dtype` as an "
84
- "argument, but both were provided."
85
- )
77
+ """Check inputs in concatenate et al."""
78
+ if out is not None and dtype is not None :
79
+ # mimic numpy
80
+ raise TypeError (
81
+ "concatenate() only takes `out` or `dtype` as an "
82
+ "argument, but both were provided."
83
+ )
86
84
87
85
88
86
def _concat_cast_helper (tensors , out = None , dtype = None , casting = "same_kind" ):
@@ -104,12 +102,7 @@ def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
104
102
# pure torch implementation, used below and in cov/corrcoef below
105
103
tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
106
104
tensors = _concat_cast_helper (tensors , out , dtype , casting )
107
-
108
- try :
109
- result = torch .cat (tensors , axis )
110
- except (IndexError , RuntimeError ) as e :
111
- raise _util .AxisError (* e .args )
112
- return result
105
+ return torch .cat (tensors , axis )
113
106
114
107
115
108
def concatenate (
@@ -177,11 +170,7 @@ def stack(
177
170
tensors = _concat_cast_helper (arrays , dtype = dtype , casting = casting )
178
171
result_ndim = tensors [0 ].ndim + 1
179
172
axis = _util .normalize_axis_index (axis , result_ndim )
180
- try :
181
- result = torch .stack (tensors , axis = axis )
182
- except RuntimeError as e :
183
- raise ValueError (* e .args )
184
- return result
173
+ return torch .stack (tensors , axis = axis )
185
174
186
175
187
176
# ### split ###
@@ -352,24 +341,17 @@ def arange(
352
341
dtype = _dtypes_impl .default_dtypes .int_dtype
353
342
dt_list = [_util ._coerce_to_tensor (x ).dtype for x in (start , stop , step )]
354
343
dt_list .append (dtype )
355
- dtype = _dtypes_impl .result_type_impl (dt_list )
344
+ target_dtype = _dtypes_impl .result_type_impl (dt_list )
356
345
357
346
# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
358
- if dtype .is_complex :
359
- work_dtype , target_dtype = torch .float64 , dtype
360
- else :
361
- work_dtype , target_dtype = dtype , dtype
347
+ work_dtype = torch .float64 if target_dtype .is_complex else target_dtype
362
348
363
349
if (step > 0 and start > stop ) or (step < 0 and start < stop ):
364
350
# empty range
365
351
return torch .empty (0 , dtype = target_dtype )
366
352
367
- try :
368
- result = torch .arange (start , stop , step , dtype = work_dtype )
369
- result = _util .cast_if_needed (result , target_dtype )
370
- except RuntimeError :
371
- raise ValueError ("Maximum allowed size exceeded" )
372
-
353
+ result = torch .arange (start , stop , step , dtype = work_dtype )
354
+ result = _util .cast_if_needed (result , target_dtype )
373
355
return result
374
356
375
357
@@ -593,8 +575,7 @@ def where(
593
575
y : Optional [ArrayLike ] = None ,
594
576
/ ,
595
577
):
596
- selector = (x is None ) == (y is None )
597
- if not selector :
578
+ if (x is None ) != (y is None ):
598
579
raise ValueError ("either both or neither of x and y should be given" )
599
580
600
581
if condition .dtype != torch .bool :
@@ -603,14 +584,11 @@ def where(
603
584
if x is None and y is None :
604
585
result = torch .where (condition )
605
586
else :
606
- try :
607
- result = torch .where (condition , x , y )
608
- except RuntimeError as e :
609
- raise ValueError (* e .args )
587
+ result = torch .where (condition , x , y )
610
588
return result
611
589
612
590
613
- ###### module-level queries of object properties
591
+ # # ##### module-level queries of object properties
614
592
615
593
616
594
def ndim (a : ArrayLike ):
@@ -628,7 +606,7 @@ def size(a: ArrayLike, axis=None):
628
606
return a .shape [axis ]
629
607
630
608
631
- ###### shape manipulations and indexing
609
+ # # ##### shape manipulations and indexing
632
610
633
611
634
612
def expand_dims (a : ArrayLike , axis ):
@@ -665,6 +643,7 @@ def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
665
643
return torch .broadcast_to (array , size = shape )
666
644
667
645
646
+ # This is a function from tuples to tuples, so we just reuse it
668
647
from torch import broadcast_shapes
669
648
670
649
@@ -742,16 +721,15 @@ def triu_indices(n, k=0, m=None):
742
721
def tril_indices_from (arr : ArrayLike , k = 0 ):
743
722
if arr .ndim != 2 :
744
723
raise ValueError ("input array must be 2-d" )
745
- result = torch . tril_indices ( arr . shape [ 0 ], arr . shape [ 1 ], offset = k )
746
- return tuple ( result )
724
+ # Return a tensor rather than a tuple to avoid a graphbreak
725
+ return torch . tril_indices ( arr . shape [ 0 ], arr . shape [ 1 ], offset = k )
747
726
748
727
749
728
def triu_indices_from (arr : ArrayLike , k = 0 ):
750
729
if arr .ndim != 2 :
751
730
raise ValueError ("input array must be 2-d" )
752
- result = torch .triu_indices (arr .shape [0 ], arr .shape [1 ], offset = k )
753
- # unpack: numpy returns a 2-tuple of index arrays; torch returns a 2-row tensor
754
- return tuple (result )
731
+ # Return a tensor rather than a tuple to avoid a graphbreak
732
+ return torch .triu_indices (arr .shape [0 ], arr .shape [1 ], offset = k )
755
733
756
734
757
735
def tri (
@@ -765,34 +743,14 @@ def tri(
765
743
if M is None :
766
744
M = N
767
745
tensor = torch .ones ((N , M ), dtype = dtype )
768
- tensor = torch .tril (tensor , diagonal = k )
769
- return tensor
746
+ return torch .tril (tensor , diagonal = k )
770
747
771
748
772
- # ### nanfunctions ### # FIXME: this is a stub
749
+ # ### nanfunctions ###
773
750
774
751
775
- def nanmean (
776
- a : ArrayLike ,
777
- axis = None ,
778
- dtype : Optional [DTypeLike ] = None ,
779
- out : Optional [OutArray ] = None ,
780
- keepdims = None ,
781
- * ,
782
- where : NotImplementedType = None ,
783
- ):
784
- # XXX: this needs to be rewritten
785
- if dtype is None :
786
- dtype = a .dtype
787
- if axis is None :
788
- result = a .nanmean (dtype = dtype )
789
- if keepdims :
790
- result = torch .full (a .shape , result , dtype = result .dtype )
791
- else :
792
- result = a .nanmean (dtype = dtype , dim = axis , keepdim = bool (keepdims ))
793
- if out is not None :
794
- out .copy_ (result )
795
- return result
752
+ def nanmean ():
753
+ raise NotImplementedError
796
754
797
755
798
756
def nanmin ():
@@ -999,12 +957,7 @@ def clip(
999
957
max : Optional [ArrayLike ] = None ,
1000
958
out : Optional [OutArray ] = None ,
1001
959
):
1002
- # np.clip requires both a_min and a_max not None, while ndarray.clip allows
1003
- # one of them to be None. Follow the more lax version.
1004
- if min is None and max is None :
1005
- raise ValueError ("One of max or min must be given" )
1006
- result = torch .clamp (a , min , max )
1007
- return result
960
+ return torch .clamp (a , min , max )
1008
961
1009
962
1010
963
def repeat (a : ArrayLike , repeats : ArrayLike , axis = None ):
@@ -1368,15 +1321,10 @@ def transpose(a: ArrayLike, axes=None):
1368
1321
# numpy allows both .tranpose(sh) and .transpose(*sh)
1369
1322
# also older code uses axes being a list
1370
1323
if axes in [(), None , (None ,)]:
1371
- axes = tuple (range (a .ndim ))[:: - 1 ]
1324
+ axes = tuple (reversed ( range (a .ndim )))
1372
1325
elif len (axes ) == 1 :
1373
1326
axes = axes [0 ]
1374
-
1375
- try :
1376
- result = a .permute (axes )
1377
- except RuntimeError :
1378
- raise ValueError ("axes don't match array" )
1379
- return result
1327
+ return a .permute (axes )
1380
1328
1381
1329
1382
1330
def ravel (a : ArrayLike , order : NotImplementedType = "C" ):
@@ -1391,41 +1339,6 @@ def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
1391
1339
return torch .flatten (a )
1392
1340
1393
1341
1394
- # ### Type/shape etc queries ###
1395
-
1396
-
1397
- def real (a : ArrayLike ):
1398
- result = torch .real (a )
1399
- return result
1400
-
1401
-
1402
- def imag (a : ArrayLike ):
1403
- if a .is_complex ():
1404
- result = a .imag
1405
- else :
1406
- result = torch .zeros_like (a )
1407
- return result
1408
-
1409
-
1410
- def round_ (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1411
- if a .is_floating_point ():
1412
- result = torch .round (a , decimals = decimals )
1413
- elif a .is_complex ():
1414
- # RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1415
- result = (
1416
- torch .round (a .real , decimals = decimals )
1417
- + torch .round (a .imag , decimals = decimals ) * 1j
1418
- )
1419
- else :
1420
- # RuntimeError: "round_cpu" not implemented for 'int'
1421
- result = a
1422
- return result
1423
-
1424
-
1425
- around = round_
1426
- round = round_
1427
-
1428
-
1429
1342
# ### reductions ###
1430
1343
1431
1344
@@ -1742,6 +1655,9 @@ def sinc(x: ArrayLike):
1742
1655
return torch .sinc (x )
1743
1656
1744
1657
1658
+ # ### Type/shape etc queries ###
1659
+
1660
+
1745
1661
def real (a : ArrayLike ):
1746
1662
return torch .real (a )
1747
1663
@@ -1754,7 +1670,7 @@ def imag(a: ArrayLike):
1754
1670
return result
1755
1671
1756
1672
1757
- def round_ (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1673
+ def round (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1758
1674
if a .is_floating_point ():
1759
1675
result = torch .round (a , decimals = decimals )
1760
1676
elif a .is_complex ():
@@ -1769,8 +1685,8 @@ def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
1769
1685
return result
1770
1686
1771
1687
1772
- around = round_
1773
- round = round_
1688
+ around = round
1689
+ round_ = round
1774
1690
1775
1691
1776
1692
def real_if_close (a : ArrayLike , tol = 100 ):
0 commit comments