22
22
AxisLike ,
23
23
DTypeLike ,
24
24
NDArray ,
25
+ OutArray ,
25
26
SubokLike ,
26
27
normalize_array_like ,
27
28
)
@@ -41,8 +42,8 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
41
42
def copyto (dst : NDArray , src : ArrayLike , casting = "same_kind" , where = NoValue ):
42
43
if where is not NoValue :
43
44
raise NotImplementedError
44
- (src ,) = _util .typecast_tensors ((src ,), dst .tensor . dtype , casting = casting )
45
- dst .tensor . copy_ (src )
45
+ (src ,) = _util .typecast_tensors ((src ,), dst .dtype , casting = casting )
46
+ dst .copy_ (src )
46
47
47
48
48
49
def atleast_1d (* arys : ArrayLike ):
@@ -114,7 +115,7 @@ def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
114
115
def concatenate (
115
116
ar_tuple : Sequence [ArrayLike ],
116
117
axis = 0 ,
117
- out : Optional [NDArray ] = None ,
118
+ out : Optional [OutArray ] = None ,
118
119
dtype : DTypeLike = None ,
119
120
casting = "same_kind" ,
120
121
):
@@ -160,7 +161,7 @@ def column_stack(
160
161
def stack (
161
162
arrays : Sequence [ArrayLike ],
162
163
axis = 0 ,
163
- out : Optional [NDArray ] = None ,
164
+ out : Optional [OutArray ] = None ,
164
165
* ,
165
166
dtype : DTypeLike = None ,
166
167
casting = "same_kind" ,
@@ -754,7 +755,7 @@ def nanmean(
754
755
a : ArrayLike ,
755
756
axis = None ,
756
757
dtype : DTypeLike = None ,
757
- out : Optional [NDArray ] = None ,
758
+ out : Optional [OutArray ] = None ,
758
759
keepdims = NoValue ,
759
760
* ,
760
761
where = NoValue ,
@@ -892,7 +893,7 @@ def take(
892
893
a : ArrayLike ,
893
894
indices : ArrayLike ,
894
895
axis = None ,
895
- out : Optional [NDArray ] = None ,
896
+ out : Optional [OutArray ] = None ,
896
897
mode = "raise" ,
897
898
):
898
899
if mode != "raise" :
@@ -975,7 +976,7 @@ def clip(
975
976
a : ArrayLike ,
976
977
min : Optional [ArrayLike ] = None ,
977
978
max : Optional [ArrayLike ] = None ,
978
- out : Optional [NDArray ] = None ,
979
+ out : Optional [OutArray ] = None ,
979
980
):
980
981
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
981
982
# one of them to be None. Follow the more lax version.
@@ -1070,7 +1071,7 @@ def trace(
1070
1071
axis1 = 0 ,
1071
1072
axis2 = 1 ,
1072
1073
dtype : DTypeLike = None ,
1073
- out : Optional [NDArray ] = None ,
1074
+ out : Optional [OutArray ] = None ,
1074
1075
):
1075
1076
result = torch .diagonal (a , offset , dim1 = axis1 , dim2 = axis2 ).sum (- 1 , dtype = dtype )
1076
1077
return result
@@ -1180,7 +1181,7 @@ def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
1180
1181
return torch .tensordot (a , b , dims = axes )
1181
1182
1182
1183
1183
- def dot (a : ArrayLike , b : ArrayLike , out : Optional [NDArray ] = None ):
1184
+ def dot (a : ArrayLike , b : ArrayLike , out : Optional [OutArray ] = None ):
1184
1185
dtype = _dtypes_impl .result_type_impl ((a .dtype , b .dtype ))
1185
1186
a = _util .cast_if_needed (a , dtype )
1186
1187
b = _util .cast_if_needed (b , dtype )
@@ -1215,7 +1216,7 @@ def inner(a: ArrayLike, b: ArrayLike, /):
1215
1216
return result
1216
1217
1217
1218
1218
- def outer (a : ArrayLike , b : ArrayLike , out : Optional [NDArray ] = None ):
1219
+ def outer (a : ArrayLike , b : ArrayLike , out : Optional [OutArray ] = None ):
1219
1220
return torch .outer (a , b )
1220
1221
1221
1222
@@ -1382,7 +1383,7 @@ def imag(a: ArrayLike):
1382
1383
return result
1383
1384
1384
1385
1385
- def round_ (a : ArrayLike , decimals = 0 , out : Optional [NDArray ] = None ):
1386
+ def round_ (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1386
1387
if a .is_floating_point ():
1387
1388
result = torch .round (a , decimals = decimals )
1388
1389
elif a .is_complex ():
@@ -1408,7 +1409,7 @@ def sum(
1408
1409
a : ArrayLike ,
1409
1410
axis : AxisLike = None ,
1410
1411
dtype : DTypeLike = None ,
1411
- out : Optional [NDArray ] = None ,
1412
+ out : Optional [OutArray ] = None ,
1412
1413
keepdims = NoValue ,
1413
1414
initial = NoValue ,
1414
1415
where = NoValue ,
@@ -1423,7 +1424,7 @@ def prod(
1423
1424
a : ArrayLike ,
1424
1425
axis : AxisLike = None ,
1425
1426
dtype : DTypeLike = None ,
1426
- out : Optional [NDArray ] = None ,
1427
+ out : Optional [OutArray ] = None ,
1427
1428
keepdims = NoValue ,
1428
1429
initial = NoValue ,
1429
1430
where = NoValue ,
@@ -1441,7 +1442,7 @@ def mean(
1441
1442
a : ArrayLike ,
1442
1443
axis : AxisLike = None ,
1443
1444
dtype : DTypeLike = None ,
1444
- out : Optional [NDArray ] = None ,
1445
+ out : Optional [OutArray ] = None ,
1445
1446
keepdims = NoValue ,
1446
1447
* ,
1447
1448
where = NoValue ,
@@ -1454,7 +1455,7 @@ def var(
1454
1455
a : ArrayLike ,
1455
1456
axis : AxisLike = None ,
1456
1457
dtype : DTypeLike = None ,
1457
- out : Optional [NDArray ] = None ,
1458
+ out : Optional [OutArray ] = None ,
1458
1459
ddof = 0 ,
1459
1460
keepdims = NoValue ,
1460
1461
* ,
@@ -1470,7 +1471,7 @@ def std(
1470
1471
a : ArrayLike ,
1471
1472
axis : AxisLike = None ,
1472
1473
dtype : DTypeLike = None ,
1473
- out : Optional [NDArray ] = None ,
1474
+ out : Optional [OutArray ] = None ,
1474
1475
ddof = 0 ,
1475
1476
keepdims = NoValue ,
1476
1477
* ,
@@ -1485,7 +1486,7 @@ def std(
1485
1486
def argmin (
1486
1487
a : ArrayLike ,
1487
1488
axis : AxisLike = None ,
1488
- out : Optional [NDArray ] = None ,
1489
+ out : Optional [OutArray ] = None ,
1489
1490
* ,
1490
1491
keepdims = NoValue ,
1491
1492
):
@@ -1496,7 +1497,7 @@ def argmin(
1496
1497
def argmax (
1497
1498
a : ArrayLike ,
1498
1499
axis : AxisLike = None ,
1499
- out : Optional [NDArray ] = None ,
1500
+ out : Optional [OutArray ] = None ,
1500
1501
* ,
1501
1502
keepdims = NoValue ,
1502
1503
):
@@ -1507,7 +1508,7 @@ def argmax(
1507
1508
def amax (
1508
1509
a : ArrayLike ,
1509
1510
axis : AxisLike = None ,
1510
- out : Optional [NDArray ] = None ,
1511
+ out : Optional [OutArray ] = None ,
1511
1512
keepdims = NoValue ,
1512
1513
initial = NoValue ,
1513
1514
where = NoValue ,
@@ -1522,7 +1523,7 @@ def amax(
1522
1523
def amin (
1523
1524
a : ArrayLike ,
1524
1525
axis : AxisLike = None ,
1525
- out : Optional [NDArray ] = None ,
1526
+ out : Optional [OutArray ] = None ,
1526
1527
keepdims = NoValue ,
1527
1528
initial = NoValue ,
1528
1529
where = NoValue ,
@@ -1535,7 +1536,10 @@ def amin(
1535
1536
1536
1537
1537
1538
def ptp (
1538
- a : ArrayLike , axis : AxisLike = None , out : Optional [NDArray ] = None , keepdims = NoValue
1539
+ a : ArrayLike ,
1540
+ axis : AxisLike = None ,
1541
+ out : Optional [OutArray ] = None ,
1542
+ keepdims = NoValue ,
1539
1543
):
1540
1544
result = _impl .ptp (a , axis = axis , keepdims = keepdims )
1541
1545
return result
@@ -1544,7 +1548,7 @@ def ptp(
1544
1548
def all (
1545
1549
a : ArrayLike ,
1546
1550
axis : AxisLike = None ,
1547
- out : Optional [NDArray ] = None ,
1551
+ out : Optional [OutArray ] = None ,
1548
1552
keepdims = NoValue ,
1549
1553
* ,
1550
1554
where = NoValue ,
@@ -1556,7 +1560,7 @@ def all(
1556
1560
def any (
1557
1561
a : ArrayLike ,
1558
1562
axis : AxisLike = None ,
1559
- out : Optional [NDArray ] = None ,
1563
+ out : Optional [OutArray ] = None ,
1560
1564
keepdims = NoValue ,
1561
1565
* ,
1562
1566
where = NoValue ,
@@ -1574,7 +1578,7 @@ def cumsum(
1574
1578
a : ArrayLike ,
1575
1579
axis : AxisLike = None ,
1576
1580
dtype : DTypeLike = None ,
1577
- out : Optional [NDArray ] = None ,
1581
+ out : Optional [OutArray ] = None ,
1578
1582
):
1579
1583
result = _impl .cumsum (a , axis = axis , dtype = dtype )
1580
1584
return result
@@ -1584,7 +1588,7 @@ def cumprod(
1584
1588
a : ArrayLike ,
1585
1589
axis : AxisLike = None ,
1586
1590
dtype : DTypeLike = None ,
1587
- out : Optional [NDArray ] = None ,
1591
+ out : Optional [OutArray ] = None ,
1588
1592
):
1589
1593
result = _impl .cumprod (a , axis = axis , dtype = dtype )
1590
1594
return result
@@ -1597,7 +1601,7 @@ def quantile(
1597
1601
a : ArrayLike ,
1598
1602
q : ArrayLike ,
1599
1603
axis : AxisLike = None ,
1600
- out : Optional [NDArray ] = None ,
1604
+ out : Optional [OutArray ] = None ,
1601
1605
overwrite_input = False ,
1602
1606
method = "linear" ,
1603
1607
keepdims = False ,
@@ -1620,7 +1624,7 @@ def percentile(
1620
1624
a : ArrayLike ,
1621
1625
q : ArrayLike ,
1622
1626
axis : AxisLike = None ,
1623
- out : Optional [NDArray ] = None ,
1627
+ out : Optional [OutArray ] = None ,
1624
1628
overwrite_input = False ,
1625
1629
method = "linear" ,
1626
1630
keepdims = False ,
@@ -1642,7 +1646,7 @@ def percentile(
1642
1646
def median (
1643
1647
a : ArrayLike ,
1644
1648
axis = None ,
1645
- out : Optional [NDArray ] = None ,
1649
+ out : Optional [OutArray ] = None ,
1646
1650
overwrite_input = False ,
1647
1651
keepdims = False ,
1648
1652
):
@@ -1726,7 +1730,7 @@ def imag(a: ArrayLike):
1726
1730
return result
1727
1731
1728
1732
1729
- def round_ (a : ArrayLike , decimals = 0 , out : Optional [NDArray ] = None ):
1733
+ def round_ (a : ArrayLike , decimals = 0 , out : Optional [OutArray ] = None ):
1730
1734
if a .is_floating_point ():
1731
1735
result = torch .round (a , decimals = decimals )
1732
1736
elif a .is_complex ():
@@ -1786,11 +1790,11 @@ def isrealobj(x: ArrayLike):
1786
1790
return not torch .is_complex (x )
1787
1791
1788
1792
1789
- def isneginf (x : ArrayLike , out : Optional [NDArray ] = None ):
1793
+ def isneginf (x : ArrayLike , out : Optional [OutArray ] = None ):
1790
1794
return torch .isneginf (x , out = out )
1791
1795
1792
1796
1793
- def isposinf (x : ArrayLike , out : Optional [NDArray ] = None ):
1797
+ def isposinf (x : ArrayLike , out : Optional [OutArray ] = None ):
1794
1798
return torch .isposinf (x , out = out )
1795
1799
1796
1800
0 commit comments