Skip to content

Commit 7b7224e

Browse files
committed
WIP: bare-bones normalizations via type hints
1 parent f667f8c commit 7b7224e

File tree

1 file changed

+102
-13
lines changed

1 file changed

+102
-13
lines changed

torch_np/_funcs.py

Lines changed: 102 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,102 @@
1+
import typing
2+
13
import torch
24

35
from . import _decorators, _helpers
46
from ._detail import _flips, _util
57
from ._detail import implementations as _impl
68

9+
################################## normalizations
710

8-
def nonzero(a):
9-
(tensor,) = _helpers.to_tensors(a)
10-
result = tensor.nonzero(as_tuple=True)
11+
ArrayLike = typing.TypeVar("ArrayLike")
12+
DTypeLike = typing.TypeVar("DTypeLike")
13+
SubokLike = typing.TypeVar("SubokLike")
14+
15+
16+
import inspect
17+
18+
from . import _dtypes
19+
20+
21+
def normalize_array_like(x, name=None):
22+
(tensor,) = _helpers.to_tensors(x)
23+
return tensor
24+
25+
26+
def normalize_dtype(dtype, name=None):
27+
# cf _decorators.dtype_to_torch
28+
torch_dtype = None
29+
if dtype is not None:
30+
dtype = _dtypes.dtype(dtype)
31+
torch_dtype = dtype.torch_dtype
32+
return torch_dtype
33+
34+
35+
def normalize_subok_like(arg, name):
36+
if arg:
37+
raise ValueError(f"'{name}' parameter is not supported.")
38+
39+
40+
normalizers = {
41+
ArrayLike: normalize_array_like,
42+
DTypeLike: normalize_dtype,
43+
SubokLike: normalize_subok_like,
44+
}
45+
46+
import functools
47+
48+
49+
def normalizer(func):
50+
@functools.wraps(func)
51+
def wrapped(*args, **kwds):
52+
sig = inspect.signature(func)
53+
54+
dct = {}
55+
# loop over positional parameters and actual arguments
56+
for arg, (name, parm) in zip(args, sig.parameters.items()):
57+
print(arg, name, parm.annotation)
58+
normalizer = normalizers.get(parm.annotation, None)
59+
if normalizer:
60+
dct[name] = normalizer(arg, name)
61+
else:
62+
# untyped arguments pass through
63+
dct[name] = arg
64+
65+
# normalize keyword arguments
66+
for name, arg in kwds.items():
67+
print("kw: ", name, sig.parameters[name].annotation)
68+
parm = sig.parameters[name]
69+
normalizer = normalizers.get(parm.annotation, None)
70+
if normalizer:
71+
dct[name] = normalizer(kwds[name], name)
72+
else:
73+
dct[name] = arg
74+
75+
ba = sig.bind(**dct)
76+
ba.apply_defaults()
77+
78+
# TODO:
79+
# 2. extra unknown args -- error out : nonzero([2, 0, 3], oops=42)
80+
# 3. [LOOKS OK] optional (tensor_or_none) : untyped => pass through
81+
# 4. [LOOKS OK] DTypeLike : positional or kw
82+
# 5. axes : live in _impl or in types? several ways of handling them
83+
# 6. keepdims : peel off, postprocess
84+
# 7. OutLike : normal & keyword-only, peel off, postprocess
85+
86+
# finally, pass normalized arguments through
87+
result = func(*ba.args)
88+
return result
89+
90+
return wrapped
91+
92+
93+
##################################
94+
95+
96+
@normalizer
97+
def nonzero(a: ArrayLike):
98+
# (tensor,) = _helpers.to_tensors(a)
99+
result = a.nonzero(as_tuple=True)
11100
return _helpers.tuple_arrays_from(result)
12101

13102

@@ -41,25 +130,25 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
41130
return _helpers.array_from(result)
42131

43132

44-
@_decorators.dtype_to_torch
45-
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
46-
(tensor,) = _helpers.to_tensors(a)
47-
result = _impl.trace(tensor, offset, axis1, axis2, dtype)
133+
@normalizer
134+
def trace(a: ArrayLike, offset=0, axis1=0, axis2=1, dtype: DTypeLike = None, out=None):
135+
# (tensor,) = _helpers.to_tensors(a)
136+
result = _impl.trace(a, offset, axis1, axis2, dtype)
48137
return _helpers.result_or_out(result, out)
49138

50139

51-
@_decorators.dtype_to_torch
52-
def eye(N, M=None, k=0, dtype=float, order="C", *, like=None):
53-
_util.subok_not_ok(like)
140+
@normalizer
141+
def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None):
142+
# _util.subok_not_ok(like)
54143
if order != "C":
55144
raise NotImplementedError
56145
result = _impl.eye(N, M, k, dtype)
57146
return _helpers.array_from(result)
58147

59148

60-
@_decorators.dtype_to_torch
61-
def identity(n, dtype=None, *, like=None):
62-
_util.subok_not_ok(like)
149+
@normalizer
150+
def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None):
151+
## _util.subok_not_ok(like)
63152
result = torch.eye(n, dtype=dtype)
64153
return _helpers.array_from(result)
65154

0 commit comments

Comments
 (0)