Skip to content

Commit 33ce4fb

Browse files
simonjayhawkinsproost
authored andcommitted
TYPING: --check-untyped-defs util._decorators (pandas-dev#28128)
1 parent 3fd0e2b commit 33ce4fb

File tree

6 files changed

+74
-60
lines changed

6 files changed

+74
-60
lines changed

pandas/core/groupby/generic.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -833,45 +833,45 @@ def apply(self, func, *args, **kwargs):
833833
axis="",
834834
)
835835
@Appender(_shared_docs["aggregate"])
836-
def aggregate(self, func_or_funcs=None, *args, **kwargs):
836+
def aggregate(self, func=None, *args, **kwargs):
837837
_level = kwargs.pop("_level", None)
838838

839-
relabeling = func_or_funcs is None
839+
relabeling = func is None
840840
columns = None
841-
no_arg_message = "Must provide 'func_or_funcs' or named aggregation **kwargs."
841+
no_arg_message = "Must provide 'func' or named aggregation **kwargs."
842842
if relabeling:
843843
columns = list(kwargs)
844844
if not PY36:
845845
# sort for 3.5 and earlier
846846
columns = list(sorted(columns))
847847

848-
func_or_funcs = [kwargs[col] for col in columns]
848+
func = [kwargs[col] for col in columns]
849849
kwargs = {}
850850
if not columns:
851851
raise TypeError(no_arg_message)
852852

853-
if isinstance(func_or_funcs, str):
854-
return getattr(self, func_or_funcs)(*args, **kwargs)
853+
if isinstance(func, str):
854+
return getattr(self, func)(*args, **kwargs)
855855

856-
if isinstance(func_or_funcs, abc.Iterable):
856+
if isinstance(func, abc.Iterable):
857857
# Catch instances of lists / tuples
858858
# but not the class list / tuple itself.
859-
func_or_funcs = _maybe_mangle_lambdas(func_or_funcs)
860-
ret = self._aggregate_multiple_funcs(func_or_funcs, (_level or 0) + 1)
859+
func = _maybe_mangle_lambdas(func)
860+
ret = self._aggregate_multiple_funcs(func, (_level or 0) + 1)
861861
if relabeling:
862862
ret.columns = columns
863863
else:
864-
cyfunc = self._get_cython_func(func_or_funcs)
864+
cyfunc = self._get_cython_func(func)
865865
if cyfunc and not args and not kwargs:
866866
return getattr(self, cyfunc)()
867867

868868
if self.grouper.nkeys > 1:
869-
return self._python_agg_general(func_or_funcs, *args, **kwargs)
869+
return self._python_agg_general(func, *args, **kwargs)
870870

871871
try:
872-
return self._python_agg_general(func_or_funcs, *args, **kwargs)
872+
return self._python_agg_general(func, *args, **kwargs)
873873
except Exception:
874-
result = self._aggregate_named(func_or_funcs, *args, **kwargs)
874+
result = self._aggregate_named(func, *args, **kwargs)
875875

876876
index = Index(sorted(result), name=self.grouper.names[0])
877877
ret = Series(result, index=index)
@@ -1464,8 +1464,8 @@ class DataFrameGroupBy(NDFrameGroupBy):
14641464
axis="",
14651465
)
14661466
@Appender(_shared_docs["aggregate"])
1467-
def aggregate(self, arg=None, *args, **kwargs):
1468-
return super().aggregate(arg, *args, **kwargs)
1467+
def aggregate(self, func=None, *args, **kwargs):
1468+
return super().aggregate(func, *args, **kwargs)
14691469

14701470
agg = aggregate
14711471

pandas/core/indexes/interval.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def _find_non_overlapping_monotonic_bounds(self, key):
788788
return start, stop
789789

790790
def get_loc(
791-
self, key: Any, method: Optional[str] = None
791+
self, key: Any, method: Optional[str] = None, tolerance=None
792792
) -> Union[int, slice, np.ndarray]:
793793
"""
794794
Get integer location, slice or boolean mask for requested label.
@@ -982,7 +982,7 @@ def get_indexer_for(self, target: AnyArrayLike, **kwargs) -> np.ndarray:
982982
List of indices.
983983
"""
984984
if self.is_overlapping:
985-
return self.get_indexer_non_unique(target, **kwargs)[0]
985+
return self.get_indexer_non_unique(target)[0]
986986
return self.get_indexer(target, **kwargs)
987987

988988
@Appender(_index_shared_docs["get_value"] % _index_doc_kwargs)

pandas/core/window/ewm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def _constructor(self):
206206
axis="",
207207
)
208208
@Appender(_shared_docs["aggregate"])
209-
def aggregate(self, arg, *args, **kwargs):
210-
return super().aggregate(arg, *args, **kwargs)
209+
def aggregate(self, func, *args, **kwargs):
210+
return super().aggregate(func, *args, **kwargs)
211211

212212
agg = aggregate
213213

pandas/core/window/expanding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ def _get_window(self, other=None, **kwargs):
136136
axis="",
137137
)
138138
@Appender(_shared_docs["aggregate"])
139-
def aggregate(self, arg, *args, **kwargs):
140-
return super().aggregate(arg, *args, **kwargs)
139+
def aggregate(self, func, *args, **kwargs):
140+
return super().aggregate(func, *args, **kwargs)
141141

142142
agg = aggregate
143143

pandas/core/window/rolling.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -901,12 +901,12 @@ def func(arg, window, min_periods=None, closed=None):
901901
axis="",
902902
)
903903
@Appender(_shared_docs["aggregate"])
904-
def aggregate(self, arg, *args, **kwargs):
905-
result, how = self._aggregate(arg, *args, **kwargs)
904+
def aggregate(self, func, *args, **kwargs):
905+
result, how = self._aggregate(func, *args, **kwargs)
906906
if result is None:
907907

908908
# these must apply directly
909-
result = arg(self)
909+
result = func(self)
910910

911911
return result
912912

@@ -1788,8 +1788,8 @@ def _validate_freq(self):
17881788
axis="",
17891789
)
17901790
@Appender(_shared_docs["aggregate"])
1791-
def aggregate(self, arg, *args, **kwargs):
1792-
return super().aggregate(arg, *args, **kwargs)
1791+
def aggregate(self, func, *args, **kwargs):
1792+
return super().aggregate(func, *args, **kwargs)
17931793

17941794
agg = aggregate
17951795

pandas/util/_decorators.py

+48-34
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
11
from functools import wraps
22
import inspect
33
from textwrap import dedent
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
4+
from typing import (
5+
Any,
6+
Callable,
7+
Dict,
8+
List,
9+
Optional,
10+
Tuple,
11+
Type,
12+
TypeVar,
13+
Union,
14+
cast,
15+
)
516
import warnings
617

718
from pandas._libs.properties import cache_readonly # noqa
819

20+
FuncType = Callable[..., Any]
21+
F = TypeVar("F", bound=FuncType)
22+
923

1024
def deprecate(
1125
name: str,
12-
alternative: Callable,
26+
alternative: Callable[..., Any],
1327
version: str,
1428
alt_name: Optional[str] = None,
1529
klass: Optional[Type[Warning]] = None,
1630
stacklevel: int = 2,
1731
msg: Optional[str] = None,
18-
) -> Callable:
32+
) -> Callable[..., Any]:
1933
"""
2034
Return a new function that emits a deprecation warning on use.
2135
@@ -47,7 +61,7 @@ def deprecate(
4761
warning_msg = msg or "{} is deprecated, use {} instead".format(name, alt_name)
4862

4963
@wraps(alternative)
50-
def wrapper(*args, **kwargs):
64+
def wrapper(*args, **kwargs) -> Callable[..., Any]:
5165
warnings.warn(warning_msg, klass, stacklevel=stacklevel)
5266
return alternative(*args, **kwargs)
5367

@@ -90,9 +104,9 @@ def wrapper(*args, **kwargs):
90104
def deprecate_kwarg(
91105
old_arg_name: str,
92106
new_arg_name: Optional[str],
93-
mapping: Optional[Union[Dict, Callable[[Any], Any]]] = None,
107+
mapping: Optional[Union[Dict[Any, Any], Callable[[Any], Any]]] = None,
94108
stacklevel: int = 2,
95-
) -> Callable:
109+
) -> Callable[..., Any]:
96110
"""
97111
Decorator to deprecate a keyword argument of a function.
98112
@@ -160,27 +174,27 @@ def deprecate_kwarg(
160174
"mapping from old to new argument values " "must be dict or callable!"
161175
)
162176

163-
def _deprecate_kwarg(func):
177+
def _deprecate_kwarg(func: F) -> F:
164178
@wraps(func)
165-
def wrapper(*args, **kwargs):
179+
def wrapper(*args, **kwargs) -> Callable[..., Any]:
166180
old_arg_value = kwargs.pop(old_arg_name, None)
167181

168-
if new_arg_name is None and old_arg_value is not None:
169-
msg = (
170-
"the '{old_name}' keyword is deprecated and will be "
171-
"removed in a future version. "
172-
"Please take steps to stop the use of '{old_name}'"
173-
).format(old_name=old_arg_name)
174-
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
175-
kwargs[old_arg_name] = old_arg_value
176-
return func(*args, **kwargs)
177-
178182
if old_arg_value is not None:
179-
if mapping is not None:
180-
if hasattr(mapping, "get"):
181-
new_arg_value = mapping.get(old_arg_value, old_arg_value)
182-
else:
183+
if new_arg_name is None:
184+
msg = (
185+
"the '{old_name}' keyword is deprecated and will be "
186+
"removed in a future version. "
187+
"Please take steps to stop the use of '{old_name}'"
188+
).format(old_name=old_arg_name)
189+
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
190+
kwargs[old_arg_name] = old_arg_value
191+
return func(*args, **kwargs)
192+
193+
elif mapping is not None:
194+
if callable(mapping):
183195
new_arg_value = mapping(old_arg_value)
196+
else:
197+
new_arg_value = mapping.get(old_arg_value, old_arg_value)
184198
msg = (
185199
"the {old_name}={old_val!r} keyword is deprecated, "
186200
"use {new_name}={new_val!r} instead"
@@ -198,7 +212,7 @@ def wrapper(*args, **kwargs):
198212
).format(old_name=old_arg_name, new_name=new_arg_name)
199213

200214
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
201-
if kwargs.get(new_arg_name, None) is not None:
215+
if kwargs.get(new_arg_name) is not None:
202216
msg = (
203217
"Can only specify '{old_name}' or '{new_name}', " "not both"
204218
).format(old_name=old_arg_name, new_name=new_arg_name)
@@ -207,17 +221,17 @@ def wrapper(*args, **kwargs):
207221
kwargs[new_arg_name] = new_arg_value
208222
return func(*args, **kwargs)
209223

210-
return wrapper
224+
return cast(F, wrapper)
211225

212226
return _deprecate_kwarg
213227

214228

215229
def rewrite_axis_style_signature(
216230
name: str, extra_params: List[Tuple[str, Any]]
217-
) -> Callable:
218-
def decorate(func):
231+
) -> Callable[..., Any]:
232+
def decorate(func: F) -> F:
219233
@wraps(func)
220-
def wrapper(*args, **kwargs):
234+
def wrapper(*args, **kwargs) -> Callable[..., Any]:
221235
return func(*args, **kwargs)
222236

223237
kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
@@ -234,8 +248,9 @@ def wrapper(*args, **kwargs):
234248

235249
sig = inspect.Signature(params)
236250

237-
func.__signature__ = sig
238-
return wrapper
251+
# https://github.com/python/typing/issues/598
252+
func.__signature__ = sig # type: ignore
253+
return cast(F, wrapper)
239254

240255
return decorate
241256

@@ -279,18 +294,17 @@ def __init__(self, *args, **kwargs):
279294

280295
self.params = args or kwargs
281296

282-
def __call__(self, func: Callable) -> Callable:
297+
def __call__(self, func: F) -> F:
283298
func.__doc__ = func.__doc__ and func.__doc__ % self.params
284299
return func
285300

286301
def update(self, *args, **kwargs) -> None:
287302
"""
288303
Update self.params with supplied args.
289-
290-
If called, we assume self.params is a dict.
291304
"""
292305

293-
self.params.update(*args, **kwargs)
306+
if isinstance(self.params, dict):
307+
self.params.update(*args, **kwargs)
294308

295309

296310
class Appender:
@@ -320,7 +334,7 @@ def __init__(self, addendum: Optional[str], join: str = "", indents: int = 0):
320334
self.addendum = addendum
321335
self.join = join
322336

323-
def __call__(self, func: Callable) -> Callable:
337+
def __call__(self, func: F) -> F:
324338
func.__doc__ = func.__doc__ if func.__doc__ else ""
325339
self.addendum = self.addendum if self.addendum else ""
326340
docitems = [func.__doc__, self.addendum]

0 commit comments

Comments
 (0)