@@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
27
27
import org .apache .spark .unsafe .types .UTF8String
28
28
29
29
class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper {
30
+
31
+ val allSupportedRoundModes = Seq (ROUND_HALF_UP , ROUND_HALF_EVEN , ROUND_CEILING , ROUND_FLOOR )
32
+
30
33
/** Check that a Decimal has the given string representation, precision and scale */
31
34
private def checkDecimal (d : Decimal , string : String , precision : Int , scale : Int ): Unit = {
32
35
assert(d.toString === string)
@@ -222,7 +225,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
222
225
}
223
226
224
227
test(" changePrecision/toPrecision on compact decimal should respect rounding mode" ) {
225
- Seq ( ROUND_FLOOR , ROUND_CEILING , ROUND_HALF_UP , ROUND_HALF_EVEN ) .foreach { mode =>
228
+ allSupportedRoundModes .foreach { mode =>
226
229
Seq (" 0.4" , " 0.5" , " 0.6" , " 1.0" , " 1.1" , " 1.6" , " 2.5" , " 5.5" ).foreach { n =>
227
230
Seq (" " , " -" ).foreach { sign =>
228
231
val bd = BigDecimal (sign + n)
@@ -315,4 +318,52 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
315
318
}
316
319
}
317
320
}
321
+
322
+ // 18 is a max number of digits in Decimal's compact long
323
+ test(" SPARK-41554: decrease/increase scale by 18 and more on compact decimal" ) {
324
+ val unscaledNums = Seq (
325
+ 0L , 1L , 10L , 51L , 123L , 523L ,
326
+ // 18 digits
327
+ 912345678901234567L ,
328
+ 112345678901234567L ,
329
+ 512345678901234567L
330
+ )
331
+ val precision = 38
332
+ // generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
333
+ val scalePairs = for {
334
+ scale <- Seq (38 , 20 , 19 , 18 )
335
+ delta <- Seq (38 , 20 , 19 , 18 )
336
+ a = scale
337
+ b = scale - delta
338
+ } yield {
339
+ Seq ((a, b), (- a, - b), (b, a), (- b, - a))
340
+ }
341
+
342
+ for {
343
+ unscaled <- unscaledNums
344
+ mode <- allSupportedRoundModes
345
+ (scaleFrom, scaleTo) <- scalePairs.flatten
346
+ sign <- Seq (1L , - 1L )
347
+ } {
348
+ val unscaledWithSign = unscaled * sign
349
+ if (scaleFrom < 0 || scaleTo < 0 ) {
350
+ withSQLConf(SQLConf .LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED .key -> " true" ) {
351
+ checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
352
+ }
353
+ } else {
354
+ checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
355
+ }
356
+ }
357
+
358
+ def checkScaleChange (unscaled : Long , scaleFrom : Int , scaleTo : Int ,
359
+ roundMode : BigDecimal .RoundingMode .Value ): Unit = {
360
+ val decimal = Decimal (unscaled, precision, scaleFrom)
361
+ checkCompact(decimal, true )
362
+ decimal.changePrecision(precision, scaleTo, roundMode)
363
+ val bd = BigDecimal (unscaled, scaleFrom).setScale(scaleTo, roundMode)
364
+ assert(decimal.toBigDecimal === bd,
365
+ s " unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: $roundMode" )
366
+ }
367
+ }
368
+
318
369
}
0 commit comments