Skip to content

Commit c20af53

Browse files
committed
[SPARK-36373][SQL] DecimalPrecision only add necessary cast
### What changes were proposed in this pull request? This pr makes `DecimalPrecision` only add necessary cast similar to [`ImplicitTypeCasts`](https://github.com/apache/spark/blob/96c2919988ddf78d104103876d8d8221e8145baa/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L675-L678). For example: ``` EqualTo(AttributeReference("d1", DecimalType(5, 2))(), AttributeReference("d2", DecimalType(2, 1))()) ``` It will add a useless cast to _d1_: ``` (cast(d1#6 as decimal(5,2)) = cast(d2#7 as decimal(5,2))) ``` ### Why are the changes needed? 1. Avoid adding unnecessary cast. Although it will be removed by `SimplifyCasts` later. 2. I'm trying to add an extended rule similar to `PullOutGroupingExpressions`. The current behavior will introduce additional alias. For example: `cast(d1 as decimal(5,2)) as cast_d1`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #33602 from wangyum/SPARK-36373. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Yuming Wang <[email protected]>
1 parent 7a27f8a commit c20af53

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ object DecimalPrecision extends TypeCoercionRule {
204204
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
205205
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
206206
val resultType = widerDecimalType(p1, s1, p2, s2)
207-
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
207+
val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
208+
val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
209+
b.makeCopy(Array(newE1, newE2))
208210
}
209211

210212
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
5959
val comparison = analyzer.execute(plan).collect {
6060
case Project(Alias(e: BinaryComparison, _) :: Nil, _) => e
6161
}.head
62+
// Only add necessary cast.
63+
assert(comparison.left.children.forall(_.dataType !== expectedType))
64+
assert(comparison.right.children.forall(_.dataType !== expectedType))
6265
assert(comparison.left.dataType === expectedType)
6366
assert(comparison.right.dataType === expectedType)
6467
}

0 commit comments

Comments
 (0)