Skip to content

Commit 3887e26

Browse files
authored
Merge pull request #108 from Quansight-Labs/ndarray_annot
MAINT: normalize NDArray to tensors, add a special-case for out= NDArrays
2 parents 773d772 + f30f4c4 commit 3887e26

File tree

4 files changed

+93
-38
lines changed

4 files changed

+93
-38
lines changed

torch_np/_funcs_impl.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AxisLike,
2323
DTypeLike,
2424
NDArray,
25+
OutArray,
2526
SubokLike,
2627
normalize_array_like,
2728
)
@@ -41,8 +42,8 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
4142
def copyto(dst: NDArray, src: ArrayLike, casting="same_kind", where=NoValue):
4243
if where is not NoValue:
4344
raise NotImplementedError
44-
(src,) = _util.typecast_tensors((src,), dst.tensor.dtype, casting=casting)
45-
dst.tensor.copy_(src)
45+
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
46+
dst.copy_(src)
4647

4748

4849
def atleast_1d(*arys: ArrayLike):
@@ -114,7 +115,7 @@ def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
114115
def concatenate(
115116
ar_tuple: Sequence[ArrayLike],
116117
axis=0,
117-
out: Optional[NDArray] = None,
118+
out: Optional[OutArray] = None,
118119
dtype: DTypeLike = None,
119120
casting="same_kind",
120121
):
@@ -160,7 +161,7 @@ def column_stack(
160161
def stack(
161162
arrays: Sequence[ArrayLike],
162163
axis=0,
163-
out: Optional[NDArray] = None,
164+
out: Optional[OutArray] = None,
164165
*,
165166
dtype: DTypeLike = None,
166167
casting="same_kind",
@@ -754,7 +755,7 @@ def nanmean(
754755
a: ArrayLike,
755756
axis=None,
756757
dtype: DTypeLike = None,
757-
out: Optional[NDArray] = None,
758+
out: Optional[OutArray] = None,
758759
keepdims=NoValue,
759760
*,
760761
where=NoValue,
@@ -892,7 +893,7 @@ def take(
892893
a: ArrayLike,
893894
indices: ArrayLike,
894895
axis=None,
895-
out: Optional[NDArray] = None,
896+
out: Optional[OutArray] = None,
896897
mode="raise",
897898
):
898899
if mode != "raise":
@@ -975,7 +976,7 @@ def clip(
975976
a: ArrayLike,
976977
min: Optional[ArrayLike] = None,
977978
max: Optional[ArrayLike] = None,
978-
out: Optional[NDArray] = None,
979+
out: Optional[OutArray] = None,
979980
):
980981
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
981982
# one of them to be None. Follow the more lax version.
@@ -1070,7 +1071,7 @@ def trace(
10701071
axis1=0,
10711072
axis2=1,
10721073
dtype: DTypeLike = None,
1073-
out: Optional[NDArray] = None,
1074+
out: Optional[OutArray] = None,
10741075
):
10751076
result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
10761077
return result
@@ -1180,7 +1181,7 @@ def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
11801181
return torch.tensordot(a, b, dims=axes)
11811182

11821183

1183-
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
1184+
def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
11841185
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
11851186
a = _util.cast_if_needed(a, dtype)
11861187
b = _util.cast_if_needed(b, dtype)
@@ -1215,7 +1216,7 @@ def inner(a: ArrayLike, b: ArrayLike, /):
12151216
return result
12161217

12171218

1218-
def outer(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
1219+
def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
12191220
return torch.outer(a, b)
12201221

12211222

@@ -1382,7 +1383,7 @@ def imag(a: ArrayLike):
13821383
return result
13831384

13841385

1385-
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
1386+
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
13861387
if a.is_floating_point():
13871388
result = torch.round(a, decimals=decimals)
13881389
elif a.is_complex():
@@ -1408,7 +1409,7 @@ def sum(
14081409
a: ArrayLike,
14091410
axis: AxisLike = None,
14101411
dtype: DTypeLike = None,
1411-
out: Optional[NDArray] = None,
1412+
out: Optional[OutArray] = None,
14121413
keepdims=NoValue,
14131414
initial=NoValue,
14141415
where=NoValue,
@@ -1423,7 +1424,7 @@ def prod(
14231424
a: ArrayLike,
14241425
axis: AxisLike = None,
14251426
dtype: DTypeLike = None,
1426-
out: Optional[NDArray] = None,
1427+
out: Optional[OutArray] = None,
14271428
keepdims=NoValue,
14281429
initial=NoValue,
14291430
where=NoValue,
@@ -1441,7 +1442,7 @@ def mean(
14411442
a: ArrayLike,
14421443
axis: AxisLike = None,
14431444
dtype: DTypeLike = None,
1444-
out: Optional[NDArray] = None,
1445+
out: Optional[OutArray] = None,
14451446
keepdims=NoValue,
14461447
*,
14471448
where=NoValue,
@@ -1454,7 +1455,7 @@ def var(
14541455
a: ArrayLike,
14551456
axis: AxisLike = None,
14561457
dtype: DTypeLike = None,
1457-
out: Optional[NDArray] = None,
1458+
out: Optional[OutArray] = None,
14581459
ddof=0,
14591460
keepdims=NoValue,
14601461
*,
@@ -1470,7 +1471,7 @@ def std(
14701471
a: ArrayLike,
14711472
axis: AxisLike = None,
14721473
dtype: DTypeLike = None,
1473-
out: Optional[NDArray] = None,
1474+
out: Optional[OutArray] = None,
14741475
ddof=0,
14751476
keepdims=NoValue,
14761477
*,
@@ -1485,7 +1486,7 @@ def std(
14851486
def argmin(
14861487
a: ArrayLike,
14871488
axis: AxisLike = None,
1488-
out: Optional[NDArray] = None,
1489+
out: Optional[OutArray] = None,
14891490
*,
14901491
keepdims=NoValue,
14911492
):
@@ -1496,7 +1497,7 @@ def argmin(
14961497
def argmax(
14971498
a: ArrayLike,
14981499
axis: AxisLike = None,
1499-
out: Optional[NDArray] = None,
1500+
out: Optional[OutArray] = None,
15001501
*,
15011502
keepdims=NoValue,
15021503
):
@@ -1507,7 +1508,7 @@ def argmax(
15071508
def amax(
15081509
a: ArrayLike,
15091510
axis: AxisLike = None,
1510-
out: Optional[NDArray] = None,
1511+
out: Optional[OutArray] = None,
15111512
keepdims=NoValue,
15121513
initial=NoValue,
15131514
where=NoValue,
@@ -1522,7 +1523,7 @@ def amax(
15221523
def amin(
15231524
a: ArrayLike,
15241525
axis: AxisLike = None,
1525-
out: Optional[NDArray] = None,
1526+
out: Optional[OutArray] = None,
15261527
keepdims=NoValue,
15271528
initial=NoValue,
15281529
where=NoValue,
@@ -1535,7 +1536,10 @@ def amin(
15351536

15361537

15371538
def ptp(
1538-
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
1539+
a: ArrayLike,
1540+
axis: AxisLike = None,
1541+
out: Optional[OutArray] = None,
1542+
keepdims=NoValue,
15391543
):
15401544
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
15411545
return result
@@ -1544,7 +1548,7 @@ def ptp(
15441548
def all(
15451549
a: ArrayLike,
15461550
axis: AxisLike = None,
1547-
out: Optional[NDArray] = None,
1551+
out: Optional[OutArray] = None,
15481552
keepdims=NoValue,
15491553
*,
15501554
where=NoValue,
@@ -1556,7 +1560,7 @@ def all(
15561560
def any(
15571561
a: ArrayLike,
15581562
axis: AxisLike = None,
1559-
out: Optional[NDArray] = None,
1563+
out: Optional[OutArray] = None,
15601564
keepdims=NoValue,
15611565
*,
15621566
where=NoValue,
@@ -1574,7 +1578,7 @@ def cumsum(
15741578
a: ArrayLike,
15751579
axis: AxisLike = None,
15761580
dtype: DTypeLike = None,
1577-
out: Optional[NDArray] = None,
1581+
out: Optional[OutArray] = None,
15781582
):
15791583
result = _impl.cumsum(a, axis=axis, dtype=dtype)
15801584
return result
@@ -1584,7 +1588,7 @@ def cumprod(
15841588
a: ArrayLike,
15851589
axis: AxisLike = None,
15861590
dtype: DTypeLike = None,
1587-
out: Optional[NDArray] = None,
1591+
out: Optional[OutArray] = None,
15881592
):
15891593
result = _impl.cumprod(a, axis=axis, dtype=dtype)
15901594
return result
@@ -1597,7 +1601,7 @@ def quantile(
15971601
a: ArrayLike,
15981602
q: ArrayLike,
15991603
axis: AxisLike = None,
1600-
out: Optional[NDArray] = None,
1604+
out: Optional[OutArray] = None,
16011605
overwrite_input=False,
16021606
method="linear",
16031607
keepdims=False,
@@ -1620,7 +1624,7 @@ def percentile(
16201624
a: ArrayLike,
16211625
q: ArrayLike,
16221626
axis: AxisLike = None,
1623-
out: Optional[NDArray] = None,
1627+
out: Optional[OutArray] = None,
16241628
overwrite_input=False,
16251629
method="linear",
16261630
keepdims=False,
@@ -1642,7 +1646,7 @@ def percentile(
16421646
def median(
16431647
a: ArrayLike,
16441648
axis=None,
1645-
out: Optional[NDArray] = None,
1649+
out: Optional[OutArray] = None,
16461650
overwrite_input=False,
16471651
keepdims=False,
16481652
):
@@ -1726,7 +1730,7 @@ def imag(a: ArrayLike):
17261730
return result
17271731

17281732

1729-
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
1733+
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
17301734
if a.is_floating_point():
17311735
result = torch.round(a, decimals=decimals)
17321736
elif a.is_complex():
@@ -1786,11 +1790,11 @@ def isrealobj(x: ArrayLike):
17861790
return not torch.is_complex(x)
17871791

17881792

1789-
def isneginf(x: ArrayLike, out: Optional[NDArray] = None):
1793+
def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
17901794
return torch.isneginf(x, out=out)
17911795

17921796

1793-
def isposinf(x: ArrayLike, out: Optional[NDArray] = None):
1797+
def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
17941798
return torch.isposinf(x, out=out)
17951799

17961800

torch_np/_normalizations.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SubokLike = typing.TypeVar("SubokLike")
1515
AxisLike = typing.TypeVar("AxisLike")
1616
NDArray = typing.TypeVar("NDarray")
17+
OutArray = typing.TypeVar("OutArray")
1718

1819

1920
import inspect
@@ -60,6 +61,19 @@ def normalize_axis_like(arg, name=None):
6061

6162

6263
def normalize_ndarray(arg, name=None):
64+
# check the arg is an ndarray, extract its tensor attribute
65+
if arg is None:
66+
return arg
67+
68+
from ._ndarray import ndarray
69+
70+
if not isinstance(arg, ndarray):
71+
raise TypeError("'out' must be an array")
72+
return arg.tensor
73+
74+
75+
def normalize_outarray(arg, name=None):
76+
# almost normalize_ndarray, only return the array, not its tensor
6377
if arg is None:
6478
return arg
6579

@@ -75,6 +89,8 @@ def normalize_ndarray(arg, name=None):
7589
Optional[ArrayLike]: normalize_optional_array_like,
7690
Sequence[ArrayLike]: normalize_seq_array_like,
7791
Optional[NDArray]: normalize_ndarray,
92+
Optional[OutArray]: normalize_outarray,
93+
NDArray: normalize_ndarray,
7894
DTypeLike: normalize_dtype,
7995
SubokLike: normalize_subok_like,
8096
AxisLike: normalize_axis_like,
@@ -164,6 +180,9 @@ def wrapped(*args, **kwds):
164180

165181
if "out" in params:
166182
out = sig.bind(*args, **kwds).arguments.get("out")
183+
184+
### if out is not None: breakpoint()
185+
167186
result = maybe_copy_to(out, result, promote_scalar_result)
168187
result = wrap_tensors(result)
169188

torch_np/_ufuncs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from . import _binary_ufuncs_impl, _helpers, _unary_ufuncs_impl
66
from ._detail import _dtypes_impl, _util
7-
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
7+
from ._normalizations import ArrayLike, DTypeLike, OutArray, SubokLike, normalizer
88

99

1010
def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
@@ -46,7 +46,7 @@ def wrapped(
4646
x1: ArrayLike,
4747
x2: ArrayLike,
4848
/,
49-
out: Optional[NDArray] = None,
49+
out: Optional[OutArray] = None,
5050
*,
5151
where=True,
5252
casting="same_kind",
@@ -80,7 +80,7 @@ def matmul(
8080
x1: ArrayLike,
8181
x2: ArrayLike,
8282
/,
83-
out: Optional[NDArray] = None,
83+
out: Optional[OutArray] = None,
8484
*,
8585
casting="same_kind",
8686
order="K",
@@ -109,10 +109,10 @@ def matmul(
109109
def divmod(
110110
x1: ArrayLike,
111111
x2: ArrayLike,
112-
out1: Optional[NDArray] = None,
113-
out2: Optional[NDArray] = None,
112+
out1: Optional[OutArray] = None,
113+
out2: Optional[OutArray] = None,
114114
/,
115-
out: tuple[Optional[NDArray], Optional[NDArray]] = (None, None),
115+
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
116116
*,
117117
where=True,
118118
casting="same_kind",
@@ -181,7 +181,7 @@ def deco_unary_ufunc(torch_func):
181181
def wrapped(
182182
x: ArrayLike,
183183
/,
184-
out: Optional[NDArray] = None,
184+
out: Optional[OutArray] = None,
185185
*,
186186
where=True,
187187
casting="same_kind",

0 commit comments

Comments
 (0)