@@ -4194,37 +4194,76 @@ object Types {
4194
4194
4195
4195
def tryCompiletimeConstantFold (using Context ): Type = tycon match {
4196
4196
case tycon : TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
4197
- def constValue (tp : Type ): Option [Any ] = tp.dealias match {
4197
+ extension (tp : Type ) def fixForEvaluation : Type =
4198
+ tp.normalized.dealias match {
4199
+ case tp : TermRef => tp.underlying
4200
+ case tp => tp
4201
+ }
4202
+
4203
+ def constValue (tp : Type ): Option [Any ] = tp.fixForEvaluation match {
4198
4204
case ConstantType (Constant (n)) => Some (n)
4199
4205
case _ => None
4200
4206
}
4201
4207
4202
- def boolValue (tp : Type ): Option [Boolean ] = tp.dealias match {
4208
+ def boolValue (tp : Type ): Option [Boolean ] = tp.fixForEvaluation match {
4203
4209
case ConstantType (Constant (n : Boolean )) => Some (n)
4204
4210
case _ => None
4205
4211
}
4206
4212
4207
- def intValue (tp : Type ): Option [Int ] = tp.dealias match {
4213
+ def intValue (tp : Type ): Option [Int ] = tp.fixForEvaluation match {
4208
4214
case ConstantType (Constant (n : Int )) => Some (n)
4209
4215
case _ => None
4210
4216
}
4211
4217
4212
- def stringValue (tp : Type ): Option [String ] = tp.dealias match {
4213
- case ConstantType (Constant (n : String )) => Some (n)
4218
+ def longValue (tp : Type ): Option [Long ] = tp.fixForEvaluation match {
4219
+ case ConstantType (Constant (n : Long )) => Some (n)
4220
+ case _ => None
4221
+ }
4222
+
4223
+ def floatValue (tp : Type ): Option [Float ] = tp.fixForEvaluation match {
4224
+ case ConstantType (Constant (n : Float )) => Some (n)
4225
+ case _ => None
4226
+ }
4227
+
4228
+ def doubleValue (tp : Type ): Option [Double ] = tp.fixForEvaluation match {
4229
+ case ConstantType (Constant (n : Double )) => Some (n)
4214
4230
case _ => None
4215
4231
}
4216
4232
4233
+ def stringValue (tp : Type ): Option [String ] = tp.fixForEvaluation match {
4234
+ case ConstantType (Constant (n : String )) => Some (n)
4235
+ case _ => None
4236
+ }
4237
+ def isConst : Option [Type ] = args.head.fixForEvaluation match {
4238
+ case ConstantType (_) => Some (ConstantType (Constant (true )))
4239
+ case _ => Some (ConstantType (Constant (false )))
4240
+ }
4217
4241
def natValue (tp : Type ): Option [Int ] = intValue(tp).filter(n => n >= 0 && n < Int .MaxValue )
4218
4242
4219
4243
def constantFold1 [T ](extractor : Type => Option [T ], op : T => Any ): Option [Type ] =
4220
- extractor(args.head.normalized ).map(a => ConstantType (Constant (op(a))))
4244
+ extractor(args.head).map(a => ConstantType (Constant (op(a))))
4221
4245
4222
4246
def constantFold2 [T ](extractor : Type => Option [T ], op : (T , T ) => Any ): Option [Type ] =
4247
+ constantFold2AB(extractor, extractor, op)
4248
+
4249
+ def constantFold2AB [TA , TB ](extractorA : Type => Option [TA ], extractorB : Type => Option [TB ], op : (TA , TB ) => Any ): Option [Type ] =
4223
4250
for {
4224
- a <- extractor (args.head.normalized )
4225
- b <- extractor (args.tail.head.normalized )
4251
+ a <- extractorA (args.head)
4252
+ b <- extractorB (args.last )
4226
4253
} yield ConstantType (Constant (op(a, b)))
4227
4254
4255
+ def constantFold3 [TA , TB , TC ](
4256
+ extractorA : Type => Option [TA ],
4257
+ extractorB : Type => Option [TB ],
4258
+ extractorC : Type => Option [TC ],
4259
+ op : (TA , TB , TC ) => Any
4260
+ ): Option [Type ] =
4261
+ for {
4262
+ a <- extractorA(args.head)
4263
+ b <- extractorB(args(1 ))
4264
+ c <- extractorC(args.last)
4265
+ } yield ConstantType (Constant (op(a, b, c)))
4266
+
4228
4267
trace(i " compiletime constant fold $this" , typr, show = true ) {
4229
4268
val name = tycon.symbol.name
4230
4269
val owner = tycon.symbol.owner
@@ -4236,10 +4275,13 @@ object Types {
4236
4275
} else if (owner == defn.CompiletimeOpsAnyModuleClass ) name match {
4237
4276
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
4238
4277
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
4278
+ case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
4279
+ case tpnme.IsConst if nArgs == 1 => isConst
4239
4280
case _ => None
4240
4281
} else if (owner == defn.CompiletimeOpsIntModuleClass ) name match {
4241
4282
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
4242
4283
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => - x)
4284
+ // ToString is deprecated for ops.int, and moved to ops.any
4243
4285
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
4244
4286
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
4245
4287
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
@@ -4264,9 +4306,85 @@ object Types {
4264
4306
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
4265
4307
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
4266
4308
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
4309
+ case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
4310
+ case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
4311
+ case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
4312
+ case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
4313
+ case _ => None
4314
+ } else if (owner == defn.CompiletimeOpsLongModuleClass ) name match {
4315
+ case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
4316
+ case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => - x)
4317
+ case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
4318
+ case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
4319
+ case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
4320
+ case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
4321
+ case (_, 0L ) => throw new TypeError (" Division by 0" )
4322
+ case (a, b) => a / b
4323
+ })
4324
+ case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
4325
+ case (_, 0L ) => throw new TypeError (" Modulo by 0" )
4326
+ case (a, b) => a % b
4327
+ })
4328
+ case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
4329
+ case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
4330
+ case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
4331
+ case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
4332
+ case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
4333
+ case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
4334
+ case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
4335
+ case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
4336
+ case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
4337
+ case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
4338
+ case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
4339
+ case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
4340
+ case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
4341
+ constantFold1(longValue, java.lang.Long .numberOfLeadingZeros(_))
4342
+ case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
4343
+ case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
4344
+ case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
4345
+ case _ => None
4346
+ } else if (owner == defn.CompiletimeOpsFloatModuleClass ) name match {
4347
+ case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
4348
+ case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => - x)
4349
+ case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
4350
+ case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
4351
+ case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
4352
+ case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
4353
+ case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
4354
+ case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
4355
+ case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
4356
+ case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
4357
+ case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
4358
+ case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
4359
+ case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
4360
+ case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
4361
+ case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
4362
+ case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
4363
+ case _ => None
4364
+ } else if (owner == defn.CompiletimeOpsDoubleModuleClass ) name match {
4365
+ case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
4366
+ case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => - x)
4367
+ case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
4368
+ case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
4369
+ case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
4370
+ case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
4371
+ case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
4372
+ case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
4373
+ case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
4374
+ case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
4375
+ case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
4376
+ case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
4377
+ case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
4378
+ case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
4379
+ case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
4380
+ case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
4267
4381
case _ => None
4268
4382
} else if (owner == defn.CompiletimeOpsStringModuleClass ) name match {
4269
4383
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
4384
+ case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
4385
+ case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
4386
+ case tpnme.Substring if nArgs == 3 =>
4387
+ constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
4270
4388
case _ => None
4271
4389
} else if (owner == defn.CompiletimeOpsBooleanModuleClass ) name match {
4272
4390
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => ! x)
0 commit comments