|
1 |
| -import operator |
2 |
| -import typing |
3 |
| -from typing import Optional, Sequence |
4 |
| - |
5 | 1 | import torch
|
6 | 2 |
|
7 |
| -from . import _decorators, _helpers |
8 |
| -from ._detail import _dtypes_impl, _flips, _reductions, _util |
| 3 | +from . import _helpers |
| 4 | +from ._detail import _flips, _reductions, _util |
9 | 5 | from ._detail import implementations as _impl
|
10 | 6 |
|
11 |
| -################################## normalizations |
12 |
| - |
13 |
| -ArrayLike = typing.TypeVar("ArrayLike") |
14 |
| -DTypeLike = typing.TypeVar("DTypeLike") |
15 |
| -SubokLike = typing.TypeVar("SubokLike") |
16 |
| -AxisLike = typing.TypeVar("AxisLike") |
17 |
| - |
18 |
| -# annotate e.g. atleast_1d(*arys) |
19 |
| -UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike") |
20 |
| - |
21 |
| - |
22 |
| -import inspect |
23 |
| - |
24 |
| -from . import _dtypes |
25 |
| - |
26 |
| - |
27 |
| -def normalize_array_like(x, name=None): |
28 |
| - (tensor,) = _helpers.to_tensors(x) |
29 |
| - return tensor |
30 |
| - |
31 |
| - |
32 |
| -def normalize_optional_array_like(x, name=None): |
33 |
| - # This explicit normalizer is needed because otherwise normalize_array_like |
34 |
| - # does not run for a parameter annotated as Optional[ArrayLike] |
35 |
| - return None if x is None else normalize_array_like(x, name) |
36 |
| - |
37 |
| - |
38 |
| -def normalize_seq_array_like(x, name=None): |
39 |
| - tensors = _helpers.to_tensors(*x) |
40 |
| - return tensors |
41 |
| - |
42 |
| - |
43 |
| -def normalize_dtype(dtype, name=None): |
44 |
| - # cf _decorators.dtype_to_torch |
45 |
| - torch_dtype = None |
46 |
| - if dtype is not None: |
47 |
| - dtype = _dtypes.dtype(dtype) |
48 |
| - torch_dtype = dtype.torch_dtype |
49 |
| - return torch_dtype |
50 |
| - |
51 |
| - |
52 |
| -def normalize_subok_like(arg, name): |
53 |
| - if arg: |
54 |
| - raise ValueError(f"'{name}' parameter is not supported.") |
55 |
| - |
56 |
| - |
57 |
| -def normalize_axis_like(arg, name=None): |
58 |
| - from ._ndarray import ndarray |
59 |
| - |
60 |
| - if isinstance(arg, ndarray): |
61 |
| - arg = operator.index(arg) |
62 |
| - return arg |
63 |
| - |
64 |
| - |
65 |
| -normalizers = { |
66 |
| - ArrayLike: normalize_array_like, |
67 |
| - Optional[ArrayLike]: normalize_optional_array_like, |
68 |
| - Sequence[ArrayLike]: normalize_seq_array_like, |
69 |
| - UnpackedSeqArrayLike: normalize_seq_array_like, # cf handling in normalize |
70 |
| - DTypeLike: normalize_dtype, |
71 |
| - SubokLike: normalize_subok_like, |
72 |
| - AxisLike: normalize_axis_like, |
73 |
| -} |
74 |
| - |
75 |
| -import functools |
76 |
| - |
77 |
| - |
78 |
| -def normalize_this(arg, parm): |
79 |
| - """Normalize arg if a normalizer is registred.""" |
80 |
| - normalizer = normalizers.get(parm.annotation, None) |
81 |
| - if normalizer: |
82 |
| - return normalizer(arg) |
83 |
| - else: |
84 |
| - # untyped arguments pass through |
85 |
| - return arg |
86 |
| - |
87 |
| - |
88 |
| -def normalizer(func): |
89 |
| - @functools.wraps(func) |
90 |
| - def wrapped(*args, **kwds): |
91 |
| - sig = inspect.signature(func) |
92 |
| - |
93 |
| - # first, check for *args in positional parameters. Case in point: |
94 |
| - # atleast_1d(*arys: UnpackedSequenceArrayLike) |
95 |
| - # if found, consume all args into a tuple to normalize as a whole |
96 |
| - for j, param in enumerate(sig.parameters.values()): |
97 |
| - if param.annotation == UnpackedSeqArrayLike: |
98 |
| - if j == 0: |
99 |
| - args = (args,) |
100 |
| - else: |
101 |
| - # args = args[:j] + (args[j:],) would likely work |
102 |
| - # not present in numpy codebase, so do not bother just yet. |
103 |
| - # NB: branching on j ==0 is to avoid the empty tuple, args[:j] |
104 |
| - raise NotImplementedError |
105 |
| - break |
106 |
| - |
107 |
| - # loop over positional parameters and actual arguments |
108 |
| - lst, dct = [], {} |
109 |
| - for arg, (name, parm) in zip(args, sig.parameters.items()): |
110 |
| - print(arg, name, parm.annotation) |
111 |
| - lst.append(normalize_this(arg, parm)) |
112 |
| - |
113 |
| - # normalize keyword arguments |
114 |
| - for name, arg in kwds.items(): |
115 |
| - if not name in sig.parameters: |
116 |
| - # unknown kwarg, bail out |
117 |
| - raise TypeError( |
118 |
| - f"{func.__name__}() got an unexpected keyword argument '{name}'." |
119 |
| - ) |
120 |
| - |
121 |
| - print("kw: ", name, sig.parameters[name].annotation) |
122 |
| - parm = sig.parameters[name] |
123 |
| - dct[name] = normalize_this(arg, parm) |
124 |
| - |
125 |
| - ba = sig.bind(*lst, **dct) |
126 |
| - ba.apply_defaults() |
127 |
| - |
128 |
| - # Now that all parameters have been consumed, check: |
129 |
| - # Anything that has not been bound is unexpected positional arg => raise. |
130 |
| - # If there are too few actual arguments, this fill fail in func(*ba.args) below |
131 |
| - if len(args) > len(ba.args): |
132 |
| - raise TypeError( |
133 |
| - f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given." |
134 |
| - ) |
135 |
| - |
136 |
| - # TODO: |
137 |
| - # 1. [LOOKS OK] kw-only parameters : see vstack |
138 |
| - # 2. [LOOKS OK] extra unknown args -- error out : nonzero([2, 0, 3], oops=42) |
139 |
| - # 3. [LOOKS OK] optional (tensor_or_none) : untyped => pass through |
140 |
| - # 4. [LOOKS OK] DTypeLike : positional or kw |
141 |
| - # 5. axes : live in _impl or in types? several ways of handling them |
142 |
| - # 6. [OK, NOT HERE] keepdims : peel off, postprocess |
143 |
| - # 7. OutLike : normal & keyword-only, peel off, postprocess |
144 |
| - # 8. [LOOKS OK] *args |
145 |
| - # 9. consolidate normalizations (_funcs, _wrapper) |
146 |
| - # 10. consolidate decorators (_{unary,binary}_ufuncs) |
147 |
| - # 11. out= arg : validate it's an ndarray |
148 |
| - |
149 |
| - # finally, pass normalized arguments through |
150 |
| - result = func(*ba.args, **ba.kwargs) |
151 |
| - return result |
152 |
| - |
153 |
| - return wrapped |
154 |
| - |
155 |
| - |
156 |
| -################################## |
| 7 | +from ._normalizations import ArrayLike, DTypeLike, AxisLike, SubokLike, UnpackedSeqArrayLike, normalizer |
| 8 | +from typing import Optional |
157 | 9 |
|
158 | 10 |
|
159 | 11 | @normalizer
|
|
0 commit comments