Skip to content

Commit 956e140

Browse files
zhengruifengSandishKumarHN
authored andcommitted
[SPARK-40852][CONNECT][PYTHON] Introduce StatFunction in proto and implement DataFrame.summary
### What changes were proposed in this pull request? Implement `DataFrame.summary` there is a set of DataFrame APIs implemented in [`StatFunctions`](https://github.com/apache/spark/blob/9cae423075145d3dd81d53f4b82d4f2af6fe7c15/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala), [`DataFrameStatFunctions`](https://github.com/apache/spark/blob/b69c26833c99337bb17922f21dd72ee3a12e0c0a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala) and [`DataFrameNaFunctions`](https://github.com/apache/spark/blob/5d74ace648422e7a9bff7774ac266372934023b9/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala), which I think can not be implemented in connect client due to: 1. depend on Catalyst's analysis (most of them); ~~2. implemented in RDD operations (like `summary`,`approxQuantile`);~~ (resolved by reimpl) ~~3. internally trigger jobs (like `summary`);~~ (resolved by reimpl) This PR introduced a new proto `StatFunction` to support `StatFunctions` method ### Why are the changes needed? for Connect API coverage ### Does this PR introduce _any_ user-facing change? yes, new API ### How was this patch tested? added UT Closes apache#38318 from zhengruifeng/connect_df_summary. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 896a465 commit 956e140

File tree

9 files changed

+252
-59
lines changed

9 files changed

+252
-59
lines changed

connector/connect/src/main/protobuf/spark/connect/relations.proto

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ message Relation {
4848
SubqueryAlias subquery_alias = 16;
4949
Repartition repartition = 17;
5050

51+
StatFunction stat_function = 100;
52+
5153
Unknown unknown = 999;
5254
}
5355
}
@@ -254,3 +256,21 @@ message Repartition {
254256
// Optional. Default value is false.
255257
bool shuffle = 3;
256258
}
259+
260+
// StatFunction
261+
message StatFunction {
262+
// Required. The input relation.
263+
Relation input = 1;
264+
// Required. The function and its parameters.
265+
oneof function {
266+
Summary summary = 2;
267+
268+
Unknown unknown = 999;
269+
}
270+
271+
// StatFunctions.summary
272+
message Summary {
273+
repeated string statistics = 1;
274+
}
275+
}
276+

connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,22 @@ package object dsl {
441441
Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
442442
.build()
443443

444+
def summary(statistics: String*): Relation = {
445+
Relation
446+
.newBuilder()
447+
.setStatFunction(
448+
proto.StatFunction
449+
.newBuilder()
450+
.setInput(logicalPlan)
451+
.setSummary(
452+
proto.StatFunction.Summary
453+
.newBuilder()
454+
.addAllStatistics(statistics.toSeq.asJava)
455+
.build())
456+
.build())
457+
.build()
458+
}
459+
444460
private def createSetOperation(
445461
left: Relation,
446462
right: Relation,

connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.annotation.elidable.byName
2121
import scala.collection.JavaConverters._
2222

2323
import org.apache.spark.connect.proto
24-
import org.apache.spark.sql.SparkSession
24+
import org.apache.spark.sql.{Dataset, SparkSession}
2525
import org.apache.spark.sql.catalyst.AliasIdentifier
2626
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
2727
import org.apache.spark.sql.catalyst.expressions
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType,
3232
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LogicalPlan, Sample, SubqueryAlias, Union}
3333
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
3434
import org.apache.spark.sql.execution.QueryExecution
35+
import org.apache.spark.sql.execution.stat.StatFunctions
3536
import org.apache.spark.sql.types._
3637
import org.apache.spark.util.Utils
3738

@@ -73,6 +74,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
7374
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
7475
transformSubqueryAlias(rel.getSubqueryAlias)
7576
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
77+
case proto.Relation.RelTypeCase.STAT_FUNCTION =>
78+
transformStatFunction(rel.getStatFunction)
7679
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
7780
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
7881
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -124,6 +127,19 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
124127
logical.Range(start, end, step, numPartitions)
125128
}
126129

130+
private def transformStatFunction(rel: proto.StatFunction): LogicalPlan = {
131+
val child = transformRelation(rel.getInput)
132+
133+
rel.getFunctionCase match {
134+
case proto.StatFunction.FunctionCase.SUMMARY =>
135+
StatFunctions
136+
.summary(Dataset.ofRows(session, child), rel.getSummary.getStatisticsList.asScala.toSeq)
137+
.logicalPlan
138+
139+
case _ => throw InvalidPlanInput(s"StatFunction ${rel.getUnknown} not supported.")
140+
}
141+
}
142+
127143
private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
128144
if (!rel.hasInput) {
129145
throw InvalidPlanInput("Deduplicate needs a plan input")

connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
261261
comparePlans(connectPlan2, sparkPlan2)
262262
}
263263

264+
test("Test summary") {
265+
comparePlans(
266+
connectTestRelation.summary("count", "mean", "stddev"),
267+
sparkTestRelation.summary("count", "mean", "stddev"))
268+
}
269+
264270
private def createLocalRelationProtoByQualifiedAttributes(
265271
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
266272
val localRelationBuilder = proto.LocalRelation.newBuilder()

python/pyspark/sql/connect/dataframe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,16 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) ->
376376
def where(self, condition: Expression) -> "DataFrame":
377377
return self.filter(condition)
378378

379+
def summary(self, *statistics: str) -> "DataFrame":
380+
_statistics: List[str] = list(statistics)
381+
for s in _statistics:
382+
if not isinstance(s, str):
383+
raise TypeError(f"'statistics' must be list[str], but got {type(s).__name__}")
384+
return DataFrame.withPlan(
385+
plan.StatFunction(child=self._plan, function="summary", statistics=_statistics),
386+
session=self._session,
387+
)
388+
379389
def _get_alias(self) -> Optional[str]:
380390
p = self._plan
381391
while p is not None:

python/pyspark/sql/connect/plan.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
from typing import (
19+
Any,
1920
List,
2021
Optional,
2122
Sequence,
@@ -750,3 +751,40 @@ def _repr_html_(self) -> str:
750751
</li>
751752
</uL>
752753
"""
754+
755+
756+
class StatFunction(LogicalPlan):
757+
def __init__(self, child: Optional["LogicalPlan"], function: str, **kwargs: Any) -> None:
758+
super().__init__(child)
759+
assert function in ["summary"]
760+
self.function = function
761+
self.kwargs = kwargs
762+
763+
def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
764+
assert self._child is not None
765+
766+
plan = proto.Relation()
767+
plan.stat_function.input.CopyFrom(self._child.plan(session))
768+
769+
if self.function == "summary":
770+
plan.stat_function.summary.statistics.extend(self.kwargs.get("statistics", []))
771+
else:
772+
raise Exception(f"Unknown function ${self.function}.")
773+
774+
return plan
775+
776+
def print(self, indent: int = 0) -> str:
777+
i = " " * indent
778+
return f"""{i}<StatFunction function='{self.function}' augments='{self.kwargs}'>"""
779+
780+
def _repr_html_(self) -> str:
781+
return f"""
782+
<ul>
783+
<li>
784+
<b>StatFunction</b><br />
785+
Function: {self.function} <br />
786+
Augments: {self.kwargs} <br />
787+
{self._child_repr_()}
788+
</li>
789+
</ul>
790+
"""

0 commit comments

Comments
 (0)