Skip to content

Commit e94647a

Browse files
committed
MAINT: move normalization logic to _normalizations
1 parent c4dffa5 commit e94647a

File tree

6 files changed

+47
-196
lines changed

6 files changed

+47
-196
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
from ._detail import _binary_ufuncs
22

3+
from ._normalizations import ArrayLike, DTypeLike, SubokLike, normalizer
4+
from . import _helpers
5+
36
__all__ = [
47
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
58
]
69

710

8-
from . import _helpers
9-
from ._detail import _util
10-
11-
# TODO: consolidate normalizations
12-
from ._funcs import ArrayLike, DTypeLike, SubokLike, normalizer
13-
1411

1512
def deco_binary_ufunc(torch_func):
16-
"""Common infra for unary ufuncs.
13+
"""Common infra for binary ufuncs.
1714
1815
Normalize arguments, sort out type casting, broadcasting and delegate to
1916
the pytorch functions for the actual work.
@@ -33,25 +30,16 @@ def wrapped(
3330
signature=None,
3431
extobj=None,
3532
):
36-
if order != "K" or not where or signature or extobj:
37-
raise NotImplementedError
38-
39-
# XXX: dtype=... parameter
40-
if dtype is not None:
41-
raise NotImplementedError
42-
43-
out_shape_dtype = None
44-
if out is not None:
45-
out_shape_dtype = (out.get().dtype, out.get().shape)
46-
47-
tensors = _util.cast_and_broadcast((x1, x2), out_shape_dtype, casting)
48-
33+
tensors = _helpers.ufunc_preprocess((x1, x2), out, where, casting, order, dtype, subok, signature, extobj)
4934
result = torch_func(*tensors)
5035
return _helpers.result_or_out(result, out)
5136

5237
return wrapped
5338

5439

40+
41+
42+
5543
#
5644
# For each torch ufunc implementation, decorate and attach the decorated name
5745
# to this module. Its contents is then exported to the public namespace in __init__.py

torch_np/_decorators.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from . import _dtypes, _helpers
66
from ._detail import _util
77

8-
NoValue = None
9-
108

119
def out_shape_dtype(func):
1210
"""Handle out=... kwarg for ufuncs.
@@ -24,3 +22,12 @@ def wrapped(*args, out=None, **kwds):
2422
return _helpers.result_or_out(result_tensor, out)
2523

2624
return wrapped
25+
26+
27+
28+
29+
30+
31+
32+
33+

torch_np/_funcs.py

Lines changed: 4 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,11 @@
1-
import operator
2-
import typing
3-
from typing import Optional, Sequence
4-
51
import torch
62

7-
from . import _decorators, _helpers
8-
from ._detail import _dtypes_impl, _flips, _reductions, _util
3+
from . import _helpers
4+
from ._detail import _flips, _reductions, _util
95
from ._detail import implementations as _impl
106

11-
################################## normalizations
12-
13-
ArrayLike = typing.TypeVar("ArrayLike")
14-
DTypeLike = typing.TypeVar("DTypeLike")
15-
SubokLike = typing.TypeVar("SubokLike")
16-
AxisLike = typing.TypeVar("AxisLike")
17-
18-
# annotate e.g. atleast_1d(*arys)
19-
UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike")
20-
21-
22-
import inspect
23-
24-
from . import _dtypes
25-
26-
27-
def normalize_array_like(x, name=None):
28-
(tensor,) = _helpers.to_tensors(x)
29-
return tensor
30-
31-
32-
def normalize_optional_array_like(x, name=None):
33-
# This explicit normalizer is needed because otherwise normalize_array_like
34-
# does not run for a parameter annotated as Optional[ArrayLike]
35-
return None if x is None else normalize_array_like(x, name)
36-
37-
38-
def normalize_seq_array_like(x, name=None):
39-
tensors = _helpers.to_tensors(*x)
40-
return tensors
41-
42-
43-
def normalize_dtype(dtype, name=None):
44-
# cf _decorators.dtype_to_torch
45-
torch_dtype = None
46-
if dtype is not None:
47-
dtype = _dtypes.dtype(dtype)
48-
torch_dtype = dtype.torch_dtype
49-
return torch_dtype
50-
51-
52-
def normalize_subok_like(arg, name):
53-
if arg:
54-
raise ValueError(f"'{name}' parameter is not supported.")
55-
56-
57-
def normalize_axis_like(arg, name=None):
58-
from ._ndarray import ndarray
59-
60-
if isinstance(arg, ndarray):
61-
arg = operator.index(arg)
62-
return arg
63-
64-
65-
normalizers = {
66-
ArrayLike: normalize_array_like,
67-
Optional[ArrayLike]: normalize_optional_array_like,
68-
Sequence[ArrayLike]: normalize_seq_array_like,
69-
UnpackedSeqArrayLike: normalize_seq_array_like, # cf handling in normalize
70-
DTypeLike: normalize_dtype,
71-
SubokLike: normalize_subok_like,
72-
AxisLike: normalize_axis_like,
73-
}
74-
75-
import functools
76-
77-
78-
def normalize_this(arg, parm):
79-
"""Normalize arg if a normalizer is registred."""
80-
normalizer = normalizers.get(parm.annotation, None)
81-
if normalizer:
82-
return normalizer(arg)
83-
else:
84-
# untyped arguments pass through
85-
return arg
86-
87-
88-
def normalizer(func):
89-
@functools.wraps(func)
90-
def wrapped(*args, **kwds):
91-
sig = inspect.signature(func)
92-
93-
# first, check for *args in positional parameters. Case in point:
94-
# atleast_1d(*arys: UnpackedSequenceArrayLike)
95-
# if found, consume all args into a tuple to normalize as a whole
96-
for j, param in enumerate(sig.parameters.values()):
97-
if param.annotation == UnpackedSeqArrayLike:
98-
if j == 0:
99-
args = (args,)
100-
else:
101-
# args = args[:j] + (args[j:],) would likely work
102-
# not present in numpy codebase, so do not bother just yet.
103-
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
104-
raise NotImplementedError
105-
break
106-
107-
# loop over positional parameters and actual arguments
108-
lst, dct = [], {}
109-
for arg, (name, parm) in zip(args, sig.parameters.items()):
110-
print(arg, name, parm.annotation)
111-
lst.append(normalize_this(arg, parm))
112-
113-
# normalize keyword arguments
114-
for name, arg in kwds.items():
115-
if not name in sig.parameters:
116-
# unknown kwarg, bail out
117-
raise TypeError(
118-
f"{func.__name__}() got an unexpected keyword argument '{name}'."
119-
)
120-
121-
print("kw: ", name, sig.parameters[name].annotation)
122-
parm = sig.parameters[name]
123-
dct[name] = normalize_this(arg, parm)
124-
125-
ba = sig.bind(*lst, **dct)
126-
ba.apply_defaults()
127-
128-
# Now that all parameters have been consumed, check:
129-
# Anything that has not been bound is unexpected positional arg => raise.
130-
# If there are too few actual arguments, this fill fail in func(*ba.args) below
131-
if len(args) > len(ba.args):
132-
raise TypeError(
133-
f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given."
134-
)
135-
136-
# TODO:
137-
# 1. [LOOKS OK] kw-only parameters : see vstack
138-
# 2. [LOOKS OK] extra unknown args -- error out : nonzero([2, 0, 3], oops=42)
139-
# 3. [LOOKS OK] optional (tensor_or_none) : untyped => pass through
140-
# 4. [LOOKS OK] DTypeLike : positional or kw
141-
# 5. axes : live in _impl or in types? several ways of handling them
142-
# 6. [OK, NOT HERE] keepdims : peel off, postprocess
143-
# 7. OutLike : normal & keyword-only, peel off, postprocess
144-
# 8. [LOOKS OK] *args
145-
# 9. consolidate normalizations (_funcs, _wrapper)
146-
# 10. consolidate decorators (_{unary,binary}_ufuncs)
147-
# 11. out= arg : validate it's an ndarray
148-
149-
# finally, pass normalized arguments through
150-
result = func(*ba.args, **ba.kwargs)
151-
return result
152-
153-
return wrapped
154-
155-
156-
##################################
7+
from ._normalizations import ArrayLike, DTypeLike, AxisLike, SubokLike, UnpackedSeqArrayLike, normalizer
8+
from typing import Optional
1579

15810

15911
@normalizer

torch_np/_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ def cast_and_broadcast(tensors, out, casting):
4141
return tuple(tensors)
4242

4343

44+
45+
def ufunc_preprocess(tensors, out, where, casting, order, dtype, subok, signature, extobj):
46+
# internal preprocessing or args in ufuncs (cf _unary_ufuncs, _binary_ufuncs)
47+
if order != "K" or not where or signature or extobj:
48+
raise NotImplementedError
49+
50+
# XXX: dtype=... parameter
51+
if dtype is not None:
52+
raise NotImplementedError
53+
54+
out_shape_dtype = None
55+
if out is not None:
56+
out_shape_dtype = (out.get().dtype, out.get().shape)
57+
58+
tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting)
59+
60+
return tensors
61+
62+
63+
4464
# ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ###
4565

4666

torch_np/_unary_ufuncs.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,14 @@
33

44

55
from ._detail import _unary_ufuncs
6+
from ._normalizations import ArrayLike, DTypeLike, SubokLike, normalizer
7+
from . import _helpers
68

79
__all__ = [
810
name for name in dir(_unary_ufuncs) if not name.startswith("_") and name != "torch"
911
]
1012

1113

12-
from . import _helpers
13-
from ._detail import _util
14-
15-
# TODO: consolidate normalizations
16-
from ._funcs import ArrayLike, DTypeLike, SubokLike, normalizer
17-
18-
# import torch
19-
2014

2115
def deco_unary_ufunc(torch_func):
2216
"""Common infra for unary ufuncs.
@@ -38,25 +32,15 @@ def wrapped(
3832
signature=None,
3933
extobj=None,
4034
):
41-
if order != "K" or not where or signature or extobj:
42-
raise NotImplementedError
43-
44-
# XXX: dtype=... parameter
45-
if dtype is not None:
46-
raise NotImplementedError
47-
48-
out_shape_dtype = None
49-
if out is not None:
50-
out_shape_dtype = (out.get().dtype, out.get().shape)
51-
52-
tensors = _util.cast_and_broadcast((x,), out_shape_dtype, casting)
53-
35+
tensors = _helpers.ufunc_preprocess((x,), out, where, casting, order, dtype, subok, signature, extobj)
5436
result = torch_func(*tensors)
5537
return _helpers.result_or_out(result, out)
5638

5739
return wrapped
5840

5941

42+
43+
6044
#
6145
# For each torch ufunc implementation, decorate and attach the decorated name
6246
# to this module. Its contents is then exported to the public namespace in __init__.py

torch_np/_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ._detail import implementations as _impl
1414

1515
### XXX: order the imports DAG
16-
from ._funcs import ArrayLike, DTypeLike, SubokLike, UnpackedSeqArrayLike, normalizer
16+
from ._normalizations import ArrayLike, DTypeLike, SubokLike, UnpackedSeqArrayLike, normalizer
1717
from ._ndarray import array, asarray, maybe_set_base, ndarray
1818

1919
from . import _dtypes, _helpers, _decorators # isort: skip # XXX

0 commit comments

Comments
 (0)