32
32
Shape ,
33
33
npt ,
34
34
)
35
+ from pandas .compat import (
36
+ IS64 ,
37
+ is_platform_windows ,
38
+ )
35
39
from pandas .errors import AbstractMethodError
36
40
from pandas .util ._decorators import doc
37
41
from pandas .util ._validators import validate_fillna_kwargs
@@ -1081,21 +1085,31 @@ def _quantile(
1081
1085
# ------------------------------------------------------------------
1082
1086
# Reductions
1083
1087
1084
- def _reduce (self , name : str , * , skipna : bool = True , ** kwargs ):
1088
+ def _reduce (
1089
+ self , name : str , * , skipna : bool = True , keepdims : bool = False , ** kwargs
1090
+ ):
1085
1091
if name in {"any" , "all" , "min" , "max" , "sum" , "prod" , "mean" , "var" , "std" }:
1086
- return getattr (self , name )(skipna = skipna , ** kwargs )
1087
-
1088
- data = self ._data
1089
- mask = self ._mask
1090
-
1091
- # median, skew, kurt, sem
1092
- op = getattr (nanops , f"nan{ name } " )
1093
- result = op (data , axis = 0 , skipna = skipna , mask = mask , ** kwargs )
1092
+ result = getattr (self , name )(skipna = skipna , ** kwargs )
1093
+ else :
1094
+ # median, skew, kurt, sem
1095
+ data = self ._data
1096
+ mask = self ._mask
1097
+ op = getattr (nanops , f"nan{ name } " )
1098
+ axis = kwargs .pop ("axis" , None )
1099
+ result = op (data , axis = axis , skipna = skipna , mask = mask , ** kwargs )
1100
+
1101
+ if keepdims :
1102
+ if isna (result ):
1103
+ return self ._wrap_na_result (name = name , axis = 0 , mask_size = (1 ,))
1104
+ else :
1105
+ result = result .reshape (1 )
1106
+ mask = np .zeros (1 , dtype = bool )
1107
+ return self ._maybe_mask_result (result , mask )
1094
1108
1095
- if np . isnan (result ):
1109
+ if isna (result ):
1096
1110
return libmissing .NA
1097
-
1098
- return result
1111
+ else :
1112
+ return result
1099
1113
1100
1114
def _wrap_reduction_result (self , name : str , result , * , skipna , axis ):
1101
1115
if isinstance (result , np .ndarray ):
@@ -1108,6 +1122,32 @@ def _wrap_reduction_result(self, name: str, result, *, skipna, axis):
1108
1122
return self ._maybe_mask_result (result , mask )
1109
1123
return result
1110
1124
1125
+ def _wrap_na_result (self , * , name , axis , mask_size ):
1126
+ mask = np .ones (mask_size , dtype = bool )
1127
+
1128
+ float_dtyp = "float32" if self .dtype == "Float32" else "float64"
1129
+ if name in ["mean" , "median" , "var" , "std" , "skew" ]:
1130
+ np_dtype = float_dtyp
1131
+ elif name in ["min" , "max" ] or self .dtype .itemsize == 8 :
1132
+ np_dtype = self .dtype .numpy_dtype .name
1133
+ else :
1134
+ is_windows_or_32bit = is_platform_windows () or not IS64
1135
+ int_dtyp = "int32" if is_windows_or_32bit else "int64"
1136
+ uint_dtyp = "uint32" if is_windows_or_32bit else "uint64"
1137
+ np_dtype = {"b" : int_dtyp , "i" : int_dtyp , "u" : uint_dtyp , "f" : float_dtyp }[
1138
+ self .dtype .kind
1139
+ ]
1140
+
1141
+ value = np .array ([1 ], dtype = np_dtype )
1142
+ return self ._maybe_mask_result (value , mask = mask )
1143
+
1144
+ def _wrap_min_count_reduction_result (
1145
+ self , name : str , result , * , skipna , min_count , axis
1146
+ ):
1147
+ if min_count == 0 and isinstance (result , np .ndarray ):
1148
+ return self ._maybe_mask_result (result , np .zeros (result .shape , dtype = bool ))
1149
+ return self ._wrap_reduction_result (name , result , skipna = skipna , axis = axis )
1150
+
1111
1151
def sum (
1112
1152
self ,
1113
1153
* ,
@@ -1125,7 +1165,9 @@ def sum(
1125
1165
min_count = min_count ,
1126
1166
axis = axis ,
1127
1167
)
1128
- return self ._wrap_reduction_result ("sum" , result , skipna = skipna , axis = axis )
1168
+ return self ._wrap_min_count_reduction_result (
1169
+ "sum" , result , skipna = skipna , min_count = min_count , axis = axis
1170
+ )
1129
1171
1130
1172
def prod (
1131
1173
self ,
@@ -1136,14 +1178,17 @@ def prod(
1136
1178
** kwargs ,
1137
1179
):
1138
1180
nv .validate_prod ((), kwargs )
1181
+
1139
1182
result = masked_reductions .prod (
1140
1183
self ._data ,
1141
1184
self ._mask ,
1142
1185
skipna = skipna ,
1143
1186
min_count = min_count ,
1144
1187
axis = axis ,
1145
1188
)
1146
- return self ._wrap_reduction_result ("prod" , result , skipna = skipna , axis = axis )
1189
+ return self ._wrap_min_count_reduction_result (
1190
+ "prod" , result , skipna = skipna , min_count = min_count , axis = axis
1191
+ )
1147
1192
1148
1193
def mean (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
1149
1194
nv .validate_mean ((), kwargs )
@@ -1183,23 +1228,25 @@ def std(
1183
1228
1184
1229
def min (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
1185
1230
nv .validate_min ((), kwargs )
1186
- return masked_reductions .min (
1231
+ result = masked_reductions .min (
1187
1232
self ._data ,
1188
1233
self ._mask ,
1189
1234
skipna = skipna ,
1190
1235
axis = axis ,
1191
1236
)
1237
+ return self ._wrap_reduction_result ("min" , result , skipna = skipna , axis = axis )
1192
1238
1193
1239
def max (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
1194
1240
nv .validate_max ((), kwargs )
1195
- return masked_reductions .max (
1241
+ result = masked_reductions .max (
1196
1242
self ._data ,
1197
1243
self ._mask ,
1198
1244
skipna = skipna ,
1199
1245
axis = axis ,
1200
1246
)
1247
+ return self ._wrap_reduction_result ("max" , result , skipna = skipna , axis = axis )
1201
1248
1202
- def any (self , * , skipna : bool = True , ** kwargs ):
1249
+ def any (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
1203
1250
"""
1204
1251
Return whether any element is truthy.
1205
1252
@@ -1218,6 +1265,7 @@ def any(self, *, skipna: bool = True, **kwargs):
1218
1265
If `skipna` is False, the result will still be True if there is
1219
1266
at least one element that is truthy, otherwise NA will be returned
1220
1267
if there are NA's present.
1268
+ axis : int, optional, default 0
1221
1269
**kwargs : any, default None
1222
1270
Additional keywords have no effect but might be accepted for
1223
1271
compatibility with NumPy.
@@ -1261,7 +1309,6 @@ def any(self, *, skipna: bool = True, **kwargs):
1261
1309
>>> pd.array([0, 0, pd.NA]).any(skipna=False)
1262
1310
<NA>
1263
1311
"""
1264
- kwargs .pop ("axis" , None )
1265
1312
nv .validate_any ((), kwargs )
1266
1313
1267
1314
values = self ._data .copy ()
@@ -1280,7 +1327,7 @@ def any(self, *, skipna: bool = True, **kwargs):
1280
1327
else :
1281
1328
return self .dtype .na_value
1282
1329
1283
- def all (self , * , skipna : bool = True , ** kwargs ):
1330
+ def all (self , * , skipna : bool = True , axis : AxisInt | None = 0 , ** kwargs ):
1284
1331
"""
1285
1332
Return whether all elements are truthy.
1286
1333
@@ -1299,6 +1346,7 @@ def all(self, *, skipna: bool = True, **kwargs):
1299
1346
If `skipna` is False, the result will still be False if there is
1300
1347
at least one element that is falsey, otherwise NA will be returned
1301
1348
if there are NA's present.
1349
+ axis : int, optional, default 0
1302
1350
**kwargs : any, default None
1303
1351
Additional keywords have no effect but might be accepted for
1304
1352
compatibility with NumPy.
@@ -1342,7 +1390,6 @@ def all(self, *, skipna: bool = True, **kwargs):
1342
1390
>>> pd.array([1, 0, pd.NA]).all(skipna=False)
1343
1391
False
1344
1392
"""
1345
- kwargs .pop ("axis" , None )
1346
1393
nv .validate_all ((), kwargs )
1347
1394
1348
1395
values = self ._data .copy ()
@@ -1352,7 +1399,7 @@ def all(self, *, skipna: bool = True, **kwargs):
1352
1399
# bool, int, float, complex, str, bytes,
1353
1400
# _NestedSequence[Union[bool, int, float, complex, str, bytes]]]"
1354
1401
np .putmask (values , self ._mask , self ._truthy_value ) # type: ignore[arg-type]
1355
- result = values .all ()
1402
+ result = values .all (axis = axis )
1356
1403
1357
1404
if skipna :
1358
1405
return result
0 commit comments