Skip to content

Commit 97af132

Browse files
committed
Avoid infinite recursion if typechecking non constant code string
1 parent 48d88bc commit 97af132

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,29 @@ object Inliner {
200200
/** Expand call to scala.compiletime.testing.typeChecks */
201201
def typeChecks(tree: Tree)(implicit ctx: Context): Tree = {
202202
assert(tree.symbol == defn.CompiletimeTesting_typeChecks)
203-
def getCodeArgValue(t: Tree): String = t match {
204-
case Literal(Constant(code: String)) => code
203+
def getCodeArgValue(t: Tree): Option[String] = t match {
204+
case Literal(Constant(code: String)) => Some(code)
205205
case Typed(t2, _) => getCodeArgValue(t2)
206206
case Inlined(_, Nil, t2) => getCodeArgValue(t2)
207207
case Block(Nil, t2) => getCodeArgValue(t2)
208+
case _ => None
208209
}
209210
val Apply(_, codeArg :: Nil) = tree
210-
val code = getCodeArgValue(codeArg.underlyingArgument)
211-
val ctx2 = ctx.fresh.setNewTyperState().setTyper(new Typer)
212-
val tree2 = new Parser(SourceFile.virtual("tasty-reflect", code))(ctx2).block()
213-
val res =
214-
if (ctx2.reporter.hasErrors) false
215-
else {
216-
ctx2.typer.typed(tree2)(ctx2)
217-
!ctx2.reporter.hasErrors
218-
}
219-
Literal(Constant(res))
211+
getCodeArgValue(codeArg.underlyingArgument) match {
212+
case Some(code) =>
213+
val ctx2 = ctx.fresh.setNewTyperState().setTyper(new Typer)
214+
val tree2 = new Parser(SourceFile.virtual("tasty-reflect", code))(ctx2).block()
215+
val res =
216+
if (ctx2.reporter.hasErrors) false
217+
else {
218+
ctx2.typer.typed(tree2)(ctx2)
219+
!ctx2.reporter.hasErrors
220+
}
221+
Literal(Constant(res))
222+
case _ =>
223+
EmptyTree
224+
}
225+
220226
}
221227

222228
}

tests/neg/typeChecks.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
import scala.compiletime.testing.typeChecks
3+
4+
object Test {
5+
6+
def f(s: String) = typeChecks(s) // error
7+
}

0 commit comments

Comments
 (0)