diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 8263bab70aca..ff731f171dc4 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -226,6 +226,7 @@ class Definitions { @tu lazy val CompiletimePackageObject: Symbol = requiredModule("scala.compiletime.package") @tu lazy val Compiletime_erasedValue : Symbol = CompiletimePackageObject.requiredMethod("erasedValue") @tu lazy val Compiletime_error : Symbol = CompiletimePackageObject.requiredMethod(nme.error) + @tu lazy val Compiletime_requireConst: Symbol = CompiletimePackageObject.requiredMethod("requireConst") @tu lazy val Compiletime_constValue : Symbol = CompiletimePackageObject.requiredMethod("constValue") @tu lazy val Compiletime_constValueOpt: Symbol = CompiletimePackageObject.requiredMethod("constValueOpt") @tu lazy val Compiletime_code : Symbol = CompiletimePackageObject.requiredMethod("extension_code") diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 4ac48335afa7..20396c08f631 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -620,6 +620,14 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) { /** The Inlined node representing the inlined call */ def inlined(sourcePos: SrcPos): Tree = { + // Special handling of `requireConst` + callValueArgss match + case (arg :: Nil) :: Nil if inlinedMethod == defn.Compiletime_requireConst => + arg match + case ConstantValue(_) | Inlined(_, Nil, Typed(ConstantValue(_), _)) => // ok + case _ => report.error(em"expected a constant value but found: $arg", arg.srcPos) + case _ => + // Special handling of `constValue[T]` and `constValueOpt[T]` if (callTypeArgs.length == 1) if (inlinedMethod == defn.Compiletime_constValue) { @@ -890,13 +898,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) { } } - object ConstantValue { - def unapply(tree: Tree)(using Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match { - case ConstantType(Constant(x)) => Some(x) - case _ => None - } - } - def tryInline(tree: Tree)(using Context): Tree = tree match { case InlineableArg(rhs) => inlining.println(i"inline arg $tree -> $rhs") @@ -1453,4 +1454,12 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) { foldOver(syms, tree) } }.apply(Nil, tree) + + object ConstantValue { + def unapply(tree: Tree)(using Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match { + case ConstantType(Constant(x)) => Some(x) + case _ => None + } + } + } diff --git a/library/src/scala/compiletime/package.scala b/library/src/scala/compiletime/package.scala index ecc6abcd2e55..4c1e4027ce39 100644 --- a/library/src/scala/compiletime/package.scala +++ b/library/src/scala/compiletime/package.scala @@ -50,6 +50,22 @@ package object compiletime { ${ dotty.internal.CompileTimeMacros.codeExpr('self, 'args) } end extension + /** Checks at compiletime that the provided values is a constant after + * inlining and constant folding. + * + * Usage: + * ```scala + * inline def twice(inline n: Int): Int = + * requireConst(n) // compile-time assertion that the parameter `n` is a constant + * n + n + * + * twice(1) + * val m: Int = ... + * twice(m) // error: expected a constant value but found: m + * ``` + */ + inline def requireConst(inline x: Boolean | Byte | Short | Int | Long | Float | Double | Char | String): Unit = () + /** Same as `constValue` but returns a `None` if a constant value * cannot be constructed from the provided type. Otherwise returns * that value wrapped in `Some`. diff --git a/tests/neg/compiletime-const.scala b/tests/neg/compiletime-const.scala new file mode 100644 index 000000000000..2fa5c06056c0 --- /dev/null +++ b/tests/neg/compiletime-const.scala @@ -0,0 +1,48 @@ +import scala.compiletime.requireConst + +object Test { + + requireConst(true) + requireConst(1) + requireConst(1L) + requireConst(1d) + requireConst(1f) + requireConst('a') + requireConst("abc") + + requireConst(1 + 3) + requireConst("abc" + "cde") + + val a: Int = 2 + inline val b = 2 + + requireConst(a) // error: expected a requireConstant value but found: Test.a + requireConst(b) + requireConst(b + b) + requireConst(b - b) + requireConst(b / b) + requireConst(b % b) + + inline def f(inline n: Int): Int = 4 + n + + requireConst(f(1)) + requireConst(f(a)) // error: expected a requireConstant value but found: 4.+(Test.a):Int + requireConst(f(b)) + requireConst(f(b + b)) + + def g(n: Int): n.type = n + + requireConst(g(1)) + requireConst(g(a)) // error: expected a requireConstant value but found: Test.a + requireConst(g(b)) + + + inline def twice(inline n: Int): Int = + requireConst(n) // static assertion that n is a requireConstant + n + n + + twice(1) + twice(a) // error: expected a requireConstant value but found: Test.a + twice(b) + +}