Skip to content

Commit 89af58f

Browse files
Merge pull request #9764 from dotty-staging/add-compiletime-const
Add scala.compiletime.requireConst
2 parents 62ec194 + 8809972 commit 89af58f

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ class Definitions {
226226
@tu lazy val CompiletimePackageObject: Symbol = requiredModule("scala.compiletime.package")
227227
@tu lazy val Compiletime_erasedValue : Symbol = CompiletimePackageObject.requiredMethod("erasedValue")
228228
@tu lazy val Compiletime_error : Symbol = CompiletimePackageObject.requiredMethod(nme.error)
229+
@tu lazy val Compiletime_requireConst: Symbol = CompiletimePackageObject.requiredMethod("requireConst")
229230
@tu lazy val Compiletime_constValue : Symbol = CompiletimePackageObject.requiredMethod("constValue")
230231
@tu lazy val Compiletime_constValueOpt: Symbol = CompiletimePackageObject.requiredMethod("constValueOpt")
231232
@tu lazy val Compiletime_code : Symbol = CompiletimePackageObject.requiredMethod("extension_code")

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,14 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
620620
/** The Inlined node representing the inlined call */
621621
def inlined(sourcePos: SrcPos): Tree = {
622622

623+
// Special handling of `requireConst`
624+
callValueArgss match
625+
case (arg :: Nil) :: Nil if inlinedMethod == defn.Compiletime_requireConst =>
626+
arg match
627+
case ConstantValue(_) | Inlined(_, Nil, Typed(ConstantValue(_), _)) => // ok
628+
case _ => report.error(em"expected a constant value but found: $arg", arg.srcPos)
629+
case _ =>
630+
623631
// Special handling of `constValue[T]` and `constValueOpt[T]`
624632
if (callTypeArgs.length == 1)
625633
if (inlinedMethod == defn.Compiletime_constValue) {
@@ -890,13 +898,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
890898
}
891899
}
892900

893-
object ConstantValue {
894-
def unapply(tree: Tree)(using Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match {
895-
case ConstantType(Constant(x)) => Some(x)
896-
case _ => None
897-
}
898-
}
899-
900901
def tryInline(tree: Tree)(using Context): Tree = tree match {
901902
case InlineableArg(rhs) =>
902903
inlining.println(i"inline arg $tree -> $rhs")
@@ -1453,4 +1454,12 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
14531454
foldOver(syms, tree)
14541455
}
14551456
}.apply(Nil, tree)
1457+
1458+
object ConstantValue {
1459+
def unapply(tree: Tree)(using Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match {
1460+
case ConstantType(Constant(x)) => Some(x)
1461+
case _ => None
1462+
}
1463+
}
1464+
14561465
}

library/src/scala/compiletime/package.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ package object compiletime {
5050
${ dotty.internal.CompileTimeMacros.codeExpr('self, 'args) }
5151
end extension
5252

53+
/** Checks at compiletime that the provided values is a constant after
54+
* inlining and constant folding.
55+
*
56+
* Usage:
57+
* ```scala
58+
* inline def twice(inline n: Int): Int =
59+
* requireConst(n) // compile-time assertion that the parameter `n` is a constant
60+
* n + n
61+
*
62+
* twice(1)
63+
* val m: Int = ...
64+
* twice(m) // error: expected a constant value but found: m
65+
* ```
66+
*/
67+
inline def requireConst(inline x: Boolean | Byte | Short | Int | Long | Float | Double | Char | String): Unit = ()
68+
5369
/** Same as `constValue` but returns a `None` if a constant value
5470
* cannot be constructed from the provided type. Otherwise returns
5571
* that value wrapped in `Some`.

tests/neg/compiletime-const.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import scala.compiletime.requireConst
2+
3+
object Test {
4+
5+
requireConst(true)
6+
requireConst(1)
7+
requireConst(1L)
8+
requireConst(1d)
9+
requireConst(1f)
10+
requireConst('a')
11+
requireConst("abc")
12+
13+
requireConst(1 + 3)
14+
requireConst("abc" + "cde")
15+
16+
val a: Int = 2
17+
inline val b = 2
18+
19+
requireConst(a) // error: expected a requireConstant value but found: Test.a
20+
requireConst(b)
21+
requireConst(b + b)
22+
requireConst(b - b)
23+
requireConst(b / b)
24+
requireConst(b % b)
25+
26+
inline def f(inline n: Int): Int = 4 + n
27+
28+
requireConst(f(1))
29+
requireConst(f(a)) // error: expected a requireConstant value but found: 4.+(Test.a):Int
30+
requireConst(f(b))
31+
requireConst(f(b + b))
32+
33+
def g(n: Int): n.type = n
34+
35+
requireConst(g(1))
36+
requireConst(g(a)) // error: expected a requireConstant value but found: Test.a
37+
requireConst(g(b))
38+
39+
40+
inline def twice(inline n: Int): Int =
41+
requireConst(n) // static assertion that n is a requireConstant
42+
n + n
43+
44+
twice(1)
45+
twice(a) // error: expected a requireConstant value but found: Test.a
46+
twice(b)
47+
48+
}

0 commit comments

Comments
 (0)