Skip to content

Commit 9570a88

Browse files
committed
Adds compiletime.ops.{long, float, double}, adds other ops, and fixes termref type not being considered.
fix check file more ops wip Added ops.float and ops.double
1 parent 7a6cabe commit 9570a88

16 files changed

+1099
-39
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ class Definitions {
246246
@tu lazy val CompiletimeOpsPackage: Symbol = requiredPackage("scala.compiletime.ops")
247247
@tu lazy val CompiletimeOpsAnyModuleClass: Symbol = requiredModule("scala.compiletime.ops.any").moduleClass
248248
@tu lazy val CompiletimeOpsIntModuleClass: Symbol = requiredModule("scala.compiletime.ops.int").moduleClass
249+
@tu lazy val CompiletimeOpsLongModuleClass: Symbol = requiredModule("scala.compiletime.ops.long").moduleClass
250+
@tu lazy val CompiletimeOpsFloatModuleClass: Symbol = requiredModule("scala.compiletime.ops.float").moduleClass
251+
@tu lazy val CompiletimeOpsDoubleModuleClass: Symbol = requiredModule("scala.compiletime.ops.double").moduleClass
249252
@tu lazy val CompiletimeOpsStringModuleClass: Symbol = requiredModule("scala.compiletime.ops.string").moduleClass
250253
@tu lazy val CompiletimeOpsBooleanModuleClass: Symbol = requiredModule("scala.compiletime.ops.boolean").moduleClass
251254

@@ -1071,19 +1074,40 @@ class Definitions {
10711074
final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
10721075
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass
10731076

1074-
private val compiletimePackageAnyTypes: Set[Name] = Set(tpnme.Equals, tpnme.NotEquals)
1075-
private val compiletimePackageIntTypes: Set[Name] = Set(
1077+
private val compiletimePackageAnyTypes: Set[Name] = Set(
1078+
tpnme.Equals, tpnme.NotEquals, tpnme.IsConst, tpnme.ToString
1079+
)
1080+
private val compiletimePackageNumericTypes: Set[Name] = Set(
10761081
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
10771082
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
1078-
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
1083+
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max
1084+
)
1085+
private val compiletimePackageIntTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1086+
tpnme.ToString, //ToString is moved to ops.any and deprecated for ops.int
1087+
tpnme.NumberOfLeadingZeros, tpnme.ToLong, tpnme.ToFloat, tpnme.ToDouble,
1088+
tpnme.Xor, tpnme.BitwiseAnd, tpnme.BitwiseOr, tpnme.ASR, tpnme.LSL, tpnme.LSR
1089+
)
1090+
private val compiletimePackageLongTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1091+
tpnme.NumberOfLeadingZeros, tpnme.ToInt, tpnme.ToFloat, tpnme.ToDouble,
10791092
tpnme.Xor, tpnme.BitwiseAnd, tpnme.BitwiseOr, tpnme.ASR, tpnme.LSL, tpnme.LSR
10801093
)
1094+
private val compiletimePackageFloatTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1095+
tpnme.ToInt, tpnme.ToLong, tpnme.ToDouble
1096+
)
1097+
private val compiletimePackageDoubleTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
1098+
tpnme.ToInt, tpnme.ToLong, tpnme.ToFloat
1099+
)
10811100
private val compiletimePackageBooleanTypes: Set[Name] = Set(tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or)
1082-
private val compiletimePackageStringTypes: Set[Name] = Set(tpnme.Plus)
1101+
private val compiletimePackageStringTypes: Set[Name] = Set(
1102+
tpnme.Plus, tpnme.Length, tpnme.Substring, tpnme.Matches
1103+
)
10831104
private val compiletimePackageOpTypes: Set[Name] =
10841105
Set(tpnme.S)
10851106
++ compiletimePackageAnyTypes
10861107
++ compiletimePackageIntTypes
1108+
++ compiletimePackageLongTypes
1109+
++ compiletimePackageFloatTypes
1110+
++ compiletimePackageDoubleTypes
10871111
++ compiletimePackageBooleanTypes
10881112
++ compiletimePackageStringTypes
10891113

@@ -1093,6 +1117,9 @@ class Definitions {
10931117
isCompiletime_S(sym)
10941118
|| sym.owner == CompiletimeOpsAnyModuleClass && compiletimePackageAnyTypes.contains(sym.name)
10951119
|| sym.owner == CompiletimeOpsIntModuleClass && compiletimePackageIntTypes.contains(sym.name)
1120+
|| sym.owner == CompiletimeOpsLongModuleClass && compiletimePackageLongTypes.contains(sym.name)
1121+
|| sym.owner == CompiletimeOpsFloatModuleClass && compiletimePackageFloatTypes.contains(sym.name)
1122+
|| sym.owner == CompiletimeOpsDoubleModuleClass && compiletimePackageDoubleTypes.contains(sym.name)
10961123
|| sym.owner == CompiletimeOpsBooleanModuleClass && compiletimePackageBooleanTypes.contains(sym.name)
10971124
|| sym.owner == CompiletimeOpsStringModuleClass && compiletimePackageStringTypes.contains(sym.name)
10981125
)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -211,29 +211,38 @@ object StdNames {
211211
final val IOOBException: N = "IndexOutOfBoundsException"
212212
final val FunctionXXL: N = "FunctionXXL"
213213

214-
final val Abs: N = "Abs"
215-
final val And: N = "&&"
216-
final val BitwiseAnd: N = "BitwiseAnd"
217-
final val BitwiseOr: N = "BitwiseOr"
218-
final val Div: N = "/"
219-
final val Equals: N = "=="
220-
final val Ge: N = ">="
221-
final val Gt: N = ">"
222-
final val Le: N = "<="
223-
final val Lt: N = "<"
224-
final val Max: N = "Max"
225-
final val Min: N = "Min"
226-
final val Minus: N = "-"
227-
final val Mod: N = "%"
228-
final val Negate: N = "Negate"
229-
final val Not: N = "!"
230-
final val NotEquals: N = "!="
231-
final val Or: N = "||"
232-
final val Plus: N = "+"
233-
final val S: N = "S"
234-
final val Times: N = "*"
235-
final val ToString: N = "ToString"
236-
final val Xor: N = "^"
214+
final val Abs: N = "Abs"
215+
final val And: N = "&&"
216+
final val BitwiseAnd: N = "BitwiseAnd"
217+
final val BitwiseOr: N = "BitwiseOr"
218+
final val Div: N = "/"
219+
final val Equals: N = "=="
220+
final val Ge: N = ">="
221+
final val Gt: N = ">"
222+
final val IsConst: N = "IsConst"
223+
final val Le: N = "<="
224+
final val Length: N = "Length"
225+
final val Lt: N = "<"
226+
final val Matches: N = "Matches"
227+
final val Max: N = "Max"
228+
final val Min: N = "Min"
229+
final val Minus: N = "-"
230+
final val Mod: N = "%"
231+
final val Negate: N = "Negate"
232+
final val Not: N = "!"
233+
final val NotEquals: N = "!="
234+
final val NumberOfLeadingZeros: N = "NumberOfLeadingZeros"
235+
final val Or: N = "||"
236+
final val Plus: N = "+"
237+
final val S: N = "S"
238+
final val Substring: N = "Substring"
239+
final val Times: N = "*"
240+
final val ToInt: N = "ToInt"
241+
final val ToLong: N = "ToLong"
242+
final val ToFloat: N = "ToFloat"
243+
final val ToDouble: N = "ToDouble"
244+
final val ToString: N = "ToString"
245+
final val Xor: N = "^"
237246

238247
final val ClassfileAnnotation: N = "ClassfileAnnotation"
239248
final val ClassManifest: N = "ClassManifest"

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 126 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,37 +4174,76 @@ object Types {
41744174

41754175
def tryCompiletimeConstantFold(using Context): Type = tycon match {
41764176
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
4177-
def constValue(tp: Type): Option[Any] = tp.dealias match {
4177+
extension (tp : Type) def fixForEvaluation : Type =
4178+
tp.normalized.dealias match {
4179+
case tp : TermRef => tp.underlying
4180+
case tp => tp
4181+
}
4182+
4183+
def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match {
41784184
case ConstantType(Constant(n)) => Some(n)
41794185
case _ => None
41804186
}
41814187

4182-
def boolValue(tp: Type): Option[Boolean] = tp.dealias match {
4188+
def boolValue(tp: Type): Option[Boolean] = tp.fixForEvaluation match {
41834189
case ConstantType(Constant(n: Boolean)) => Some(n)
41844190
case _ => None
41854191
}
41864192

4187-
def intValue(tp: Type): Option[Int] = tp.dealias match {
4193+
def intValue(tp: Type): Option[Int] = tp.fixForEvaluation match {
41884194
case ConstantType(Constant(n: Int)) => Some(n)
41894195
case _ => None
41904196
}
41914197

4192-
def stringValue(tp: Type): Option[String] = tp.dealias match {
4193-
case ConstantType(Constant(n: String)) => Some(n)
4198+
def longValue(tp: Type): Option[Long] = tp.fixForEvaluation match {
4199+
case ConstantType(Constant(n: Long)) => Some(n)
4200+
case _ => None
4201+
}
4202+
4203+
def floatValue(tp: Type): Option[Float] = tp.fixForEvaluation match {
4204+
case ConstantType(Constant(n: Float)) => Some(n)
4205+
case _ => None
4206+
}
4207+
4208+
def doubleValue(tp: Type): Option[Double] = tp.fixForEvaluation match {
4209+
case ConstantType(Constant(n: Double)) => Some(n)
41944210
case _ => None
41954211
}
41964212

4213+
def stringValue(tp: Type): Option[String] = tp.fixForEvaluation match {
4214+
case ConstantType(Constant(n: String)) => Some(n)
4215+
case _ => None
4216+
}
4217+
def isConst : Option[Type] = args.head.fixForEvaluation match {
4218+
case ConstantType(_) => Some(ConstantType(Constant(true)))
4219+
case _ => Some(ConstantType(Constant(false)))
4220+
}
41974221
def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)
41984222

41994223
def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
4200-
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))
4224+
extractor(args.head).map(a => ConstantType(Constant(op(a))))
42014225

42024226
def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
4227+
constantFold2AB(extractor, extractor, op)
4228+
4229+
def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
42034230
for {
4204-
a <- extractor(args.head.normalized)
4205-
b <- extractor(args.tail.head.normalized)
4231+
a <- extractorA(args.head)
4232+
b <- extractorB(args.last)
42064233
} yield ConstantType(Constant(op(a, b)))
42074234

4235+
def constantFold3[TA, TB, TC](
4236+
extractorA: Type => Option[TA],
4237+
extractorB: Type => Option[TB],
4238+
extractorC: Type => Option[TC],
4239+
op: (TA, TB, TC) => Any
4240+
): Option[Type] =
4241+
for {
4242+
a <- extractorA(args.head)
4243+
b <- extractorB(args(1))
4244+
c <- extractorC(args.last)
4245+
} yield ConstantType(Constant(op(a, b, c)))
4246+
42084247
trace(i"compiletime constant fold $this", typr, show = true) {
42094248
val name = tycon.symbol.name
42104249
val owner = tycon.symbol.owner
@@ -4216,10 +4255,13 @@ object Types {
42164255
} else if (owner == defn.CompiletimeOpsAnyModuleClass) name match {
42174256
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
42184257
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
4258+
case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
4259+
case tpnme.IsConst if nArgs == 1 => isConst
42194260
case _ => None
42204261
} else if (owner == defn.CompiletimeOpsIntModuleClass) name match {
42214262
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
42224263
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
4264+
//ToString is deprecated for ops.int, and moved to ops.any
42234265
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
42244266
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
42254267
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
@@ -4244,9 +4286,85 @@ object Types {
42444286
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
42454287
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
42464288
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
4289+
case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
4290+
case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
4291+
case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
4292+
case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
4293+
case _ => None
4294+
} else if (owner == defn.CompiletimeOpsLongModuleClass) name match {
4295+
case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
4296+
case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => -x)
4297+
case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
4298+
case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
4299+
case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
4300+
case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
4301+
case (_, 0L) => throw new TypeError("Division by 0")
4302+
case (a, b) => a / b
4303+
})
4304+
case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
4305+
case (_, 0L) => throw new TypeError("Modulo by 0")
4306+
case (a, b) => a % b
4307+
})
4308+
case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
4309+
case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
4310+
case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
4311+
case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
4312+
case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
4313+
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
4314+
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
4315+
case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
4316+
case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
4317+
case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
4318+
case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
4319+
case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
4320+
case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
4321+
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
4322+
case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
4323+
case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
4324+
case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
4325+
case _ => None
4326+
} else if (owner == defn.CompiletimeOpsFloatModuleClass) name match {
4327+
case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
4328+
case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => -x)
4329+
case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
4330+
case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
4331+
case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
4332+
case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
4333+
case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
4334+
case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
4335+
case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
4336+
case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
4337+
case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
4338+
case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
4339+
case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
4340+
case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
4341+
case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
4342+
case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
4343+
case _ => None
4344+
} else if (owner == defn.CompiletimeOpsDoubleModuleClass) name match {
4345+
case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
4346+
case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => -x)
4347+
case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
4348+
case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
4349+
case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
4350+
case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
4351+
case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
4352+
case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
4353+
case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
4354+
case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
4355+
case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
4356+
case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
4357+
case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
4358+
case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
4359+
case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
4360+
case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
42474361
case _ => None
42484362
} else if (owner == defn.CompiletimeOpsStringModuleClass) name match {
42494363
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
4364+
case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
4365+
case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
4366+
case tpnme.Substring if nArgs == 3 =>
4367+
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
42504368
case _ => None
42514369
} else if (owner == defn.CompiletimeOpsBooleanModuleClass) name match {
42524370
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)

library/src/scala/compiletime/ops/any.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,22 @@ object any:
2121
* @syntax markdown
2222
*/
2323
type !=[X, Y] <: Boolean
24+
25+
/** Tests if a type is a constant.
26+
* ```scala
27+
* val c1: IsConst[1] = true
28+
* val c2: IsConst["hi"] = true
29+
* val c3: IsConst[false] = true
30+
* ```
31+
* @syntax markdown
32+
*/
33+
type IsConst[X] <: Boolean
34+
35+
/** String conversion of a constant singleton type.
36+
* ```scala
37+
* val s1: ToString[1] = "1"
38+
* val sTrue: ToString[true] = "true"
39+
* ```
40+
* @syntax markdown
41+
*/
42+
type ToString[X] <: String

0 commit comments

Comments
 (0)