Skip to content

Commit 0d8bdd6

Browse files
Make HOAS Quote pattern match with def method capture (#17567)
This PR will fix #17105 by extracting symbols from eta-expanded identifiers. This fix enables the use of patterns such as ```scala case '{ def f(...): T = ...; $g(f): U } => ``` where `g` will match any expression that may contain references to `f`.
2 parents 4dbcf09 + d819d9f commit 0d8bdd6

File tree

8 files changed

+187
-28
lines changed

8 files changed

+187
-28
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 69 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,34 @@ object QuoteMatcher {
259259
// Matches an open term and wraps it into a lambda that provides the free variables
260260
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
261261
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
262+
263+
/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
264+
* e.g.
265+
* g: (Int) => Int
266+
* => {
267+
* def $anonfun(y: Int): Int = g(y)
268+
* closure($anonfun)
269+
* }
270+
*
271+
* f: (using Int) => Int
272+
* => f(using x)
273+
* This function restores the symbol of the original method from
274+
* the eta-expanded function.
275+
*/
276+
def getCapturedIdent(arg: Tree)(using Context): Ident =
277+
arg match
278+
case id: Ident => id
279+
case Apply(fun, _) => getCapturedIdent(fun)
280+
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
281+
case Typed(expr, _) => getCapturedIdent(expr)
282+
262283
val env = summon[Env]
263-
val capturedArgs = args.map(_.symbol)
264-
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
284+
val capturedIds = args.map(getCapturedIdent)
285+
val capturedSymbols = capturedIds.map(_.symbol)
286+
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
265287
withEnv(captureEnv) {
266288
scrutinee match
267-
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
289+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env)
268290
case _ => notMatched
269291
}
270292

@@ -394,19 +416,34 @@ object QuoteMatcher {
394416
case scrutinee @ DefDef(_, paramss1, tpt1, _) =>
395417
pattern match
396418
case pattern @ DefDef(_, paramss2, tpt2, _) =>
397-
def rhsEnv: Env =
398-
val paramSyms: List[(Symbol, Symbol)] =
399-
for
400-
(clause1, clause2) <- paramss1.zip(paramss2)
401-
(param1, param2) <- clause1.zip(clause2)
402-
yield
403-
param1.symbol -> param2.symbol
404-
val oldEnv: Env = summon[Env]
405-
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
406-
oldEnv ++ newEnv
407-
matchLists(paramss1, paramss2)(_ =?= _)
408-
&&& tpt1 =?= tpt2
409-
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
419+
def matchErasedParams(sctype: Type, pttype: Type): optional[MatchingExprs] =
420+
(sctype, pttype) match
421+
case (sctpe: MethodType, pttpe: MethodType) =>
422+
if sctpe.erasedParams.sameElements(pttpe.erasedParams) then
423+
matchErasedParams(sctpe.resType, pttpe.resType)
424+
else
425+
notMatched
426+
case _ => matched
427+
428+
def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] =
429+
(scparamss, ptparamss) match {
430+
case (scparams :: screst, ptparams :: ptrest) =>
431+
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
432+
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
433+
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
434+
(resEnv, mr1 &&& mrrest)
435+
case (Nil, Nil) => (summon[Env], matched)
436+
case _ => notMatched
437+
}
438+
439+
val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
440+
val (pEnv, pmatch) = matchParamss(paramss1, paramss2)
441+
val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)
442+
443+
ematch
444+
&&& pmatch
445+
&&& withEnv(defEnv)(tpt1 =?= tpt2)
446+
&&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs)
410447
case _ => notMatched
411448

412449
case Closure(_, _, tpt1) =>
@@ -497,10 +534,11 @@ object QuoteMatcher {
497534
*
498535
* @param tree Scrutinee sub-tree that matched
499536
* @param patternTpe Type of the pattern hole (from the pattern)
500-
* @param args HOAS arguments (from the pattern)
537+
* @param argIds Identifiers of HOAS arguments (from the pattern)
538+
* @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
501539
* @param env Mapping between scrutinee and pattern variables
502540
*/
503-
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
541+
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)
504542

505543
/** Return the expression that was extracted from a hole.
506544
*
@@ -513,19 +551,22 @@ object QuoteMatcher {
513551
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match
514552
case MatchResult.ClosedTree(tree) =>
515553
new ExprImpl(tree, spliceScope)
516-
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
517-
val names: List[TermName] = args.map {
518-
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
519-
case arg => arg.symbol.name.asTermName
520-
}
521-
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
554+
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) =>
555+
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
556+
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
522557
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
523558
val meth = newAnonFun(ctx.owner, methTpe)
524559
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
525-
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
560+
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
526561
val body = new TreeMap {
527562
override def transform(tree: Tree)(using Context): Tree =
528563
tree match
564+
/*
565+
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
566+
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
567+
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
568+
*/
569+
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args)
529570
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
530571
case tree => super.transform(tree)
531572
}.transform(tree)
@@ -534,7 +575,7 @@ object QuoteMatcher {
534575
val hoasClosure = Closure(meth, bodyFn)
535576
new ExprImpl(hoasClosure, spliceScope)
536577

537-
private inline def notMatched: optional[MatchingExprs] =
578+
private inline def notMatched[T]: optional[T] =
538579
optional.break()
539580

540581
private inline def matched: MatchingExprs =
@@ -543,8 +584,8 @@ object QuoteMatcher {
543584
private inline def matched(tree: Tree)(using Context): MatchingExprs =
544585
Seq(MatchResult.ClosedTree(tree))
545586

546-
private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
547-
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))
587+
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs =
588+
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env))
548589

549590
extension (self: MatchingExprs)
550591
/** Concatenates the contents of two successful matchings */
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
case erased: [erased case]
2+
case erased nested: c
3+
case erased nested 2: d
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import scala.quoted.*
2+
3+
inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
4+
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
5+
body match
6+
// Erased Types
7+
case '{ def erasedfn(y: String) = "placeholder"; $a(erasedfn): String } =>
8+
Expr("This case should not match")
9+
case '{ def erasedfn(erased y: String) = "placeholder"; $a(erasedfn): String } =>
10+
'{ $a((erased z: String) => "[erased case]") }
11+
case '{
12+
def erasedfn(a: String, b: String)(c: String, d: String): String = a
13+
$y(erasedfn): String
14+
} => Expr("This should not match")
15+
case '{
16+
def erasedfn(a: String, erased b: String)(erased c: String, d: String): String = a
17+
$y(erasedfn): String
18+
} =>
19+
'{ $y((a: String, erased b: String) => (erased c: String, d: String) => d) }
20+
case '{
21+
def erasedfn(a: String, erased b: String)(c: String, erased d: String): String = a
22+
$y(erasedfn): String
23+
} =>
24+
'{ $y((a: String, erased b: String) => (c: String, erased d: String) => c) }
25+
case _ => Expr("not matched")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@main def Test: Unit =
2+
println("case erased: " + testExpr { def erasedfn1(erased x: String) = "placeholder"; erasedfn1("arg1")})
3+
println("case erased nested: " + testExpr {
4+
def erasedfn2(p: String, erased q: String)(r: String, erased s: String) = p
5+
erasedfn2("a", "b")("c", "d")
6+
})
7+
println("case erased nested 2: " + testExpr {
8+
def erasedfn2(p: String, erased q: String)(erased r: String, s: String) = p
9+
erasedfn2("a", "b")("c", "d")
10+
})

tests/run-macros/i17105.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
case single: [1st case] arg1 outside
2+
case no-param-method (will be eta-expanded): [1st case] placeholder 2
3+
case curried: [2nd case] arg1, arg2 outside
4+
case methods from outer scope: [1st case] arg1 outer-method
5+
case refinement: Hoe got 1
6+
case dependent: 1
7+
case dependent2: 1
8+
case dependent3: 1

tests/run-macros/i17105/Lib1.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
// Test case for dependent types
3+
trait DSL {
4+
type N
5+
def toString(n: N): String
6+
val zero: N
7+
def next(n: N): N
8+
}
9+
10+
object IntDSL extends DSL {
11+
type N = Int
12+
def toString(n: N): String = n.toString()
13+
val zero = 0
14+
def next(n: N): N = n + 1
15+
}

tests/run-macros/i17105/Macro_2.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.quoted.*
2+
import language.experimental.erasedDefinitions
3+
4+
inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
5+
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
6+
body match
7+
case '{ def g(y: String) = "placeholder" + y; $a(g): String } =>
8+
'{ $a((z: String) => s"[1st case] ${z}") }
9+
case '{ def g(y: String)(z: String) = "placeholder" + y; $a(g): String } =>
10+
'{ $a((z1: String) => (z2: String) => s"[2nd case] ${z1}, ${z2}") }
11+
// Refined Types
12+
case '{
13+
type t
14+
def refined(a: `t`): String = $x(a): String
15+
$y(refined): String
16+
} =>
17+
'{ $y($x) }
18+
// Dependent Types
19+
case '{
20+
def p(dsl: DSL): dsl.N = dsl.zero
21+
$y(p): String
22+
} =>
23+
'{ $y((dsl1: DSL) => dsl1.next(dsl1.zero)) }
24+
case '{
25+
def p(dsl: DSL)(a: dsl.N): dsl.N = a
26+
$y(p): String
27+
} =>
28+
'{ $y((dsl: DSL) => (b2: dsl.N) => dsl.next(b2)) }
29+
case '{
30+
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
31+
$y(p): String
32+
} =>
33+
'{ $y((dsl1: DSL) => (dsl2: DSL) => dsl2.next(dsl2.zero)) }
34+
case _ => Expr("not matched")

tests/run-macros/i17105/Test_3.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import reflect.Selectable.reflectiveSelectable
2+
3+
class Hoe { def f(x: Int): String = s"Hoe got ${x}" }
4+
5+
@main def Test: Unit =
6+
println("case single: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + " outside" })
7+
println("case no-param-method (will be eta-expanded): " + testExpr { def f(x: String) = "placeholder" + x; (() => f)()("placeholder 2") })
8+
println("case curried: " + testExpr { def f(x: String)(y: String) = "placeholder" + x; f("arg1")("arg2") + " outside" })
9+
def outer() = " outer-method"
10+
println("case methods from outer scope: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + outer() })
11+
println("case refinement: " + testExpr { def refined(a: { def f(x: Int): String }): String = a.f(1); refined(Hoe()) })
12+
println("case dependent: " + testExpr {
13+
def p(a: DSL): a.N = a.zero
14+
IntDSL.toString(p(IntDSL))
15+
})
16+
println("case dependent2: " + testExpr {
17+
def p(dsl1: DSL)(c: dsl1.N): dsl1.N = c
18+
IntDSL.toString(p(IntDSL)(IntDSL.zero))
19+
})
20+
println("case dependent3: " + testExpr {
21+
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
22+
IntDSL.toString(p(IntDSL)(IntDSL))
23+
})

0 commit comments

Comments
 (0)