Skip to content

Commit ef63275

Browse files
itholiczhengruifeng
authored andcommitted
[SPARK-44842][SPARK-43812][PS] Support stat functions for pandas 2.0.0 and enabling tests
### What changes were proposed in this pull request? This PR proposes to match the behavior with pandas 2.0.0 and above for stat functions, such as `sum`, `quantile`, `prod`, etc. See pandas-dev/pandas#41480 and pandas-dev/pandas#47500 for more detail. ### Why are the changes needed? To match the behavior to latest pandas. ### Does this PR introduce _any_ user-facing change? Yes, the behaviors for stat funcs are now matched with pandas 2.0.0 and above. ### How was this patch tested? Enabling & updating the existing UTs. Closes #42526 from itholic/pandas_stat. Authored-by: itholic <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 52c4673 commit ef63275

File tree

9 files changed

+80
-131
lines changed

9 files changed

+80
-131
lines changed

python/pyspark/pandas/frame.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
DecimalType,
8484
TimestampType,
8585
TimestampNTZType,
86+
NullType,
8687
)
8788
from pyspark.sql.window import Window
8889

@@ -797,7 +798,7 @@ def _reduce_for_stat_function(
797798
new_column_labels.append(label)
798799

799800
if len(exprs) == 1:
800-
return Series([])
801+
return Series([], dtype="float64")
801802

802803
sdf = self._internal.spark_frame.select(*exprs)
803804

@@ -12128,11 +12129,6 @@ def quantile(
1212812129
0.50 3.0 7.0
1212912130
0.75 4.0 8.0
1213012131
"""
12131-
warnings.warn(
12132-
"Default value of `numeric_only` will be changed to `False` "
12133-
"instead of `True` in 4.0.0.",
12134-
FutureWarning,
12135-
)
1213612132
axis = validate_axis(axis)
1213712133
if axis != 0:
1213812134
raise NotImplementedError('axis should be either 0 or "index" currently.')
@@ -12155,7 +12151,7 @@ def quantile(
1215512151
def quantile(psser: "Series") -> PySparkColumn:
1215612152
spark_type = psser.spark.data_type
1215712153
spark_column = psser.spark.column
12158-
if isinstance(spark_type, (BooleanType, NumericType)):
12154+
if isinstance(spark_type, (BooleanType, NumericType, NullType)):
1215912155
return F.percentile_approx(spark_column.cast(DoubleType()), qq, accuracy)
1216012156
else:
1216112157
raise TypeError(

python/pyspark/pandas/generic.py

-5
Original file line numberDiff line numberDiff line change
@@ -1419,11 +1419,6 @@ def product(
14191419
nan
14201420
"""
14211421
axis = validate_axis(axis)
1422-
warnings.warn(
1423-
"Default value of `numeric_only` will be changed to `False` "
1424-
"instead of `None` in 4.0.0.",
1425-
FutureWarning,
1426-
)
14271422

14281423
if numeric_only is None and axis == 0:
14291424
numeric_only = True

python/pyspark/pandas/groupby.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -614,9 +614,10 @@ def mean(self, numeric_only: Optional[bool] = True) -> FrameLike:
614614
615615
Parameters
616616
----------
617-
numeric_only : bool, default False
617+
numeric_only : bool, default True
618618
Include only float, int, boolean columns. If None, will attempt to use
619-
everything, then use only numeric data.
619+
everything, then use only numeric data. False is not supported.
620+
This parameter is mainly for pandas compatibility.
620621
621622
.. versionadded:: 3.4.0
622623
@@ -646,11 +647,6 @@ def mean(self, numeric_only: Optional[bool] = True) -> FrameLike:
646647
2 4.0 1.500000 1.000000
647648
"""
648649
self._validate_agg_columns(numeric_only=numeric_only, function_name="median")
649-
warnings.warn(
650-
"Default value of `numeric_only` will be changed to `False` "
651-
"instead of `True` in 4.0.0.",
652-
FutureWarning,
653-
)
654650

655651
return self._reduce_for_stat_function(
656652
F.mean, accepted_spark_types=(NumericType,), bool_to_numeric=True
@@ -920,7 +916,7 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
920916
)
921917

922918
# TODO: sync the doc.
923-
def var(self, ddof: int = 1) -> FrameLike:
919+
def var(self, ddof: int = 1, numeric_only: Optional[bool] = True) -> FrameLike:
924920
"""
925921
Compute variance of groups, excluding missing values.
926922
@@ -935,6 +931,13 @@ def var(self, ddof: int = 1) -> FrameLike:
935931
.. versionchanged:: 3.4.0
936932
Supported including arbitary integers.
937933
934+
numeric_only : bool, default True
935+
Include only float, int, boolean columns. If None, will attempt to use
936+
everything, then use only numeric data. False is not supported.
937+
This parameter is mainly for pandas compatibility.
938+
939+
.. versionadded:: 4.0.0
940+
938941
Examples
939942
--------
940943
>>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
@@ -961,6 +964,7 @@ def var(col: Column) -> Column:
961964
var,
962965
accepted_spark_types=(NumericType,),
963966
bool_to_numeric=True,
967+
numeric_only=numeric_only,
964968
)
965969

966970
def skew(self) -> FrameLike:

python/pyspark/pandas/series.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
Row,
6868
StructType,
6969
TimestampType,
70+
NullType,
7071
)
7172
from pyspark.sql.window import Window
7273
from pyspark.sql.utils import get_column_class, get_window_class
@@ -4024,7 +4025,7 @@ def quantile(
40244025
def quantile(psser: Series) -> PySparkColumn:
40254026
spark_type = psser.spark.data_type
40264027
spark_column = psser.spark.column
4027-
if isinstance(spark_type, (BooleanType, NumericType)):
4028+
if isinstance(spark_type, (BooleanType, NumericType, NullType)):
40284029
return F.percentile_approx(spark_column.cast(DoubleType()), q_float, accuracy)
40294030
else:
40304031
raise TypeError(
@@ -4059,7 +4060,8 @@ def rank(
40594060
ascending : boolean, default True
40604061
False for ranks by high (1) to low (N)
40614062
numeric_only : bool, optional
4062-
If set to True, rank numeric Series, or return an empty Series for non-numeric Series
4063+
If set to True, rank numeric Series, or raise TypeError for non-numeric Series.
4064+
False is not supported. This parameter is mainly for pandas compatibility.
40634065
40644066
Returns
40654067
-------
@@ -4127,18 +4129,10 @@ def rank(
41274129
y b
41284130
z c
41294131
Name: A, dtype: object
4130-
4131-
>>> s.rank(numeric_only=True)
4132-
Series([], Name: A, dtype: float64)
41334132
"""
4134-
warnings.warn(
4135-
"Default value of `numeric_only` will be changed to `False` "
4136-
"instead of `None` in 4.0.0.",
4137-
FutureWarning,
4138-
)
41394133
is_numeric = isinstance(self.spark.data_type, (NumericType, BooleanType))
41404134
if numeric_only and not is_numeric:
4141-
return ps.Series([], dtype="float64", name=self.name)
4135+
raise TypeError("Series.rank does not allow numeric_only=True with non-numeric dtype.")
41424136
else:
41434137
return self._rank(method, ascending).spark.analyzed
41444138

python/pyspark/pandas/tests/computation/test_any_all.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ def df_pair(self):
3939
psdf = ps.from_pandas(pdf)
4040
return pdf, psdf
4141

42-
@unittest.skipIf(
43-
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
44-
"TODO(SPARK-43812): Enable DataFrameTests.test_all for pandas 2.0.0.",
45-
)
4642
def test_all(self):
4743
pdf = pd.DataFrame(
4844
{
@@ -105,9 +101,15 @@ def test_all(self):
105101
self.assert_eq(psdf.all(skipna=True), pdf.all(skipna=True))
106102
self.assert_eq(psdf.all(), pdf.all())
107103
self.assert_eq(
108-
ps.DataFrame([np.nan]).all(skipna=False), pd.DataFrame([np.nan]).all(skipna=False)
104+
ps.DataFrame([np.nan]).all(skipna=False),
105+
pd.DataFrame([np.nan]).all(skipna=False),
106+
almost=True,
107+
)
108+
self.assert_eq(
109+
ps.DataFrame([None]).all(skipna=True),
110+
pd.DataFrame([None]).all(skipna=True),
111+
almost=True,
109112
)
110-
self.assert_eq(ps.DataFrame([None]).all(skipna=True), pd.DataFrame([None]).all(skipna=True))
111113

112114
def test_any(self):
113115
pdf = pd.DataFrame(

python/pyspark/pandas/tests/computation/test_compute.py

+16-22
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,6 @@ def test_nunique(self):
283283
self.assert_eq(psdf.nunique(), pdf.nunique())
284284
self.assert_eq(psdf.nunique(dropna=False), pdf.nunique(dropna=False))
285285

286-
@unittest.skipIf(
287-
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
288-
"TODO(SPARK-43810): Enable DataFrameSlowTests.test_quantile for pandas 2.0.0.",
289-
)
290286
def test_quantile(self):
291287
pdf, psdf = self.df_pair
292288

@@ -332,59 +328,57 @@ def test_quantile(self):
332328
pdf = pd.DataFrame({"x": ["a", "b", "c"]})
333329
psdf = ps.from_pandas(pdf)
334330

335-
self.assert_eq(psdf.quantile(0.5), pdf.quantile(0.5))
336-
self.assert_eq(psdf.quantile([0.25, 0.5, 0.75]), pdf.quantile([0.25, 0.5, 0.75]))
331+
self.assert_eq(psdf.quantile(0.5), pdf.quantile(0.5, numeric_only=True))
332+
self.assert_eq(
333+
psdf.quantile([0.25, 0.5, 0.75]), pdf.quantile([0.25, 0.5, 0.75], numeric_only=True)
334+
)
337335

338336
with self.assertRaisesRegex(TypeError, "Could not convert object \\(string\\) to numeric"):
339337
psdf.quantile(0.5, numeric_only=False)
340338
with self.assertRaisesRegex(TypeError, "Could not convert object \\(string\\) to numeric"):
341339
psdf.quantile([0.25, 0.5, 0.75], numeric_only=False)
342340

343-
@unittest.skipIf(
344-
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
345-
"TODO(SPARK-43558): Enable DataFrameSlowTests.test_product for pandas 2.0.0.",
346-
)
347341
def test_product(self):
348342
pdf = pd.DataFrame(
349343
{"A": [1, 2, 3, 4, 5], "B": [10, 20, 30, 40, 50], "C": ["a", "b", "c", "d", "e"]}
350344
)
351345
psdf = ps.from_pandas(pdf)
352-
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
346+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())
353347

354348
# Named columns
355349
pdf.columns.name = "Koalas"
356350
psdf = ps.from_pandas(pdf)
357-
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
351+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())
358352

359353
# MultiIndex columns
360354
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
361355
psdf = ps.from_pandas(pdf)
362-
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
356+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())
363357

364358
# Named MultiIndex columns
365359
pdf.columns.names = ["Hello", "Koalas"]
366360
psdf = ps.from_pandas(pdf)
367-
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
361+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())
368362

369363
# No numeric columns
370364
pdf = pd.DataFrame({"key": ["a", "b", "c"], "val": ["x", "y", "z"]})
371365
psdf = ps.from_pandas(pdf)
372-
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
366+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())
373367

374368
# No numeric named columns
375369
pdf.columns.name = "Koalas"
376370
psdf = ps.from_pandas(pdf)
377-
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), almost=True)
371+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), almost=True)
378372

379373
# No numeric MultiIndex columns
380374
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y")])
381375
psdf = ps.from_pandas(pdf)
382-
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), almost=True)
376+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), almost=True)
383377

384378
# No numeric named MultiIndex columns
385379
pdf.columns.names = ["Hello", "Koalas"]
386380
psdf = ps.from_pandas(pdf)
387-
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), almost=True)
381+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), almost=True)
388382

389383
# All NaN columns
390384
pdf = pd.DataFrame(
@@ -395,22 +389,22 @@ def test_product(self):
395389
}
396390
)
397391
psdf = ps.from_pandas(pdf)
398-
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
392+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)
399393

400394
# All NaN named columns
401395
pdf.columns.name = "Koalas"
402396
psdf = ps.from_pandas(pdf)
403-
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
397+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)
404398

405399
# All NaN MultiIndex columns
406400
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
407401
psdf = ps.from_pandas(pdf)
408-
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
402+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)
409403

410404
# All NaN named MultiIndex columns
411405
pdf.columns.names = ["Hello", "Koalas"]
412406
psdf = ps.from_pandas(pdf)
413-
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
407+
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)
414408

415409

416410
class FrameComputeTests(FrameComputeMixin, ComparisonTestBase, SQLTestUtils):

python/pyspark/pandas/tests/groupby/test_stat.py

+7-26
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,10 @@ def _test_stat_func(self, func, check_exact=True):
5858
check_exact=check_exact,
5959
)
6060

61-
@unittest.skipIf(
62-
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
63-
"TODO(SPARK-43554): Enable GroupByTests.test_basic_stat_funcs for pandas 2.0.0.",
64-
)
6561
def test_basic_stat_funcs(self):
66-
self._test_stat_func(lambda groupby_obj: groupby_obj.var(), check_exact=False)
62+
self._test_stat_func(
63+
lambda groupby_obj: groupby_obj.var(numeric_only=True), check_exact=False
64+
)
6765

6866
pdf, psdf = self.pdf, self.psdf
6967

@@ -102,30 +100,24 @@ def test_basic_stat_funcs(self):
102100

103101
self.assert_eq(
104102
psdf.groupby("A").std().sort_index(),
105-
pdf.groupby("A").std().sort_index(),
103+
pdf.groupby("A").std(numeric_only=True).sort_index(),
106104
check_exact=False,
107105
)
108106
self.assert_eq(
109107
psdf.groupby("A").sem().sort_index(),
110-
pdf.groupby("A").sem().sort_index(),
108+
pdf.groupby("A").sem(numeric_only=True).sort_index(),
111109
check_exact=False,
112110
)
113111

114112
# TODO: fix bug of `sum` and re-enable the test below
115113
# self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), check_exact=False)
116114
self.assert_eq(
117115
psdf.groupby("A").sum().sort_index(),
118-
pdf.groupby("A").sum().sort_index(),
116+
pdf.groupby("A").sum(numeric_only=True).sort_index(),
119117
check_exact=False,
120118
)
121119

122-
@unittest.skipIf(
123-
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
124-
"TODO(SPARK-43706): Enable GroupByTests.test_mean " "for pandas 2.0.0.",
125-
)
126120
def test_mean(self):
127-
self._test_stat_func(lambda groupby_obj: groupby_obj.mean())
128-
self._test_stat_func(lambda groupby_obj: groupby_obj.mean(numeric_only=None))
129121
self._test_stat_func(lambda groupby_obj: groupby_obj.mean(numeric_only=True))
130122
psdf = self.psdf
131123
with self.assertRaises(TypeError):
@@ -267,10 +259,6 @@ def test_nth(self):
267259
with self.assertRaisesRegex(TypeError, "Invalid index"):
268260
self.psdf.groupby("B").nth("x")
269261

270-
@unittest.skipIf(
271-
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
272-
"TODO(SPARK-43551): Enable GroupByTests.test_prod for pandas 2.0.0.",
273-
)
274262
def test_prod(self):
275263
pdf = pd.DataFrame(
276264
{
@@ -286,19 +274,12 @@ def test_prod(self):
286274
psdf = ps.from_pandas(pdf)
287275

288276
for n in [0, 1, 2, 128, -1, -2, -128]:
289-
self._test_stat_func(
290-
lambda groupby_obj: groupby_obj.prod(min_count=n), check_exact=False
291-
)
292-
self._test_stat_func(
293-
lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n),
294-
check_exact=False,
295-
)
296277
self._test_stat_func(
297278
lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n),
298279
check_exact=False,
299280
)
300281
self.assert_eq(
301-
pdf.groupby("A").prod(min_count=n).sort_index(),
282+
pdf.groupby("A").prod(min_count=n, numeric_only=True).sort_index(),
302283
psdf.groupby("A").prod(min_count=n).sort_index(),
303284
almost=True,
304285
)

0 commit comments

Comments
 (0)