Skip to content

Commit 302e862

Browse files
committed
MAINT: move normalization logic to _normalizations
1 parent c4dffa5 commit 302e862

File tree

7 files changed

+201
-199
lines changed

7 files changed

+201
-199
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1+
from . import _helpers
12
from ._detail import _binary_ufuncs
3+
from ._normalizations import ArrayLike, DTypeLike, SubokLike, normalizer
24

35
__all__ = [
46
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
57
]
68

79

8-
from . import _helpers
9-
from ._detail import _util
10-
11-
# TODO: consolidate normalizations
12-
from ._funcs import ArrayLike, DTypeLike, SubokLike, normalizer
13-
14-
1510
def deco_binary_ufunc(torch_func):
16-
"""Common infra for unary ufuncs.
11+
"""Common infra for binary ufuncs.
1712
1813
Normalize arguments, sort out type casting, broadcasting and delegate to
1914
the pytorch functions for the actual work.
@@ -33,19 +28,9 @@ def wrapped(
3328
signature=None,
3429
extobj=None,
3530
):
36-
if order != "K" or not where or signature or extobj:
37-
raise NotImplementedError
38-
39-
# XXX: dtype=... parameter
40-
if dtype is not None:
41-
raise NotImplementedError
42-
43-
out_shape_dtype = None
44-
if out is not None:
45-
out_shape_dtype = (out.get().dtype, out.get().shape)
46-
47-
tensors = _util.cast_and_broadcast((x1, x2), out_shape_dtype, casting)
48-
31+
tensors = _helpers.ufunc_preprocess(
32+
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
33+
)
4934
result = torch_func(*tensors)
5035
return _helpers.result_or_out(result, out)
5136

torch_np/_decorators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from . import _dtypes, _helpers
66
from ._detail import _util
77

8-
NoValue = None
9-
108

119
def out_shape_dtype(func):
1210
"""Handle out=... kwarg for ufuncs.

torch_np/_funcs.py

Lines changed: 11 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,18 @@
1-
import operator
2-
import typing
3-
from typing import Optional, Sequence
1+
from typing import Optional
42

53
import torch
64

7-
from . import _decorators, _helpers
8-
from ._detail import _dtypes_impl, _flips, _reductions, _util
5+
from . import _helpers
6+
from ._detail import _flips, _reductions, _util
97
from ._detail import implementations as _impl
10-
11-
################################## normalizations
12-
13-
ArrayLike = typing.TypeVar("ArrayLike")
14-
DTypeLike = typing.TypeVar("DTypeLike")
15-
SubokLike = typing.TypeVar("SubokLike")
16-
AxisLike = typing.TypeVar("AxisLike")
17-
18-
# annotate e.g. atleast_1d(*arys)
19-
UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike")
20-
21-
22-
import inspect
23-
24-
from . import _dtypes
25-
26-
27-
def normalize_array_like(x, name=None):
28-
(tensor,) = _helpers.to_tensors(x)
29-
return tensor
30-
31-
32-
def normalize_optional_array_like(x, name=None):
33-
# This explicit normalizer is needed because otherwise normalize_array_like
34-
# does not run for a parameter annotated as Optional[ArrayLike]
35-
return None if x is None else normalize_array_like(x, name)
36-
37-
38-
def normalize_seq_array_like(x, name=None):
39-
tensors = _helpers.to_tensors(*x)
40-
return tensors
41-
42-
43-
def normalize_dtype(dtype, name=None):
44-
# cf _decorators.dtype_to_torch
45-
torch_dtype = None
46-
if dtype is not None:
47-
dtype = _dtypes.dtype(dtype)
48-
torch_dtype = dtype.torch_dtype
49-
return torch_dtype
50-
51-
52-
def normalize_subok_like(arg, name):
53-
if arg:
54-
raise ValueError(f"'{name}' parameter is not supported.")
55-
56-
57-
def normalize_axis_like(arg, name=None):
58-
from ._ndarray import ndarray
59-
60-
if isinstance(arg, ndarray):
61-
arg = operator.index(arg)
62-
return arg
63-
64-
65-
normalizers = {
66-
ArrayLike: normalize_array_like,
67-
Optional[ArrayLike]: normalize_optional_array_like,
68-
Sequence[ArrayLike]: normalize_seq_array_like,
69-
UnpackedSeqArrayLike: normalize_seq_array_like, # cf handling in normalize
70-
DTypeLike: normalize_dtype,
71-
SubokLike: normalize_subok_like,
72-
AxisLike: normalize_axis_like,
73-
}
74-
75-
import functools
76-
77-
78-
def normalize_this(arg, parm):
79-
"""Normalize arg if a normalizer is registred."""
80-
normalizer = normalizers.get(parm.annotation, None)
81-
if normalizer:
82-
return normalizer(arg)
83-
else:
84-
# untyped arguments pass through
85-
return arg
86-
87-
88-
def normalizer(func):
89-
@functools.wraps(func)
90-
def wrapped(*args, **kwds):
91-
sig = inspect.signature(func)
92-
93-
# first, check for *args in positional parameters. Case in point:
94-
# atleast_1d(*arys: UnpackedSequenceArrayLike)
95-
# if found, consume all args into a tuple to normalize as a whole
96-
for j, param in enumerate(sig.parameters.values()):
97-
if param.annotation == UnpackedSeqArrayLike:
98-
if j == 0:
99-
args = (args,)
100-
else:
101-
# args = args[:j] + (args[j:],) would likely work
102-
# not present in numpy codebase, so do not bother just yet.
103-
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
104-
raise NotImplementedError
105-
break
106-
107-
# loop over positional parameters and actual arguments
108-
lst, dct = [], {}
109-
for arg, (name, parm) in zip(args, sig.parameters.items()):
110-
print(arg, name, parm.annotation)
111-
lst.append(normalize_this(arg, parm))
112-
113-
# normalize keyword arguments
114-
for name, arg in kwds.items():
115-
if not name in sig.parameters:
116-
# unknown kwarg, bail out
117-
raise TypeError(
118-
f"{func.__name__}() got an unexpected keyword argument '{name}'."
119-
)
120-
121-
print("kw: ", name, sig.parameters[name].annotation)
122-
parm = sig.parameters[name]
123-
dct[name] = normalize_this(arg, parm)
124-
125-
ba = sig.bind(*lst, **dct)
126-
ba.apply_defaults()
127-
128-
# Now that all parameters have been consumed, check:
129-
# Anything that has not been bound is unexpected positional arg => raise.
130-
# If there are too few actual arguments, this fill fail in func(*ba.args) below
131-
if len(args) > len(ba.args):
132-
raise TypeError(
133-
f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given."
134-
)
135-
136-
# TODO:
137-
# 1. [LOOKS OK] kw-only parameters : see vstack
138-
# 2. [LOOKS OK] extra unknown args -- error out : nonzero([2, 0, 3], oops=42)
139-
# 3. [LOOKS OK] optional (tensor_or_none) : untyped => pass through
140-
# 4. [LOOKS OK] DTypeLike : positional or kw
141-
# 5. axes : live in _impl or in types? several ways of handling them
142-
# 6. [OK, NOT HERE] keepdims : peel off, postprocess
143-
# 7. OutLike : normal & keyword-only, peel off, postprocess
144-
# 8. [LOOKS OK] *args
145-
# 9. consolidate normalizations (_funcs, _wrapper)
146-
# 10. consolidate decorators (_{unary,binary}_ufuncs)
147-
# 11. out= arg : validate it's an ndarray
148-
149-
# finally, pass normalized arguments through
150-
result = func(*ba.args, **ba.kwargs)
151-
return result
152-
153-
return wrapped
154-
155-
156-
##################################
8+
from ._normalizations import (
9+
ArrayLike,
10+
AxisLike,
11+
DTypeLike,
12+
SubokLike,
13+
UnpackedSeqArrayLike,
14+
normalizer,
15+
)
15716

15817

15918
@normalizer

torch_np/_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ def cast_and_broadcast(tensors, out, casting):
4141
return tuple(tensors)
4242

4343

44+
def ufunc_preprocess(
45+
tensors, out, where, casting, order, dtype, subok, signature, extobj
46+
):
47+
# internal preprocessing or args in ufuncs (cf _unary_ufuncs, _binary_ufuncs)
48+
if order != "K" or not where or signature or extobj:
49+
raise NotImplementedError
50+
51+
# XXX: dtype=... parameter
52+
if dtype is not None:
53+
raise NotImplementedError
54+
55+
out_shape_dtype = None
56+
if out is not None:
57+
out_shape_dtype = (out.get().dtype, out.get().shape)
58+
59+
tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting)
60+
61+
return tensors
62+
63+
4464
# ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ###
4565

4666

0 commit comments

Comments
 (0)