Skip to content

Commit d16bfac

Browse files
committed
Fix capturing conditions of HOAS quote patterns
We did not properly check that the scrutinee was closed under the environment minus the captured arguments.
1 parent 616308d commit d16bfac

File tree

4 files changed

+84
-16
lines changed

4 files changed

+84
-16
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -203,24 +203,33 @@ object QuoteMatcher {
203203
// Matches an open term and wraps it into a lambda that provides the free variables
204204
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
205205
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
206-
val names: List[TermName] = args.map {
207-
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
208-
case arg => arg.symbol.name.asTermName
206+
def hoasClosure = {
207+
val names: List[TermName] = args.map {
208+
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
209+
case arg => arg.symbol.name.asTermName
210+
}
211+
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
212+
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
213+
val meth = newAnonFun(ctx.owner, methTpe)
214+
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
215+
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
216+
val body = new TreeMap {
217+
override def transform(tree: Tree)(using Context): Tree =
218+
tree match
219+
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
220+
case tree => super.transform(tree)
221+
}.transform(scrutinee)
222+
TreeOps(body).changeNonLocalOwners(meth)
223+
}
224+
Closure(meth, bodyFn)
209225
}
210-
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
211-
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
212-
val meth = newAnonFun(ctx.owner, methTpe)
213-
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
214-
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
215-
val body = new TreeMap {
216-
override def transform(tree: Tree)(using Context): Tree =
217-
tree match
218-
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
219-
case tree => super.transform(tree)
220-
}.transform(scrutinee)
221-
TreeOps(body).changeNonLocalOwners(meth)
226+
val capturedArgs = args.map(_.symbol)
227+
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
228+
withEnv(captureEnv) {
229+
scrutinee match
230+
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
231+
case _ => notMatched
222232
}
223-
matched(Closure(meth, bodyFn))
224233

225234
/* Match type ascription (b) */
226235
case Typed(expr2, _) =>
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(42: scala.Int)
2+
((x: scala.Int) => x.+(x)).apply(1)
3+
4+
(42: scala.Int)
5+
((x: scala.Int) => x.+(x)).apply(1)
6+
((y: scala.Int) => y.+(y)).apply(2)
7+
((x: scala.Int, y: scala.Int) => x.+(y)).apply(1, 2)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted.*
2+
3+
inline def test1(inline x: Int): String = ${ test1Expr('x) }
4+
inline def test2(inline x: Int): String = ${ test2Expr('x) }
5+
6+
private def test1Expr(x: Expr[Int])(using Quotes) : Expr[String] =
7+
x match
8+
case '{ val x: Int = 1; $z: Int } => Expr(z.show)
9+
case '{ val x: Int = 1; $z(x): Int } => Expr('{$z(1)}.show)
10+
case _ => '{"No match"}
11+
12+
private def test2Expr(x: Expr[Int])(using Quotes) : Expr[String] =
13+
x match
14+
case '{ val x: Int = 1; val y: Int = 2; $z: Int } => Expr(z.show)
15+
case '{ val x: Int = 1; val y: Int = 2; $z(x): Int } => Expr('{$z(1)}.show)
16+
case '{ val x: Int = 1; val y: Int = 2; $z(y): Int } => Expr('{$z(2)}.show)
17+
case '{ val x: Int = 1; val y: Int = 2; $z(x,y): Int } => Expr('{$z(1, 2)}.show)
18+
case _ => '{"No match"}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
@main def Test =
2+
println(test1 {
3+
val a: Int = 1
4+
42: Int
5+
})
6+
println(test1 {
7+
val a: Int = 1
8+
a + a
9+
})
10+
11+
println()
12+
13+
println(test2 {
14+
val a: Int = 1
15+
val b: Int = 2
16+
42: Int
17+
})
18+
println(test2 {
19+
val a: Int = 1
20+
val b: Int = 2
21+
a + a
22+
})
23+
24+
println(test2 {
25+
val a: Int = 1
26+
val b: Int = 2
27+
b + b
28+
})
29+
30+
println(test2 {
31+
val a: Int = 1
32+
val b: Int = 2
33+
a + b
34+
})

0 commit comments

Comments
 (0)