Skip to content

Commit 7bc364b

Browse files
sigmodgengliangwang
authored andcommitted
[SPARK-35621][SQL] Add rule id pruning to the TypeCoercion rule
### What changes were proposed in this pull request? - Added TreeNode.transformUpWithBeforeAndAfterRuleOnChildren(...); - Call transformUpWithBeforeAndAfterRuleOnChildren in TypeCoercionRule. ### Why are the changes needed? Reduce the number of tree traversals and hence improve the query compilation latency. ### How was this patch tested? Existing tests. Performance diff : <google-sheets-html-origin><style type="text/css"></style> &nbsp; | Baseline | Experiment (wo. ruleId) | Experiment (wo. ruleId)/Baseline | Experiment (w. ruleId) | Experiment (w. ruleId)/Baseline -- | -- | -- | -- | -- | -- CombinedTypeCoercionRule | 665020354 | 567320034 | 0.85 | 330798240 | 0.50 </google-sheets-html-origin> Closes #32761 from sigmod/transform. Authored-by: Yingyi Bu <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent b5678be commit 7bc364b

File tree

5 files changed

+91
-15
lines changed

5 files changed

+91
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

+11-14
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,6 @@ abstract class TypeCoercionBase {
184184
}
185185
}
186186
}
187-
188-
override val ruleName: String = rules.map(_.ruleName).mkString("Combined[", ", ", "]")
189187
}
190188

191189
/**
@@ -1157,21 +1155,20 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging {
11571155
*/
11581156
def apply(plan: LogicalPlan): LogicalPlan = {
11591157
val typeCoercionFn = transform
1160-
def rewrite(plan: LogicalPlan): LogicalPlan = {
1161-
val withNewChildren = plan.mapChildren(rewrite)
1162-
if (!withNewChildren.childrenResolved) {
1163-
withNewChildren
1164-
} else {
1165-
// Only propagate types if the children have changed.
1166-
val withPropagatedTypes = if (withNewChildren ne plan) {
1167-
propagateTypes(withNewChildren)
1158+
plan.transformUpWithBeforeAndAfterRuleOnChildren(!_.analyzed, ruleId) {
1159+
case (beforeMapChildren, afterMapChildren) =>
1160+
if (!afterMapChildren.childrenResolved) {
1161+
afterMapChildren
11681162
} else {
1169-
plan
1163+
// Only propagate types if the children have changed.
1164+
val withPropagatedTypes = if (beforeMapChildren ne afterMapChildren) {
1165+
propagateTypes(afterMapChildren)
1166+
} else {
1167+
beforeMapChildren
1168+
}
1169+
withPropagatedTypes.transformExpressionsUp(typeCoercionFn)
11701170
}
1171-
withPropagatedTypes.transformExpressionsUp(typeCoercionFn)
1172-
}
11731171
}
1174-
rewrite(plan)
11751172
}
11761173

11771174
def transform: PartialFunction[Expression, Expression]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

+28-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.rules
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.errors.QueryExecutionErrors
23+
import org.apache.spark.util.Utils
2324

2425
// Represent unique rule ids for rules that are invoked multiple times.
2526
case class RuleId(id: Int) {
@@ -40,7 +41,7 @@ object RuleIdCollection {
4041
// invoked multiple times by Analyzer/Optimizer/Planner need a rule id to prune unnecessary
4142
// tree traversals in the transform function family. Note that those rules should not depend on
4243
// a changing, external state. Rules here are in alphabetical order.
43-
private val rulesNeedingIds: Seq[String] = {
44+
private var rulesNeedingIds: Seq[String] = {
4445
// Catalyst Analyzer rules
4546
"org.apache.spark.sql.catalyst.analysis.Analyzer$AddMetadataColumns" ::
4647
"org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractGenerator" ::
@@ -88,6 +89,7 @@ object RuleIdCollection {
8889
"org.apache.spark.sql.catalyst.analysis.ResolveUnion" ::
8990
"org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" ::
9091
"org.apache.spark.sql.catalyst.analysis.TimeWindowing" ::
92+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CombinedTypeCoercionRule" ::
9193
"org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
9294
"org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability" ::
9395
// Catalyst Optimizer rules
@@ -152,6 +154,31 @@ object RuleIdCollection {
152154
"org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison" :: Nil
153155
}
154156

157+
if(Utils.isTesting) {
158+
rulesNeedingIds = rulesNeedingIds ++ {
159+
// In the production code path, the following rules are run in CombinedTypeCoercionRule, and
160+
// hence we only need to add them for unit testing.
161+
"org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$PromoteStringLiterals" ::
162+
"org.apache.spark.sql.catalyst.analysis.DecimalPrecision" ::
163+
"org.apache.spark.sql.catalyst.analysis.TypeCoercion$BooleanEquality" ::
164+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CaseWhenCoercion" ::
165+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$ConcatCoercion" ::
166+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$DateTimeOperations" ::
167+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$Division" ::
168+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$EltCoercion" ::
169+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$FunctionArgumentConversion" ::
170+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$IfCoercion" ::
171+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$ImplicitTypeCasts" ::
172+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$InConversion" ::
173+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$IntegralDivision" ::
174+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$MapZipWithCoercion" ::
175+
"org.apache.spark.sql.catalyst.analysis.TypeCoercion$PromoteStrings" ::
176+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$StackCoercion" ::
177+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$StringLiteralCoercion" ::
178+
"org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$WindowFrameCoercion" :: Nil
179+
}
180+
}
181+
155182
// Maps rule names to ids. Rule ids are continuous natural numbers starting from 0.
156183
private val ruleToId = new mutable.HashMap[String, RuleId]
157184

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

+38
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,44 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
549549
}
550550
}
551551

552+
/**
553+
* Returns a copy of this node where `rule` has been recursively applied first to all of its
554+
* children and then itself (post-order). When `rule` does not apply to a given node, it is left
555+
* unchanged.
556+
*
557+
* @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
558+
* on a TreeNode T, skips processing T and its subtree; otherwise, processes
559+
* T and its subtree recursively.
560+
* @param rule the function use to transform this node and its descendant nodes. The function
561+
* takes a tuple as its input, where the first/second field is the before/after
562+
* image of applying the rule on the node's children.
563+
* @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is
564+
* UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`)
565+
* has been marked as in effective on a TreeNode T, skips processing T and its
566+
* subtree. Do not pass it if the rule is not purely functional and reads a
567+
* varying initial state for different invocations.
568+
*/
569+
def transformUpWithBeforeAndAfterRuleOnChildren(
570+
cond: BaseType => Boolean, ruleId: RuleId = UnknownRuleId)(
571+
rule: PartialFunction[(BaseType, BaseType), BaseType]): BaseType = {
572+
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
573+
return this
574+
}
575+
val afterRuleOnChildren =
576+
mapChildren(_.transformUpWithBeforeAndAfterRuleOnChildren(cond, ruleId)(rule))
577+
val newNode = CurrentOrigin.withOrigin(origin) {
578+
rule.applyOrElse((this, afterRuleOnChildren), { t: (BaseType, BaseType) => t._2 })
579+
}
580+
if (this eq newNode) {
581+
this.markRuleAsIneffective(ruleId)
582+
this
583+
} else {
584+
// If the transform function replaces this node with a new one, carry over the tags.
585+
newNode.copyTagsFrom(this)
586+
newNode
587+
}
588+
}
589+
552590
/**
553591
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
554592
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala

+7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import java.sql.Timestamp
2121

22+
import org.apache.spark.internal.config.Tests.IS_TESTING
2223
import org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion._
2324
import org.apache.spark.sql.catalyst.dsl.expressions._
2425
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -27,10 +28,16 @@ import org.apache.spark.sql.catalyst.plans.logical._
2728
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
2829
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.types._
31+
import org.apache.spark.util.Utils
3032

3133
class AnsiTypeCoercionSuite extends AnalysisTest {
3234
import TypeCoercionSuite._
3335

36+
// When Utils.isTesting is true, RuleIdCollection adds individual type coercion rules. Otherwise,
37+
// RuleIdCollection doesn't add them because they are called in a train inside
38+
// CombinedTypeCoercionRule.
39+
assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true")
40+
3441
// scalastyle:off line.size.limit
3542
// The following table shows all implicit data type conversions that are not visible to the user.
3643
// +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

+7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import java.sql.Timestamp
2121

22+
import org.apache.spark.internal.config.Tests.IS_TESTING
2223
import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
2324
import org.apache.spark.sql.catalyst.dsl.expressions._
2425
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -27,10 +28,16 @@ import org.apache.spark.sql.catalyst.plans.logical._
2728
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
2829
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.types._
31+
import org.apache.spark.util.Utils
3032

3133
class TypeCoercionSuite extends AnalysisTest {
3234
import TypeCoercionSuite._
3335

36+
// When Utils.isTesting is true, RuleIdCollection adds individual type coercion rules. Otherwise,
37+
// RuleIdCollection doesn't add them because they are called in a train inside
38+
// CombinedTypeCoercionRule.
39+
assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true")
40+
3441
// scalastyle:off line.size.limit
3542
// The following table shows all implicit data type conversions that are not visible to the user.
3643
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+

0 commit comments

Comments
 (0)