Skip to content

Commit fb64041

Browse files
committed
[SPARK-40949][CONNECT][PYTHON] Implement DataFrame.sortWithinPartitions
### What changes were proposed in this pull request? Implement `DataFrame.sortWithinPartitions` ### Why are the changes needed? for api coverage ### Does this PR introduce _any_ user-facing change? yes, new method ### How was this patch tested? added UT Closes #38423 from zhengruifeng/connect_df_sortWithinPartitions. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent c09c779 commit fb64041

File tree

10 files changed

+147
-36
lines changed

10 files changed

+147
-36
lines changed

connector/connect/src/main/protobuf/spark/connect/relations.proto

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ message Aggregate {
173173
message Sort {
174174
Relation input = 1;
175175
repeated SortField sort_fields = 2;
176+
bool is_global = 3;
176177

177178
message SortField {
178179
Expression expression = 1;

connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala

+39
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,45 @@ package object dsl {
331331
.build()
332332
}
333333

334+
def createDefaultSortField(col: String): Sort.SortField = {
335+
Sort.SortField
336+
.newBuilder()
337+
.setNulls(Sort.SortNulls.SORT_NULLS_FIRST)
338+
.setDirection(Sort.SortDirection.SORT_DIRECTION_ASCENDING)
339+
.setExpression(
340+
Expression.newBuilder
341+
.setUnresolvedAttribute(
342+
Expression.UnresolvedAttribute.newBuilder.setUnparsedIdentifier(col).build())
343+
.build())
344+
.build()
345+
}
346+
347+
def sort(columns: String*): Relation = {
348+
Relation
349+
.newBuilder()
350+
.setSort(
351+
Sort
352+
.newBuilder()
353+
.setInput(logicalPlan)
354+
.addAllSortFields(columns.map(createDefaultSortField).asJava)
355+
.setIsGlobal(true)
356+
.build())
357+
.build()
358+
}
359+
360+
def sortWithinPartitions(columns: String*): Relation = {
361+
Relation
362+
.newBuilder()
363+
.setSort(
364+
Sort
365+
.newBuilder()
366+
.setInput(logicalPlan)
367+
.addAllSortFields(columns.map(createDefaultSortField).asJava)
368+
.setIsGlobal(false)
369+
.build())
370+
.build()
371+
}
372+
334373
def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = {
335374
val agg = Aggregate.newBuilder()
336375
agg.setInput(logicalPlan)

connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
364364
assert(rel.getSortFieldsCount > 0, "'sort_fields' must be present and contain elements.")
365365
logical.Sort(
366366
child = transformRelation(rel.getInput),
367-
global = true,
367+
global = rel.getIsGlobal,
368368
order = rel.getSortFieldsList.asScala.map(transformSortOrderExpression).toSeq)
369369
}
370370

connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala

+19-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.connect.proto
2424
import org.apache.spark.connect.proto.Expression.UnresolvedStar
2525
import org.apache.spark.sql.catalyst.expressions.AttributeReference
26-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
26+
import org.apache.spark.sql.catalyst.plans.logical
2727
import org.apache.spark.sql.test.SharedSparkSession
2828

2929
/**
@@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession
3232
*/
3333
trait SparkConnectPlanTest extends SharedSparkSession {
3434

35-
def transform(rel: proto.Relation): LogicalPlan = {
35+
def transform(rel: proto.Relation): logical.LogicalPlan = {
3636
new SparkConnectPlanner(rel, spark).transform()
3737
}
3838

@@ -149,9 +149,25 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
149149

150150
val res = transform(
151151
proto.Relation.newBuilder
152-
.setSort(proto.Sort.newBuilder.addAllSortFields(Seq(f).asJava).setInput(readRel))
152+
.setSort(
153+
proto.Sort.newBuilder
154+
.addAllSortFields(Seq(f).asJava)
155+
.setInput(readRel)
156+
.setIsGlobal(true))
153157
.build())
154158
assert(res.nodeName == "Sort")
159+
assert(res.asInstanceOf[logical.Sort].global)
160+
161+
val res2 = transform(
162+
proto.Relation.newBuilder
163+
.setSort(
164+
proto.Sort.newBuilder
165+
.addAllSortFields(Seq(f).asJava)
166+
.setInput(readRel)
167+
.setIsGlobal(false))
168+
.build())
169+
assert(res2.nodeName == "Sort")
170+
assert(!res2.asInstanceOf[logical.Sort].global)
155171
}
156172

157173
test("Simple Union") {

connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala

+10
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,16 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
118118
comparePlans(connectPlan, sparkPlan)
119119
}
120120

121+
test("Test sort") {
122+
val connectPlan = connectTestRelation.sort("id", "name")
123+
val sparkPlan = sparkTestRelation.sort("id", "name")
124+
comparePlans(connectPlan, sparkPlan)
125+
126+
val connectPlan2 = connectTestRelation.sortWithinPartitions("id", "name")
127+
val sparkPlan2 = sparkTestRelation.sortWithinPartitions("id", "name")
128+
comparePlans(connectPlan2, sparkPlan2)
129+
}
130+
121131
test("column alias") {
122132
val connectPlan = connectTestRelation.select("id".protoAttr.as("id2"))
123133
val sparkPlan = sparkTestRelation.select(Column("id").alias("id2"))

python/pyspark/sql/connect/dataframe.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,15 @@ def offset(self, n: int) -> "DataFrame":
242242

243243
def sort(self, *cols: "ColumnOrString") -> "DataFrame":
244244
"""Sort by a specific column"""
245-
return DataFrame.withPlan(plan.Sort(self._plan, *cols), session=self._session)
245+
return DataFrame.withPlan(
246+
plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session
247+
)
248+
249+
def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame":
250+
"""Sort within each partition by a specific column"""
251+
return DataFrame.withPlan(
252+
plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session
253+
)
246254

247255
def sample(
248256
self,

python/pyspark/sql/connect/plan.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,19 @@ def _repr_html_(self) -> str:
360360

361361
class Sort(LogicalPlan):
362362
def __init__(
363-
self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, ColumnRef, str]
363+
self,
364+
child: Optional["LogicalPlan"],
365+
columns: List[Union[SortOrder, ColumnRef, str]],
366+
is_global: bool,
364367
) -> None:
365368
super().__init__(child)
366-
self.columns = list(columns)
369+
self.columns = columns
370+
self.is_global = is_global
367371

368372
def col_to_sort_field(
369373
self, col: Union[SortOrder, ColumnRef, str], session: Optional["RemoteSparkSession"]
370374
) -> proto.Sort.SortField:
371-
if type(col) is SortOrder:
375+
if isinstance(col, SortOrder):
372376
sf = proto.Sort.SortField()
373377
sf.expression.CopyFrom(col.ref.to_plan(session))
374378
sf.direction = (
@@ -385,10 +389,10 @@ def col_to_sort_field(
385389
else:
386390
sf = proto.Sort.SortField()
387391
# Check string
388-
if type(col) is ColumnRef:
392+
if isinstance(col, ColumnRef):
389393
sf.expression.CopyFrom(col.to_plan(session))
390394
else:
391-
sf.expression.CopyFrom(self.unresolved_attr(cast(str, col)))
395+
sf.expression.CopyFrom(self.unresolved_attr(col))
392396
sf.direction = proto.Sort.SortDirection.SORT_DIRECTION_ASCENDING
393397
sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST
394398
return sf
@@ -398,18 +402,20 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
398402
plan = proto.Relation()
399403
plan.sort.input.CopyFrom(self._child.plan(session))
400404
plan.sort.sort_fields.extend([self.col_to_sort_field(x, session) for x in self.columns])
405+
plan.sort.is_global = self.is_global
401406
return plan
402407

403408
def print(self, indent: int = 0) -> str:
404409
c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
405-
return f"{' ' * indent}<Sort columns={self.columns}>\n{c_buf}"
410+
return f"{' ' * indent}<Sort columns={self.columns}, global={self.is_global}>\n{c_buf}"
406411

407412
def _repr_html_(self) -> str:
408413
return f"""
409414
<ul>
410415
<li>
411416
<b>Sort</b><br />
412417
{", ".join([str(c) for c in self.columns])}
418+
global: {self.is_global} <br />
413419
{self._child_repr_()}
414420
</li>
415421
</uL>

0 commit comments

Comments
 (0)