Skip to content

Commit 4624e59

Browse files
committed
[SPARK-36359][SQL] Coalesce drop all expressions after the first non nullable expression
### What changes were proposed in this pull request? `Coalesce` drop all expressions after the first non nullable expression. For example: ```scala sql("create table t1(a string, b string) using parquet") sql("select a, Coalesce(count(b), 0) from t1 group by a").explain(true) ``` Before this pr: ``` == Optimized Logical Plan == Aggregate [a#0], [a#0, coalesce(count(b#1), 0) AS coalesce(count(b), 0)#3L] +- Relation default.t1[a#0,b#1] parquet ``` After this pr: ``` == Optimized Logical Plan == Aggregate [a#0], [a#0, count(b#1) AS coalesce(count(b), 0)#3L] +- Relation default.t1[a#0,b#1] parquet ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #33590 from wangyum/SPARK-36359. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Yuming Wang <[email protected]>
1 parent 6e72951 commit 4624e59

File tree

5 files changed

+35
-6
lines changed

5 files changed

+35
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
24-
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern}
2525
import org.apache.spark.sql.catalyst.util.TypeUtils
2626
import org.apache.spark.sql.types._
2727

@@ -55,6 +55,8 @@ case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpress
5555
// Coalesce is foldable if all children are foldable.
5656
override def foldable: Boolean = children.forall(_.foldable)
5757

58+
final override val nodePatterns: Seq[TreePattern] = Seq(COALESCE)
59+
5860
override def checkInputDataTypes(): TypeCheckResult = {
5961
if (children.length < 1) {
6062
TypeCheckResult.TypeCheckFailure(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -752,10 +752,10 @@ object NullPropagation extends Rule[LogicalPlan] {
752752
}
753753

754754
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
755-
t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
755+
t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT, COALESCE)
756756
|| t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
757757
case q: LogicalPlan => q.transformExpressionsUpWithPruning(
758-
t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
758+
t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT, COALESCE)
759759
|| t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
760760
case e @ WindowExpression(Cast(Literal(0L, _), _, _, _), _) =>
761761
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
@@ -781,7 +781,12 @@ object NullPropagation extends Rule[LogicalPlan] {
781781
} else if (newChildren.length == 1) {
782782
newChildren.head
783783
} else {
784-
Coalesce(newChildren)
784+
val nonNullableIndex = newChildren.indexWhere(e => !e.nullable)
785+
if (nonNullableIndex > -1) {
786+
Coalesce(newChildren.take(nonNullableIndex + 1))
787+
} else {
788+
Coalesce(newChildren)
789+
}
785790
}
786791

787792
// If the value expression is NULL then transform the In expression to null literal.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ object TreePattern extends Enumeration {
3636
val BOOL_AGG: Value = Value
3737
val CASE_WHEN: Value = Value
3838
val CAST: Value = Value
39+
val COALESCE: Value = Value
3940
val CONCAT: Value = Value
4041
val COUNT: Value = Value
4142
val COUNT_IF: Value = Value

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala

+21
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,25 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper
173173
}
174174
}
175175
}
176+
177+
test("SPARK-36359: Coalesce drop all expressions after the first non nullable expression") {
178+
val testRelation = LocalRelation(
179+
'a.int.withNullability(false),
180+
'b.int.withNullability(true),
181+
'c.int.withNullability(false),
182+
'd.int.withNullability(true))
183+
184+
comparePlans(
185+
Optimize.execute(testRelation.select(Coalesce(Seq('a, 'b, 'c, 'd)).as("out")).analyze),
186+
testRelation.select('a.as("out")).analyze)
187+
comparePlans(
188+
Optimize.execute(testRelation.select(Coalesce(Seq('a, 'c)).as("out")).analyze),
189+
testRelation.select('a.as("out")).analyze)
190+
comparePlans(
191+
Optimize.execute(testRelation.select(Coalesce(Seq('b, 'c, 'd)).as("out")).analyze),
192+
testRelation.select(Coalesce(Seq('b, 'c)).as("out")).analyze)
193+
comparePlans(
194+
Optimize.execute(testRelation.select(Coalesce(Seq('b, 'd)).as("out")).analyze),
195+
testRelation.select(Coalesce(Seq('b, 'd)).as("out")).analyze)
196+
}
176197
}

sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
235235
val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " +
236236
"from range(2)")
237237
checkKeywordsExistsInExplain(df,
238-
"Project [coalesce(cast(id#xL as string), x) AS ifnull(id, x)#x, " +
239-
"id#xL AS nullif(id, x)#xL, coalesce(cast(id#xL as string), x) AS nvl(id, x)#x, " +
238+
"Project [cast(id#xL as string) AS ifnull(id, x)#x, " +
239+
"id#xL AS nullif(id, x)#xL, cast(id#xL as string) AS nvl(id, x)#x, " +
240240
"x AS nvl2(id, x, y)#x]")
241241
}
242242

0 commit comments

Comments
 (0)