1
1
import functools
2
2
import itertools
3
3
import operator
4
- from typing import Any , Optional , Tuple , Union
4
+ from typing import Any , Callable , Optional , Tuple , Union
5
5
6
6
import numpy as np
7
7
8
8
from pandas ._config import get_option
9
9
10
10
from pandas ._libs import NaT , Timedelta , Timestamp , iNaT , lib
11
+ from pandas ._typing import Dtype , Scalar
11
12
from pandas .compat ._optional import import_optional_dependency
12
13
13
14
from pandas .core .dtypes .cast import _int64_max , maybe_upcast_putmask
37
38
_USE_BOTTLENECK = False
38
39
39
40
40
- def set_use_bottleneck (v = True ):
41
+ def set_use_bottleneck (v : bool = True ) -> None :
41
42
# set/unset to use bottleneck
42
43
global _USE_BOTTLENECK
43
44
if _BOTTLENECK_INSTALLED :
@@ -55,7 +56,7 @@ def __init__(self, *dtypes):
55
56
def check (self , obj ) -> bool :
56
57
return hasattr (obj , "dtype" ) and issubclass (obj .dtype .type , self .dtypes )
57
58
58
- def __call__ (self , f ):
59
+ def __call__ (self , f ) -> Callable :
59
60
@functools .wraps (f )
60
61
def _f (* args , ** kwargs ):
61
62
obj_iter = itertools .chain (args , kwargs .values ())
@@ -80,11 +81,11 @@ def _f(*args, **kwargs):
80
81
81
82
82
83
class bottleneck_switch :
83
- def __init__ (self , name = None , ** kwargs ):
84
+ def __init__ (self , name : Optional [ str ] = None , ** kwargs ):
84
85
self .name = name
85
86
self .kwargs = kwargs
86
87
87
- def __call__ (self , alt ) :
88
+ def __call__ (self , alt : Callable ) -> Callable :
88
89
bn_name = self .name or alt .__name__
89
90
90
91
try :
@@ -93,7 +94,9 @@ def __call__(self, alt):
93
94
bn_func = None
94
95
95
96
@functools .wraps (alt )
96
- def f (values , axis = None , skipna = True , ** kwds ):
97
+ def f (
98
+ values : np .ndarray , axis : Optional [int ] = None , skipna : bool = True , ** kwds
99
+ ):
97
100
if len (self .kwargs ) > 0 :
98
101
for k , v in self .kwargs .items ():
99
102
if k not in kwds :
@@ -129,7 +132,7 @@ def f(values, axis=None, skipna=True, **kwds):
129
132
return f
130
133
131
134
132
- def _bn_ok_dtype (dt , name : str ) -> bool :
135
+ def _bn_ok_dtype (dt : Dtype , name : str ) -> bool :
133
136
# Bottleneck chokes on datetime64
134
137
if not is_object_dtype (dt ) and not (
135
138
is_datetime_or_timedelta_dtype (dt ) or is_datetime64tz_dtype (dt )
@@ -163,7 +166,9 @@ def _has_infs(result) -> bool:
163
166
return False
164
167
165
168
166
- def _get_fill_value (dtype , fill_value = None , fill_value_typ = None ):
169
+ def _get_fill_value (
170
+ dtype : Dtype , fill_value : Any = None , fill_value_typ : Optional [str ] = None
171
+ ):
167
172
""" return the correct fill value for the dtype of the values """
168
173
if fill_value is not None :
169
174
return fill_value
@@ -326,12 +331,12 @@ def _get_values(
326
331
return values , mask , dtype , dtype_max , fill_value
327
332
328
333
329
- def _na_ok_dtype (dtype ):
334
+ def _na_ok_dtype (dtype ) -> bool :
330
335
# TODO: what about datetime64tz? PeriodDtype?
331
336
return not issubclass (dtype .type , (np .integer , np .timedelta64 , np .datetime64 ))
332
337
333
338
334
- def _wrap_results (result , dtype , fill_value = None ):
339
+ def _wrap_results (result , dtype : Dtype , fill_value = None ):
335
340
""" wrap our results if needed """
336
341
337
342
if is_datetime64_dtype (dtype ) or is_datetime64tz_dtype (dtype ):
@@ -362,7 +367,9 @@ def _wrap_results(result, dtype, fill_value=None):
362
367
return result
363
368
364
369
365
- def _na_for_min_count (values , axis : Optional [int ]):
370
+ def _na_for_min_count (
371
+ values : np .ndarray , axis : Optional [int ]
372
+ ) -> Union [Scalar , np .ndarray ]:
366
373
"""
367
374
Return the missing value for `values`.
368
375
@@ -393,7 +400,12 @@ def _na_for_min_count(values, axis: Optional[int]):
393
400
return result
394
401
395
402
396
- def nanany (values , axis = None , skipna : bool = True , mask = None ):
403
+ def nanany (
404
+ values : np .ndarray ,
405
+ axis : Optional [int ] = None ,
406
+ skipna : bool = True ,
407
+ mask : Optional [np .ndarray ] = None ,
408
+ ) -> bool :
397
409
"""
398
410
Check if any elements along an axis evaluate to True.
399
411
@@ -425,7 +437,12 @@ def nanany(values, axis=None, skipna: bool = True, mask=None):
425
437
return values .any (axis )
426
438
427
439
428
- def nanall (values , axis = None , skipna : bool = True , mask = None ):
440
+ def nanall (
441
+ values : np .ndarray ,
442
+ axis : Optional [int ] = None ,
443
+ skipna : bool = True ,
444
+ mask : Optional [np .ndarray ] = None ,
445
+ ) -> bool :
429
446
"""
430
447
Check if all elements along an axis evaluate to True.
431
448
@@ -458,7 +475,13 @@ def nanall(values, axis=None, skipna: bool = True, mask=None):
458
475
459
476
460
477
@disallow ("M8" )
461
- def nansum (values , axis = None , skipna = True , min_count = 0 , mask = None ):
478
+ def nansum (
479
+ values : np .ndarray ,
480
+ axis : Optional [int ] = None ,
481
+ skipna : bool = True ,
482
+ min_count : int = 0 ,
483
+ mask : Optional [np .ndarray ] = None ,
484
+ ) -> Dtype :
462
485
"""
463
486
Sum the elements along an axis ignoring NaNs
464
487
@@ -629,7 +652,7 @@ def _get_counts_nanvar(
629
652
mask : Optional [np .ndarray ],
630
653
axis : Optional [int ],
631
654
ddof : int ,
632
- dtype = float ,
655
+ dtype : Dtype = float ,
633
656
) -> Tuple [Union [int , np .ndarray ], Union [int , np .ndarray ]]:
634
657
""" Get the count of non-null values along an axis, accounting
635
658
for degrees of freedom.
@@ -776,7 +799,13 @@ def nanvar(values, axis=None, skipna=True, ddof=1, mask=None):
776
799
777
800
778
801
@disallow ("M8" , "m8" )
779
- def nansem (values , axis = None , skipna = True , ddof = 1 , mask = None ):
802
+ def nansem (
803
+ values : np .ndarray ,
804
+ axis : Optional [int ] = None ,
805
+ skipna : bool = True ,
806
+ ddof : int = 1 ,
807
+ mask : Optional [np .ndarray ] = None ,
808
+ ) -> float :
780
809
"""
781
810
Compute the standard error in the mean along given axis while ignoring NaNs
782
811
@@ -819,9 +848,14 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):
819
848
return np .sqrt (var ) / np .sqrt (count )
820
849
821
850
822
- def _nanminmax (meth , fill_value_typ ) :
851
+ def _nanminmax (meth : str , fill_value_typ : str ) -> Callable :
823
852
@bottleneck_switch (name = "nan" + meth )
824
- def reduction (values , axis = None , skipna = True , mask = None ):
853
+ def reduction (
854
+ values : np .ndarray ,
855
+ axis : Optional [int ] = None ,
856
+ skipna : bool = True ,
857
+ mask : Optional [np .ndarray ] = None ,
858
+ ) -> np .ndarray :
825
859
826
860
values , mask , dtype , dtype_max , fill_value = _get_values (
827
861
values , skipna , fill_value_typ = fill_value_typ , mask = mask
@@ -847,7 +881,12 @@ def reduction(values, axis=None, skipna=True, mask=None):
847
881
848
882
849
883
@disallow ("O" )
850
- def nanargmax (values , axis = None , skipna = True , mask = None ):
884
+ def nanargmax (
885
+ values : np .ndarray ,
886
+ axis : Optional [int ] = None ,
887
+ skipna : bool = True ,
888
+ mask : Optional [np .ndarray ] = None ,
889
+ ) -> int :
851
890
"""
852
891
Parameters
853
892
----------
@@ -878,7 +917,12 @@ def nanargmax(values, axis=None, skipna=True, mask=None):
878
917
879
918
880
919
@disallow ("O" )
881
- def nanargmin (values , axis = None , skipna = True , mask = None ):
920
+ def nanargmin (
921
+ values : np .ndarray ,
922
+ axis : Optional [int ] = None ,
923
+ skipna : bool = True ,
924
+ mask : Optional [np .ndarray ] = None ,
925
+ ) -> int :
882
926
"""
883
927
Parameters
884
928
----------
@@ -909,7 +953,12 @@ def nanargmin(values, axis=None, skipna=True, mask=None):
909
953
910
954
911
955
@disallow ("M8" , "m8" )
912
- def nanskew (values , axis = None , skipna = True , mask = None ):
956
+ def nanskew (
957
+ values : np .ndarray ,
958
+ axis : Optional [int ] = None ,
959
+ skipna : bool = True ,
960
+ mask : Optional [np .ndarray ] = None ,
961
+ ) -> float :
913
962
""" Compute the sample skewness.
914
963
915
964
The statistic computed here is the adjusted Fisher-Pearson standardized
@@ -987,7 +1036,12 @@ def nanskew(values, axis=None, skipna=True, mask=None):
987
1036
988
1037
989
1038
@disallow ("M8" , "m8" )
990
- def nankurt (values , axis = None , skipna = True , mask = None ):
1039
+ def nankurt (
1040
+ values : np .ndarray ,
1041
+ axis : Optional [int ] = None ,
1042
+ skipna : bool = True ,
1043
+ mask : Optional [np .ndarray ] = None ,
1044
+ ) -> float :
991
1045
"""
992
1046
Compute the sample excess kurtosis
993
1047
@@ -1075,7 +1129,13 @@ def nankurt(values, axis=None, skipna=True, mask=None):
1075
1129
1076
1130
1077
1131
@disallow ("M8" , "m8" )
1078
- def nanprod (values , axis = None , skipna = True , min_count = 0 , mask = None ):
1132
+ def nanprod (
1133
+ values : np .ndarray ,
1134
+ axis : Optional [int ] = None ,
1135
+ skipna : bool = True ,
1136
+ min_count : int = 0 ,
1137
+ mask : Optional [np .ndarray ] = None ,
1138
+ ) -> Dtype :
1079
1139
"""
1080
1140
Parameters
1081
1141
----------
@@ -1138,7 +1198,7 @@ def _get_counts(
1138
1198
values_shape : Tuple [int ],
1139
1199
mask : Optional [np .ndarray ],
1140
1200
axis : Optional [int ],
1141
- dtype = float ,
1201
+ dtype : Dtype = float ,
1142
1202
) -> Union [int , np .ndarray ]:
1143
1203
""" Get the count of non-null values along an axis
1144
1204
@@ -1218,7 +1278,12 @@ def _zero_out_fperr(arg):
1218
1278
1219
1279
1220
1280
@disallow ("M8" , "m8" )
1221
- def nancorr (a , b , method = "pearson" , min_periods = None ):
1281
+ def nancorr (
1282
+ a : np .ndarray ,
1283
+ b : np .ndarray ,
1284
+ method : str = "pearson" ,
1285
+ min_periods : Optional [int ] = None ,
1286
+ ):
1222
1287
"""
1223
1288
a, b: ndarrays
1224
1289
"""
@@ -1240,7 +1305,7 @@ def nancorr(a, b, method="pearson", min_periods=None):
1240
1305
return f (a , b )
1241
1306
1242
1307
1243
- def get_corr_func (method ):
1308
+ def get_corr_func (method : str ):
1244
1309
if method in ["kendall" , "spearman" ]:
1245
1310
from scipy .stats import kendalltau , spearmanr
1246
1311
elif callable (method ):
@@ -1262,7 +1327,7 @@ def _spearman(a, b):
1262
1327
1263
1328
1264
1329
@disallow ("M8" , "m8" )
1265
- def nancov (a , b , min_periods = None ):
1330
+ def nancov (a : np . ndarray , b : np . ndarray , min_periods : Optional [ int ] = None ):
1266
1331
if len (a ) != len (b ):
1267
1332
raise AssertionError ("Operands to nancov must have same size" )
1268
1333
@@ -1308,7 +1373,7 @@ def _ensure_numeric(x):
1308
1373
# NA-friendly array comparisons
1309
1374
1310
1375
1311
- def make_nancomp (op ):
1376
+ def make_nancomp (op ) -> Callable :
1312
1377
def f (x , y ):
1313
1378
xmask = isna (x )
1314
1379
ymask = isna (y )
@@ -1335,7 +1400,9 @@ def f(x, y):
1335
1400
nanne = make_nancomp (operator .ne )
1336
1401
1337
1402
1338
- def _nanpercentile_1d (values , mask , q , na_value , interpolation ):
1403
+ def _nanpercentile_1d (
1404
+ values : np .ndarray , mask : np .ndarray , q , na_value : Scalar , interpolation : str
1405
+ ) -> Union [Scalar , np .ndarray ]:
1339
1406
"""
1340
1407
Wrapper for np.percentile that skips missing values, specialized to
1341
1408
1-dimensional case.
@@ -1366,7 +1433,15 @@ def _nanpercentile_1d(values, mask, q, na_value, interpolation):
1366
1433
return np .percentile (values , q , interpolation = interpolation )
1367
1434
1368
1435
1369
- def nanpercentile (values , q , axis , na_value , mask , ndim , interpolation ):
1436
+ def nanpercentile (
1437
+ values : np .ndarray ,
1438
+ q ,
1439
+ axis : int ,
1440
+ na_value ,
1441
+ mask : np .ndarray ,
1442
+ ndim : int ,
1443
+ interpolation : str ,
1444
+ ):
1370
1445
"""
1371
1446
Wrapper for np.percentile that skips missing values.
1372
1447
0 commit comments