Skip to content

Commit 68c2487

Browse files
committed
Add scala.compiletime.requireConst
`requireConst` checks at compiletime that the provided values is a constant after inlining and constant folding. Usage: ```scala inline def twice(inline n: Int): Int = requireConst(n) // static assertion that the parameter `n` is a constant n + n twice(1) val m: Int = ... twice(m) // error: expected a constant value but found: m ``` This check may aslo be used to stop inlining in the branches of an if that does not have a constant condition. Such checks can avoid infinite recursions of inlining when a value is assumed to be a constant. ```scala inline def power(x: Long, n: Int): Long = if requireConst(n == 0) then 1L else if requireConst(n % 2 == 1) then x * power(x, n - 1) else { val y: Long = x * x power(y, n / 2) } val n: Int = ... power(x, n) // error: expected a constant value but found: n == 0 ```
1 parent 8639874 commit 68c2487

File tree

10 files changed

+112
-17
lines changed

10 files changed

+112
-17
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ class Definitions {
226226
@tu lazy val CompiletimePackageObject: Symbol = requiredModule("scala.compiletime.package")
227227
@tu lazy val Compiletime_erasedValue : Symbol = CompiletimePackageObject.requiredMethod("erasedValue")
228228
@tu lazy val Compiletime_error : Symbol = CompiletimePackageObject.requiredMethod(nme.error)
229+
@tu lazy val Compiletime_requireConst: Symbol = CompiletimePackageObject.requiredMethod("requireConst")
229230
@tu lazy val Compiletime_constValue : Symbol = CompiletimePackageObject.requiredMethod("constValue")
230231
@tu lazy val Compiletime_constValueOpt: Symbol = CompiletimePackageObject.requiredMethod("constValueOpt")
231232
@tu lazy val Compiletime_code : Symbol = CompiletimePackageObject.requiredMethod("extension_code")

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
463463
case tree @ Inlined(call, bindings, body) =>
464464
(("/* inlined from " ~ (if (call.isEmpty) "outside" else toText(call)) ~ " */ ") `provided`
465465
!homogenizedView && ctx.settings.XprintInline.value) ~
466-
blockText(bindings :+ body)
466+
(if bindings.isEmpty then toText(body) else blockText(bindings :+ body))
467467
case tpt: untpd.DerivedTypeTree =>
468468
"<derived typetree watching " ~ tpt.watched.showSummary() ~ ">"
469469
case TypeTree() =>

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -620,8 +620,21 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
620620
/** The Inlined node representing the inlined call */
621621
def inlined(sourcePos: SrcPos): Tree = {
622622

623+
// Special handling of `const`
624+
callValueArgss match
625+
case (arg :: Nil) :: Nil =>
626+
if (inlinedMethod == defn.Compiletime_requireConst) {
627+
arg match
628+
case ConstantValue(v) => return arg
629+
case Inlined(_, Nil, Typed(ConstantValue(v), _)) => return arg
630+
case _ =>
631+
report.error(em"expected a constant value but found: $arg", arg.srcPos)
632+
}
633+
case _ =>
634+
623635
// Special handling of `constValue[T]` and `constValueOpt[T]`
624636
if (callTypeArgs.length == 1)
637+
625638
if (inlinedMethod == defn.Compiletime_constValue) {
626639
val constVal = tryConstValue
627640
if (!constVal.isEmpty) return constVal
@@ -890,12 +903,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
890903
}
891904
}
892905

893-
object ConstantValue {
894-
def unapply(tree: Tree)(using Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match {
895-
case ConstantType(Constant(x)) => Some(x)
896-
case _ => None
897-
}
898-
}
899906

900907
def tryInline(tree: Tree)(using Context): Tree = tree match {
901908
case InlineableArg(rhs) =>
@@ -1208,14 +1215,17 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
12081215
}
12091216

12101217
override def typedIf(tree: untpd.If, pt: Type)(using Context): Tree =
1218+
val err = ctx.reporter.errorCount
12111219
typed(tree.cond, defn.BooleanType) match {
12121220
case cond1 @ ConstantValue(b: Boolean) =>
12131221
val selected0 = if (b) tree.thenp else tree.elsep
12141222
val selected = if (selected0.isEmpty) tpd.Literal(Constant(())) else typed(selected0, pt)
12151223
if (isIdempotentExpr(cond1)) selected
12161224
else Block(cond1 :: Nil, selected)
1225+
case cond1: tpd.Inlined if cond1.call.symbol == defn.Compiletime_requireConst =>
1226+
errorTree(tree, em"Expected constant condition but found: ${tree.cond}")
12171227
case cond1 =>
1218-
if (tree.isInline)
1228+
if (tree.isInline) // TODO replace with Compiletime_requireConst value check?
12191229
errorTree(tree, em"""cannot reduce inline if
12201230
| its condition ${tree.cond}
12211231
| is not a constant value""")
@@ -1452,4 +1462,12 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
14521462
foldOver(syms, tree)
14531463
}
14541464
}.apply(Nil, tree)
1465+
1466+
object ConstantValue {
1467+
def unapply(tree: Tree)(using Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match {
1468+
case ConstantType(Constant(x)) => Some(x)
1469+
case _ => None
1470+
}
1471+
}
1472+
14551473
}

library/src/scala/compiletime/package.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ package object compiletime {
5050
${ dotty.internal.CompileTimeMacros.codeExpr('self, 'args) }
5151
end extension
5252

53+
/** Checks at compiletime that the provided values is a constant after inlining and constant folding.
54+
*
55+
* Usage:
56+
* ```scala
57+
* inline def twice(inline n: Int): Int =
58+
* requireConst(n) // static assertion that the parameter `n` is a constant
59+
* n + n
60+
*
61+
* twice(1)
62+
* val m: Int = ...
63+
* twice(m) // error: expected a constant value but found: m
64+
* ```
65+
*
66+
* This check may aslo be used to stop inlining in the branches of an if that does not have a constant condition.
67+
* Such checks can avoid infinite recursions of inlining when a value is assumed to be a constant.
68+
* ```scala
69+
* inline def power(x: Long, n: Int): Long =
70+
* if requireConst(n == 0) then 1L
71+
* else if requireConst(n % 2 == 1) then x * power(x, n - 1)
72+
* else {
73+
* val y: Long = x * x
74+
* power(y, n / 2)
75+
* }
76+
*
77+
* val n: Int = ...
78+
* power(x, n) // error: expected a constant value but found: n == 0
79+
* ```
80+
*/
81+
inline def requireConst(x: Any): x.type = ???
82+
5383
/** Same as `constValue` but returns a `None` if a constant value
5484
* cannot be constructed from the provided type. Otherwise returns
5585
* that value wrapped in `Some`.

tests/invalid/run/typelevel1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ object Test extends App {
3333
// Does not work since it infers `Any` as a type argument for `::`
3434
// and we cannot undo that without a typing from untyped.
3535
transparent inline def concat[T1, T2](xs: HList, ys: HList): HList =
36-
inline if xs.isEmpty then ys
36+
if compiletime.requireConst(xs.isEmpty) then ys
3737
else new ::(xs.head, concat(xs.tail, ys))
3838

3939
val xs = 1 :: "a" :: "b" :: HNil

tests/neg/cannot-reduce-inline-match.check

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
9 | foo("f") // error
33
| ^^^^^^^^
44
| cannot reduce inline match with
5-
| scrutinee: {
6-
| "f"
7-
| } : ("f" : String)
5+
| scrutinee: "f" : ("f" : String)
86
| patterns : case _:Int
97
| This location contains code that was inlined from cannot-reduce-inline-match.scala:3

tests/neg/compiletime-const.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import scala.compiletime.requireConst
2+
3+
object Test {
4+
5+
requireConst(true)
6+
requireConst(1)
7+
requireConst(1L)
8+
requireConst(1d)
9+
requireConst(1f)
10+
requireConst('a')
11+
requireConst("abc")
12+
13+
requireConst(1 + 3)
14+
requireConst("abc" + "cde")
15+
16+
val a: Int = 2
17+
inline val b = 2
18+
19+
requireConst(a) // error: expected a requireConstant value but found: Test.a
20+
requireConst(b)
21+
requireConst(b + b)
22+
requireConst(b - b)
23+
requireConst(b / b)
24+
requireConst(b % b)
25+
26+
inline def f(inline n: Int): Int = 4 + n
27+
28+
requireConst(f(1))
29+
requireConst(f(a)) // error: expected a requireConstant value but found: 4.+(Test.a):Int
30+
requireConst(f(b))
31+
requireConst(f(b + b))
32+
33+
def g(n: Int): n.type = n
34+
35+
requireConst(g(1))
36+
requireConst(g(a)) // error: expected a requireConstant value but found: Test.a
37+
requireConst(g(b))
38+
39+
40+
inline def twice(inline n: Int): Int =
41+
requireConst(n) // static assertion that n is a requireConstant
42+
n + n
43+
44+
twice(1)
45+
twice(a) // error: expected a requireConstant value but found: Test.a
46+
twice(b)
47+
48+
}

tests/neg/inline-error-pos.check

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
8 | val b = foo(2) // error
33
| ^^^^^^
44
| cannot reduce inline match with
5-
| scrutinee: {
6-
| 2
7-
| } : (2 : Int)
5+
| scrutinee: 2 : (2 : Int)
86
| patterns : case 1
97
| This location contains code that was inlined from inline-error-pos.scala:3

tests/neg/inlinevals.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import scala.compiletime.const
2+
13
object Test {
24

35
def power0(x: Double, inline n: Int): Double = ??? // error: inline modifier can only be used for parameters of inline methods
46

57
inline def power(x: Double, inline n: Int): Double = // ok
6-
inline if n == 0 then ??? else ???
8+
if requireConst(n == 0) then ??? else ???
79

810
inline val N = 10
911
def X = 20

tests/run/typeclass-derivation3.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ object typeclasses {
129129

130130
inline def unpickleCase[T, Elems <: Tuple](buf: mutable.ListBuffer[Int], m: Mirror.ProductOf[T]): T = {
131131
inline val size = constValue[Tuple.Size[Elems]]
132-
inline if (size == 0)
132+
if requireConst(size == 0) then
133133
m.fromProduct(EmptyProduct)
134134
else {
135135
val elems = new ArrayProduct(size)

0 commit comments

Comments
 (0)