Skip to content

Commit f213eff

Browse files
committed
Adds compiletime.ops.{long, float, double}, adds other ops, and fixes termref type not being considered.
fix check file more ops wip Added ops.float and ops.double
1 parent 968dd1b commit f213eff

16 files changed

+1099
-39
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ class Definitions {
246246
@tu lazy val CompiletimeOpsPackage: Symbol = requiredPackage("scala.compiletime.ops")
247247
@tu lazy val CompiletimeOpsAnyModuleClass: Symbol = requiredModule("scala.compiletime.ops.any").moduleClass
248248
@tu lazy val CompiletimeOpsIntModuleClass: Symbol = requiredModule("scala.compiletime.ops.int").moduleClass
249+
@tu lazy val CompiletimeOpsLongModuleClass: Symbol = requiredModule("scala.compiletime.ops.long").moduleClass
250+
@tu lazy val CompiletimeOpsFloatModuleClass: Symbol = requiredModule("scala.compiletime.ops.float").moduleClass
251+
@tu lazy val CompiletimeOpsDoubleModuleClass: Symbol = requiredModule("scala.compiletime.ops.double").moduleClass
249252
@tu lazy val CompiletimeOpsStringModuleClass: Symbol = requiredModule("scala.compiletime.ops.string").moduleClass
250253
@tu lazy val CompiletimeOpsBooleanModuleClass: Symbol = requiredModule("scala.compiletime.ops.boolean").moduleClass
251254

@@ -1077,19 +1080,40 @@ class Definitions {
10771080
final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
10781081
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass
10791082

1080-
private val compiletimePackageAnyTypes: Set[Name] = Set(tpnme.Equals, tpnme.NotEquals)
1081-
private val compiletimePackageIntTypes: Set[Name] = Set(
1083+
private val compiletimePackageAnyTypes: Set[Name] = Set(
1084+
tpnme.Equals, tpnme.NotEquals, tpnme.IsConst, tpnme.ToString
1085+
)
1086+
private val compiletimePackageNumericTypes: Set[Name] = Set(
10821087
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
10831088
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
1084-
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
1089+
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max
1090+
)
1091+
private val compiletimePackageIntTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1092+
tpnme.ToString, //ToString is moved to ops.any and deprecated for ops.int
1093+
tpnme.NumberOfLeadingZeros, tpnme.ToLong, tpnme.ToFloat, tpnme.ToDouble,
1094+
tpnme.Xor, tpnme.BitwiseAnd, tpnme.BitwiseOr, tpnme.ASR, tpnme.LSL, tpnme.LSR
1095+
)
1096+
private val compiletimePackageLongTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1097+
tpnme.NumberOfLeadingZeros, tpnme.ToInt, tpnme.ToFloat, tpnme.ToDouble,
10851098
tpnme.Xor, tpnme.BitwiseAnd, tpnme.BitwiseOr, tpnme.ASR, tpnme.LSL, tpnme.LSR
10861099
)
1100+
private val compiletimePackageFloatTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1101+
tpnme.ToInt, tpnme.ToLong, tpnme.ToDouble
1102+
)
1103+
private val compiletimePackageDoubleTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1104+
tpnme.ToInt, tpnme.ToLong, tpnme.ToFloat
1105+
)
10871106
private val compiletimePackageBooleanTypes: Set[Name] = Set(tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or)
1088-
private val compiletimePackageStringTypes: Set[Name] = Set(tpnme.Plus)
1107+
private val compiletimePackageStringTypes: Set[Name] = Set(
1108+
tpnme.Plus, tpnme.Length, tpnme.Substring, tpnme.Matches
1109+
)
10891110
private val compiletimePackageOpTypes: Set[Name] =
10901111
Set(tpnme.S)
10911112
++ compiletimePackageAnyTypes
10921113
++ compiletimePackageIntTypes
1114+
++ compiletimePackageLongTypes
1115+
++ compiletimePackageFloatTypes
1116+
++ compiletimePackageDoubleTypes
10931117
++ compiletimePackageBooleanTypes
10941118
++ compiletimePackageStringTypes
10951119

@@ -1099,6 +1123,9 @@ class Definitions {
10991123
isCompiletime_S(sym)
11001124
|| sym.owner == CompiletimeOpsAnyModuleClass && compiletimePackageAnyTypes.contains(sym.name)
11011125
|| sym.owner == CompiletimeOpsIntModuleClass && compiletimePackageIntTypes.contains(sym.name)
1126+
|| sym.owner == CompiletimeOpsLongModuleClass && compiletimePackageLongTypes.contains(sym.name)
1127+
|| sym.owner == CompiletimeOpsFloatModuleClass && compiletimePackageFloatTypes.contains(sym.name)
1128+
|| sym.owner == CompiletimeOpsDoubleModuleClass && compiletimePackageDoubleTypes.contains(sym.name)
11021129
|| sym.owner == CompiletimeOpsBooleanModuleClass && compiletimePackageBooleanTypes.contains(sym.name)
11031130
|| sym.owner == CompiletimeOpsStringModuleClass && compiletimePackageStringTypes.contains(sym.name)
11041131
)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -211,29 +211,38 @@ object StdNames {
211211
final val IOOBException: N = "IndexOutOfBoundsException"
212212
final val FunctionXXL: N = "FunctionXXL"
213213

214-
final val Abs: N = "Abs"
215-
final val And: N = "&&"
216-
final val BitwiseAnd: N = "BitwiseAnd"
217-
final val BitwiseOr: N = "BitwiseOr"
218-
final val Div: N = "/"
219-
final val Equals: N = "=="
220-
final val Ge: N = ">="
221-
final val Gt: N = ">"
222-
final val Le: N = "<="
223-
final val Lt: N = "<"
224-
final val Max: N = "Max"
225-
final val Min: N = "Min"
226-
final val Minus: N = "-"
227-
final val Mod: N = "%"
228-
final val Negate: N = "Negate"
229-
final val Not: N = "!"
230-
final val NotEquals: N = "!="
231-
final val Or: N = "||"
232-
final val Plus: N = "+"
233-
final val S: N = "S"
234-
final val Times: N = "*"
235-
final val ToString: N = "ToString"
236-
final val Xor: N = "^"
214+
final val Abs: N = "Abs"
215+
final val And: N = "&&"
216+
final val BitwiseAnd: N = "BitwiseAnd"
217+
final val BitwiseOr: N = "BitwiseOr"
218+
final val Div: N = "/"
219+
final val Equals: N = "=="
220+
final val Ge: N = ">="
221+
final val Gt: N = ">"
222+
final val IsConst: N = "IsConst"
223+
final val Le: N = "<="
224+
final val Length: N = "Length"
225+
final val Lt: N = "<"
226+
final val Matches: N = "Matches"
227+
final val Max: N = "Max"
228+
final val Min: N = "Min"
229+
final val Minus: N = "-"
230+
final val Mod: N = "%"
231+
final val Negate: N = "Negate"
232+
final val Not: N = "!"
233+
final val NotEquals: N = "!="
234+
final val NumberOfLeadingZeros: N = "NumberOfLeadingZeros"
235+
final val Or: N = "||"
236+
final val Plus: N = "+"
237+
final val S: N = "S"
238+
final val Substring: N = "Substring"
239+
final val Times: N = "*"
240+
final val ToInt: N = "ToInt"
241+
final val ToLong: N = "ToLong"
242+
final val ToFloat: N = "ToFloat"
243+
final val ToDouble: N = "ToDouble"
244+
final val ToString: N = "ToString"
245+
final val Xor: N = "^"
237246

238247
final val ClassfileAnnotation: N = "ClassfileAnnotation"
239248
final val ClassManifest: N = "ClassManifest"

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 126 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4194,37 +4194,76 @@ object Types {
41944194

41954195
def tryCompiletimeConstantFold(using Context): Type = tycon match {
41964196
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
4197-
def constValue(tp: Type): Option[Any] = tp.dealias match {
4197+
extension (tp : Type) def fixForEvaluation : Type =
4198+
tp.normalized.dealias match {
4199+
case tp : TermRef => tp.underlying
4200+
case tp => tp
4201+
}
4202+
4203+
def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match {
41984204
case ConstantType(Constant(n)) => Some(n)
41994205
case _ => None
42004206
}
42014207

4202-
def boolValue(tp: Type): Option[Boolean] = tp.dealias match {
4208+
def boolValue(tp: Type): Option[Boolean] = tp.fixForEvaluation match {
42034209
case ConstantType(Constant(n: Boolean)) => Some(n)
42044210
case _ => None
42054211
}
42064212

4207-
def intValue(tp: Type): Option[Int] = tp.dealias match {
4213+
def intValue(tp: Type): Option[Int] = tp.fixForEvaluation match {
42084214
case ConstantType(Constant(n: Int)) => Some(n)
42094215
case _ => None
42104216
}
42114217

4212-
def stringValue(tp: Type): Option[String] = tp.dealias match {
4213-
case ConstantType(Constant(n: String)) => Some(n)
4218+
def longValue(tp: Type): Option[Long] = tp.fixForEvaluation match {
4219+
case ConstantType(Constant(n: Long)) => Some(n)
4220+
case _ => None
4221+
}
4222+
4223+
def floatValue(tp: Type): Option[Float] = tp.fixForEvaluation match {
4224+
case ConstantType(Constant(n: Float)) => Some(n)
4225+
case _ => None
4226+
}
4227+
4228+
def doubleValue(tp: Type): Option[Double] = tp.fixForEvaluation match {
4229+
case ConstantType(Constant(n: Double)) => Some(n)
42144230
case _ => None
42154231
}
42164232

4233+
def stringValue(tp: Type): Option[String] = tp.fixForEvaluation match {
4234+
case ConstantType(Constant(n: String)) => Some(n)
4235+
case _ => None
4236+
}
4237+
def isConst : Option[Type] = args.head.fixForEvaluation match {
4238+
case ConstantType(_) => Some(ConstantType(Constant(true)))
4239+
case _ => Some(ConstantType(Constant(false)))
4240+
}
42174241
def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)
42184242

42194243
def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
4220-
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))
4244+
extractor(args.head).map(a => ConstantType(Constant(op(a))))
42214245

42224246
def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
4247+
constantFold2AB(extractor, extractor, op)
4248+
4249+
def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
42234250
for {
4224-
a <- extractor(args.head.normalized)
4225-
b <- extractor(args.tail.head.normalized)
4251+
a <- extractorA(args.head)
4252+
b <- extractorB(args.last)
42264253
} yield ConstantType(Constant(op(a, b)))
42274254

4255+
def constantFold3[TA, TB, TC](
4256+
extractorA: Type => Option[TA],
4257+
extractorB: Type => Option[TB],
4258+
extractorC: Type => Option[TC],
4259+
op: (TA, TB, TC) => Any
4260+
): Option[Type] =
4261+
for {
4262+
a <- extractorA(args.head)
4263+
b <- extractorB(args(1))
4264+
c <- extractorC(args.last)
4265+
} yield ConstantType(Constant(op(a, b, c)))
4266+
42284267
trace(i"compiletime constant fold $this", typr, show = true) {
42294268
val name = tycon.symbol.name
42304269
val owner = tycon.symbol.owner
@@ -4236,10 +4275,13 @@ object Types {
42364275
} else if (owner == defn.CompiletimeOpsAnyModuleClass) name match {
42374276
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
42384277
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
4278+
case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
4279+
case tpnme.IsConst if nArgs == 1 => isConst
42394280
case _ => None
42404281
} else if (owner == defn.CompiletimeOpsIntModuleClass) name match {
42414282
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
42424283
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
4284+
//ToString is deprecated for ops.int, and moved to ops.any
42434285
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
42444286
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
42454287
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
@@ -4264,9 +4306,85 @@ object Types {
42644306
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
42654307
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
42664308
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
4309+
case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
4310+
case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
4311+
case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
4312+
case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
4313+
case _ => None
4314+
} else if (owner == defn.CompiletimeOpsLongModuleClass) name match {
4315+
case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
4316+
case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => -x)
4317+
case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
4318+
case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
4319+
case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
4320+
case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
4321+
case (_, 0L) => throw new TypeError("Division by 0")
4322+
case (a, b) => a / b
4323+
})
4324+
case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
4325+
case (_, 0L) => throw new TypeError("Modulo by 0")
4326+
case (a, b) => a % b
4327+
})
4328+
case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
4329+
case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
4330+
case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
4331+
case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
4332+
case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
4333+
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
4334+
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
4335+
case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
4336+
case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
4337+
case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
4338+
case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
4339+
case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
4340+
case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
4341+
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
4342+
case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
4343+
case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
4344+
case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
4345+
case _ => None
4346+
} else if (owner == defn.CompiletimeOpsFloatModuleClass) name match {
4347+
case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
4348+
case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => -x)
4349+
case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
4350+
case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
4351+
case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
4352+
case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
4353+
case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
4354+
case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
4355+
case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
4356+
case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
4357+
case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
4358+
case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
4359+
case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
4360+
case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
4361+
case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
4362+
case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
4363+
case _ => None
4364+
} else if (owner == defn.CompiletimeOpsDoubleModuleClass) name match {
4365+
case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
4366+
case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => -x)
4367+
case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
4368+
case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
4369+
case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
4370+
case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
4371+
case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
4372+
case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
4373+
case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
4374+
case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
4375+
case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
4376+
case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
4377+
case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
4378+
case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
4379+
case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
4380+
case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
42674381
case _ => None
42684382
} else if (owner == defn.CompiletimeOpsStringModuleClass) name match {
42694383
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
4384+
case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
4385+
case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
4386+
case tpnme.Substring if nArgs == 3 =>
4387+
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
42704388
case _ => None
42714389
} else if (owner == defn.CompiletimeOpsBooleanModuleClass) name match {
42724390
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)

library/src/scala/compiletime/ops/any.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,22 @@ object any:
2121
* @syntax markdown
2222
*/
2323
type !=[X, Y] <: Boolean
24+
25+
/** Tests if a type is a constant.
26+
* ```scala
27+
* val c1: IsConst[1] = true
28+
* val c2: IsConst["hi"] = true
29+
* val c3: IsConst[false] = true
30+
* ```
31+
* @syntax markdown
32+
*/
33+
type IsConst[X] <: Boolean
34+
35+
/** String conversion of a constant singleton type.
36+
* ```scala
37+
* val s1: ToString[1] = "1"
38+
* val sTrue: ToString[true] = "true"
39+
* ```
40+
* @syntax markdown
41+
*/
42+
type ToString[X] <: String

0 commit comments

Comments
 (0)