Skip to content

Commit 2d539c5

Browse files
oleksii.diagilievsrowen
oleksii.diagiliev
authored andcommitted
[SPARK-41554] fix changing of Decimal scale when scale decreased by m…
…ore than 18 This is a backport PR for #39099 Closes #39813 from fe2s/branch-3.3-fix-decimal-scaling. Authored-by: oleksii.diagiliev <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 6e0dfa9 commit 2d539c5

File tree

2 files changed

+88
-25
lines changed

2 files changed

+88
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

+36-24
Original file line numberDiff line numberDiff line change
@@ -397,30 +397,42 @@ final class Decimal extends Ordered[Decimal] with Serializable {
397397
if (scale < _scale) {
398398
// Easier case: we just need to divide our scale down
399399
val diff = _scale - scale
400-
val pow10diff = POW_10(diff)
401-
// % and / always round to 0
402-
val droppedDigits = longVal % pow10diff
403-
longVal /= pow10diff
404-
roundMode match {
405-
case ROUND_FLOOR =>
406-
if (droppedDigits < 0) {
407-
longVal += -1L
408-
}
409-
case ROUND_CEILING =>
410-
if (droppedDigits > 0) {
411-
longVal += 1L
412-
}
413-
case ROUND_HALF_UP =>
414-
if (math.abs(droppedDigits) * 2 >= pow10diff) {
415-
longVal += (if (droppedDigits < 0) -1L else 1L)
416-
}
417-
case ROUND_HALF_EVEN =>
418-
val doubled = math.abs(droppedDigits) * 2
419-
if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
420-
longVal += (if (droppedDigits < 0) -1L else 1L)
421-
}
422-
case _ =>
423-
throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
400+
// If diff is greater than max number of digits we store in Long, then
401+
// value becomes 0. Otherwise we calculate new value dividing by power of 10.
402+
// In both cases we apply rounding after that.
403+
if (diff > MAX_LONG_DIGITS) {
404+
longVal = roundMode match {
405+
case ROUND_FLOOR => if (longVal < 0) -1L else 0L
406+
case ROUND_CEILING => if (longVal > 0) 1L else 0L
407+
case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L
408+
case _ => sys.error(s"Not supported rounding mode: $roundMode")
409+
}
410+
} else {
411+
val pow10diff = POW_10(diff)
412+
// % and / always round to 0
413+
val droppedDigits = longVal % pow10diff
414+
longVal /= pow10diff
415+
roundMode match {
416+
case ROUND_FLOOR =>
417+
if (droppedDigits < 0) {
418+
longVal += -1L
419+
}
420+
case ROUND_CEILING =>
421+
if (droppedDigits > 0) {
422+
longVal += 1L
423+
}
424+
case ROUND_HALF_UP =>
425+
if (math.abs(droppedDigits) * 2 >= pow10diff) {
426+
longVal += (if (droppedDigits < 0) -1L else 1L)
427+
}
428+
case ROUND_HALF_EVEN =>
429+
val doubled = math.abs(droppedDigits) * 2
430+
if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
431+
longVal += (if (droppedDigits < 0) -1L else 1L)
432+
}
433+
case _ =>
434+
throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
435+
}
424436
}
425437
} else if (scale > _scale) {
426438
// We might be able to multiply longVal by a power of 10 and not overflow, but if not,

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala

+52-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
2727
import org.apache.spark.unsafe.types.UTF8String
2828

2929
class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper {
30+
31+
val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING, ROUND_FLOOR)
32+
3033
/** Check that a Decimal has the given string representation, precision and scale */
3134
private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
3235
assert(d.toString === string)
@@ -222,7 +225,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
222225
}
223226

224227
test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
225-
Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
228+
allSupportedRoundModes.foreach { mode =>
226229
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
227230
Seq("", "-").foreach { sign =>
228231
val bd = BigDecimal(sign + n)
@@ -315,4 +318,52 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
315318
}
316319
}
317320
}
321+
322+
// 18 is a max number of digits in Decimal's compact long
323+
test("SPARK-41554: decrease/increase scale by 18 and more on compact decimal") {
324+
val unscaledNums = Seq(
325+
0L, 1L, 10L, 51L, 123L, 523L,
326+
// 18 digits
327+
912345678901234567L,
328+
112345678901234567L,
329+
512345678901234567L
330+
)
331+
val precision = 38
332+
// generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
333+
val scalePairs = for {
334+
scale <- Seq(38, 20, 19, 18)
335+
delta <- Seq(38, 20, 19, 18)
336+
a = scale
337+
b = scale - delta
338+
} yield {
339+
Seq((a, b), (-a, -b), (b, a), (-b, -a))
340+
}
341+
342+
for {
343+
unscaled <- unscaledNums
344+
mode <- allSupportedRoundModes
345+
(scaleFrom, scaleTo) <- scalePairs.flatten
346+
sign <- Seq(1L, -1L)
347+
} {
348+
val unscaledWithSign = unscaled * sign
349+
if (scaleFrom < 0 || scaleTo < 0) {
350+
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
351+
checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
352+
}
353+
} else {
354+
checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
355+
}
356+
}
357+
358+
def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int,
359+
roundMode: BigDecimal.RoundingMode.Value): Unit = {
360+
val decimal = Decimal(unscaled, precision, scaleFrom)
361+
checkCompact(decimal, true)
362+
decimal.changePrecision(precision, scaleTo, roundMode)
363+
val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode)
364+
assert(decimal.toBigDecimal === bd,
365+
s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: $roundMode")
366+
}
367+
}
368+
318369
}

0 commit comments

Comments
 (0)