Skip to content

Commit d15a210

Browse files
authored
TYP: nanops (#43264)
1 parent beb7c48 commit d15a210

File tree

1 file changed

+42
-44
lines changed

1 file changed

+42
-44
lines changed

pandas/core/nanops.py

+42-44
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
F,
2828
Scalar,
2929
Shape,
30+
npt,
3031
)
3132
from pandas.compat._optional import import_optional_dependency
3233

3334
from pandas.core.dtypes.common import (
34-
get_dtype,
3535
is_any_int_dtype,
3636
is_bool_dtype,
3737
is_complex,
@@ -209,8 +209,8 @@ def _get_fill_value(
209209

210210

211211
def _maybe_get_mask(
212-
values: np.ndarray, skipna: bool, mask: np.ndarray | None
213-
) -> np.ndarray | None:
212+
values: np.ndarray, skipna: bool, mask: npt.NDArray[np.bool_] | None
213+
) -> npt.NDArray[np.bool_] | None:
214214
"""
215215
Compute a mask if and only if necessary.
216216
@@ -239,7 +239,7 @@ def _maybe_get_mask(
239239
240240
Returns
241241
-------
242-
Optional[np.ndarray]
242+
Optional[np.ndarray[bool]]
243243
"""
244244
if mask is None:
245245
if is_bool_dtype(values.dtype) or is_integer_dtype(values.dtype):
@@ -257,8 +257,8 @@ def _get_values(
257257
skipna: bool,
258258
fill_value: Any = None,
259259
fill_value_typ: str | None = None,
260-
mask: np.ndarray | None = None,
261-
) -> tuple[np.ndarray, np.ndarray | None, np.dtype, np.dtype, Any]:
260+
mask: npt.NDArray[np.bool_] | None = None,
261+
) -> tuple[np.ndarray, npt.NDArray[np.bool_] | None, np.dtype, np.dtype, Any]:
262262
"""
263263
Utility to get the values view, mask, dtype, dtype_max, and fill_value.
264264
@@ -279,7 +279,7 @@ def _get_values(
279279
value to fill NaNs with
280280
fill_value_typ : str
281281
Set to '+inf' or '-inf' to handle dtype-specific infinities
282-
mask : Optional[np.ndarray]
282+
mask : Optional[np.ndarray[bool]]
283283
nan-mask if known
284284
285285
Returns
@@ -396,7 +396,7 @@ def new_func(
396396
*,
397397
axis: int | None = None,
398398
skipna: bool = True,
399-
mask: np.ndarray | None = None,
399+
mask: npt.NDArray[np.bool_] | None = None,
400400
**kwargs,
401401
):
402402
orig_values = values
@@ -454,7 +454,7 @@ def nanany(
454454
*,
455455
axis: int | None = None,
456456
skipna: bool = True,
457-
mask: np.ndarray | None = None,
457+
mask: npt.NDArray[np.bool_] | None = None,
458458
) -> bool:
459459
"""
460460
Check if any elements along an axis evaluate to True.
@@ -500,7 +500,7 @@ def nanall(
500500
*,
501501
axis: int | None = None,
502502
skipna: bool = True,
503-
mask: np.ndarray | None = None,
503+
mask: npt.NDArray[np.bool_] | None = None,
504504
) -> bool:
505505
"""
506506
Check if all elements along an axis evaluate to True.
@@ -549,7 +549,7 @@ def nansum(
549549
axis: int | None = None,
550550
skipna: bool = True,
551551
min_count: int = 0,
552-
mask: np.ndarray | None = None,
552+
mask: npt.NDArray[np.bool_] | None = None,
553553
) -> float:
554554
"""
555555
Sum the elements along an axis ignoring NaNs
@@ -592,7 +592,7 @@ def nansum(
592592
def _mask_datetimelike_result(
593593
result: np.ndarray | np.datetime64 | np.timedelta64,
594594
axis: int | None,
595-
mask: np.ndarray,
595+
mask: npt.NDArray[np.bool_],
596596
orig_values: np.ndarray,
597597
) -> np.ndarray | np.datetime64 | np.timedelta64 | NaTType:
598598
if isinstance(result, np.ndarray):
@@ -616,7 +616,7 @@ def nanmean(
616616
*,
617617
axis: int | None = None,
618618
skipna: bool = True,
619-
mask: np.ndarray | None = None,
619+
mask: npt.NDArray[np.bool_] | None = None,
620620
) -> float:
621621
"""
622622
Compute the mean of the element along an axis ignoring NaNs
@@ -781,10 +781,10 @@ def get_empty_reduction_result(
781781

782782
def _get_counts_nanvar(
783783
values_shape: Shape,
784-
mask: np.ndarray | None,
784+
mask: npt.NDArray[np.bool_] | None,
785785
axis: int | None,
786786
ddof: int,
787-
dtype: Dtype = float,
787+
dtype: np.dtype = np.dtype(np.float64),
788788
) -> tuple[int | float | np.ndarray, int | float | np.ndarray]:
789789
"""
790790
Get the count of non-null values along an axis, accounting
@@ -808,7 +808,6 @@ def _get_counts_nanvar(
808808
count : int, np.nan or np.ndarray
809809
d : int, np.nan or np.ndarray
810810
"""
811-
dtype = get_dtype(dtype)
812811
count = _get_counts(values_shape, mask, axis, dtype=dtype)
813812
d = count - dtype.type(ddof)
814813

@@ -931,7 +930,7 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None):
931930
# unless we were dealing with a float array, in which case use the same
932931
# precision as the original values array.
933932
if is_float_dtype(dtype):
934-
result = result.astype(dtype)
933+
result = result.astype(dtype, copy=False)
935934
return result
936935

937936

@@ -942,7 +941,7 @@ def nansem(
942941
axis: int | None = None,
943942
skipna: bool = True,
944943
ddof: int = 1,
945-
mask: np.ndarray | None = None,
944+
mask: npt.NDArray[np.bool_] | None = None,
946945
) -> float:
947946
"""
948947
Compute the standard error in the mean along given axis while ignoring NaNs
@@ -993,7 +992,7 @@ def reduction(
993992
*,
994993
axis: int | None = None,
995994
skipna: bool = True,
996-
mask: np.ndarray | None = None,
995+
mask: npt.NDArray[np.bool_] | None = None,
997996
) -> Dtype:
998997

999998
values, mask, dtype, dtype_max, fill_value = _get_values(
@@ -1025,7 +1024,7 @@ def nanargmax(
10251024
*,
10261025
axis: int | None = None,
10271026
skipna: bool = True,
1028-
mask: np.ndarray | None = None,
1027+
mask: npt.NDArray[np.bool_] | None = None,
10291028
) -> int | np.ndarray:
10301029
"""
10311030
Parameters
@@ -1071,7 +1070,7 @@ def nanargmin(
10711070
*,
10721071
axis: int | None = None,
10731072
skipna: bool = True,
1074-
mask: np.ndarray | None = None,
1073+
mask: npt.NDArray[np.bool_] | None = None,
10751074
) -> int | np.ndarray:
10761075
"""
10771076
Parameters
@@ -1117,7 +1116,7 @@ def nanskew(
11171116
*,
11181117
axis: int | None = None,
11191118
skipna: bool = True,
1120-
mask: np.ndarray | None = None,
1119+
mask: npt.NDArray[np.bool_] | None = None,
11211120
) -> float:
11221121
"""
11231122
Compute the sample skewness.
@@ -1185,7 +1184,7 @@ def nanskew(
11851184

11861185
dtype = values.dtype
11871186
if is_float_dtype(dtype):
1188-
result = result.astype(dtype)
1187+
result = result.astype(dtype, copy=False)
11891188

11901189
if isinstance(result, np.ndarray):
11911190
result = np.where(m2 == 0, 0, result)
@@ -1204,7 +1203,7 @@ def nankurt(
12041203
*,
12051204
axis: int | None = None,
12061205
skipna: bool = True,
1207-
mask: np.ndarray | None = None,
1206+
mask: npt.NDArray[np.bool_] | None = None,
12081207
) -> float:
12091208
"""
12101209
Compute the sample excess kurtosis
@@ -1285,7 +1284,7 @@ def nankurt(
12851284

12861285
dtype = values.dtype
12871286
if is_float_dtype(dtype):
1288-
result = result.astype(dtype)
1287+
result = result.astype(dtype, copy=False)
12891288

12901289
if isinstance(result, np.ndarray):
12911290
result = np.where(denominator == 0, 0, result)
@@ -1301,7 +1300,7 @@ def nanprod(
13011300
axis: int | None = None,
13021301
skipna: bool = True,
13031302
min_count: int = 0,
1304-
mask: np.ndarray | None = None,
1303+
mask: npt.NDArray[np.bool_] | None = None,
13051304
) -> float:
13061305
"""
13071306
Parameters
@@ -1339,7 +1338,10 @@ def nanprod(
13391338

13401339

13411340
def _maybe_arg_null_out(
1342-
result: np.ndarray, axis: int | None, mask: np.ndarray | None, skipna: bool
1341+
result: np.ndarray,
1342+
axis: int | None,
1343+
mask: npt.NDArray[np.bool_] | None,
1344+
skipna: bool,
13431345
) -> np.ndarray | int:
13441346
# helper function for nanargmin/nanargmax
13451347
if mask is None:
@@ -1367,10 +1369,10 @@ def _maybe_arg_null_out(
13671369

13681370

13691371
def _get_counts(
1370-
values_shape: tuple[int, ...],
1371-
mask: np.ndarray | None,
1372+
values_shape: Shape,
1373+
mask: npt.NDArray[np.bool_] | None,
13721374
axis: int | None,
1373-
dtype: Dtype = float,
1375+
dtype: np.dtype = np.dtype(np.float64),
13741376
) -> int | float | np.ndarray:
13751377
"""
13761378
Get the count of non-null values along an axis
@@ -1390,7 +1392,6 @@ def _get_counts(
13901392
-------
13911393
count : scalar or array
13921394
"""
1393-
dtype = get_dtype(dtype)
13941395
if axis is None:
13951396
if mask is not None:
13961397
n = mask.size - mask.sum()
@@ -1405,20 +1406,13 @@ def _get_counts(
14051406

14061407
if is_scalar(count):
14071408
return dtype.type(count)
1408-
try:
1409-
return count.astype(dtype)
1410-
except AttributeError:
1411-
# error: Argument "dtype" to "array" has incompatible type
1412-
# "Union[ExtensionDtype, dtype]"; expected "Union[dtype, None, type,
1413-
# _SupportsDtype, str, Tuple[Any, int], Tuple[Any, Union[int,
1414-
# Sequence[int]]], List[Any], _DtypeDict, Tuple[Any, Any]]"
1415-
return np.array(count, dtype=dtype) # type: ignore[arg-type]
1409+
return count.astype(dtype, copy=False)
14161410

14171411

14181412
def _maybe_null_out(
14191413
result: np.ndarray | float | NaTType,
14201414
axis: int | None,
1421-
mask: np.ndarray | None,
1415+
mask: npt.NDArray[np.bool_] | None,
14221416
shape: tuple[int, ...],
14231417
min_count: int = 1,
14241418
) -> np.ndarray | float | NaTType:
@@ -1455,7 +1449,7 @@ def _maybe_null_out(
14551449

14561450

14571451
def check_below_min_count(
1458-
shape: tuple[int, ...], mask: np.ndarray | None, min_count: int
1452+
shape: tuple[int, ...], mask: npt.NDArray[np.bool_] | None, min_count: int
14591453
) -> bool:
14601454
"""
14611455
Check for the `min_count` keyword. Returns True if below `min_count` (when
@@ -1465,7 +1459,7 @@ def check_below_min_count(
14651459
----------
14661460
shape : tuple
14671461
The shape of the values (`values.shape`).
1468-
mask : ndarray or None
1462+
mask : ndarray[bool] or None
14691463
Boolean numpy array (typically of same shape as `shape`) or None.
14701464
min_count : int
14711465
Keyword passed through from sum/prod call.
@@ -1634,7 +1628,11 @@ def f(x, y):
16341628

16351629

16361630
def _nanpercentile_1d(
1637-
values: np.ndarray, mask: np.ndarray, q: np.ndarray, na_value: Scalar, interpolation
1631+
values: np.ndarray,
1632+
mask: npt.NDArray[np.bool_],
1633+
q: np.ndarray,
1634+
na_value: Scalar,
1635+
interpolation,
16381636
) -> Scalar | np.ndarray:
16391637
"""
16401638
Wrapper for np.percentile that skips missing values, specialized to
@@ -1668,7 +1666,7 @@ def nanpercentile(
16681666
q: np.ndarray,
16691667
*,
16701668
na_value,
1671-
mask: np.ndarray,
1669+
mask: npt.NDArray[np.bool_],
16721670
interpolation,
16731671
):
16741672
"""

0 commit comments

Comments
 (0)