Skip to content

Commit 3af515d

Browse files
committed
Make HOAS Quote pattern match with def method capture
closes #17105
1 parent a68568c commit 3af515d

File tree

8 files changed

+191
-28
lines changed

8 files changed

+191
-28
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Types.*
1111
import dotty.tools.dotc.core.StdNames.nme
1212
import dotty.tools.dotc.core.Symbols.*
1313
import dotty.tools.dotc.util.optional
14+
import dotty.tools.dotc.core.Definitions
1415

1516
/** Matches a quoted tree against a quoted pattern tree.
1617
* A quoted pattern tree may have type and term holes in addition to normal terms.
@@ -259,12 +260,34 @@ object QuoteMatcher {
259260
// Matches an open term and wraps it into a lambda that provides the free variables
260261
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
261262
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
263+
264+
/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
265+
* e.g.
266+
* g: (Int) => Int
267+
* => {
268+
* def $anonfun(y: Int): Int = g(y)
269+
* closure($anonfun)
270+
* }
271+
*
272+
* f: (using Int) => Int
273+
* => f(using x)
274+
* This function restores the symbol of the original method from
275+
* the eta-expanded function.
276+
*/
277+
def getCapturedIdent(arg: Tree)(using Context): Ident =
278+
arg match
279+
case id: Ident => id
280+
case Apply(fun, _) => getCapturedIdent(fun)
281+
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
282+
case Typed(expr, _) => getCapturedIdent(expr)
283+
262284
val env = summon[Env]
263-
val capturedArgs = args.map(_.symbol)
264-
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
285+
val capturedIds = args.map(getCapturedIdent)
286+
val capturedSymbols = capturedIds.map(_.symbol)
287+
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
265288
withEnv(captureEnv) {
266289
scrutinee match
267-
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
290+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env)
268291
case _ => notMatched
269292
}
270293

@@ -394,19 +417,34 @@ object QuoteMatcher {
394417
case scrutinee @ DefDef(_, paramss1, tpt1, _) =>
395418
pattern match
396419
case pattern @ DefDef(_, paramss2, tpt2, _) =>
397-
def rhsEnv: Env =
398-
val paramSyms: List[(Symbol, Symbol)] =
399-
for
400-
(clause1, clause2) <- paramss1.zip(paramss2)
401-
(param1, param2) <- clause1.zip(clause2)
402-
yield
403-
param1.symbol -> param2.symbol
404-
val oldEnv: Env = summon[Env]
405-
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
406-
oldEnv ++ newEnv
407-
matchLists(paramss1, paramss2)(_ =?= _)
408-
&&& tpt1 =?= tpt2
409-
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
420+
def matchErasedParams(sctype: Type, pttype: Type): optional[MatchingExprs] =
421+
(sctype, pttype) match
422+
case (sctpe: MethodType, pttpe: MethodType) =>
423+
if sctpe.erasedParams.sameElements(pttpe.erasedParams) then
424+
matchErasedParams(sctpe.resType, pttpe.resType)
425+
else
426+
notMatched
427+
case _ => matched
428+
429+
def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] =
430+
(scparamss, ptparamss) match {
431+
case (scparams :: screst, ptparams :: ptrest) =>
432+
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
433+
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
434+
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
435+
(resEnv, mr1 &&& mrrest)
436+
case (Nil, Nil) => (summon[Env], matched)
437+
case _ => notMatched
438+
}
439+
440+
val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
441+
val (pEnv, pmatch) = matchParamss(paramss1, paramss2)
442+
val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)
443+
444+
ematch
445+
&&& pmatch
446+
&&& withEnv(defEnv)(tpt1 =?= tpt2)
447+
&&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs)
410448
case _ => notMatched
411449

412450
case Closure(_, _, tpt1) =>
@@ -497,10 +535,14 @@ object QuoteMatcher {
497535
*
498536
* @param tree Scrutinee sub-tree that matched
499537
* @param patternTpe Type of the pattern hole (from the pattern)
500-
* @param args HOAS arguments (from the pattern)
538+
* @param argIds Identifiers of HOAS arguments (from the pattern)
539+
* @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
501540
* @param env Mapping between scrutinee and pattern variables
502541
*/
503-
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
542+
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)
543+
544+
/** The Definitions object */
545+
def defn(using Context): Definitions = ctx.definitions
504546

505547
/** Return the expression that was extracted from a hole.
506548
*
@@ -513,19 +555,22 @@ object QuoteMatcher {
513555
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match
514556
case MatchResult.ClosedTree(tree) =>
515557
new ExprImpl(tree, spliceScope)
516-
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
517-
val names: List[TermName] = args.map {
518-
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
519-
case arg => arg.symbol.name.asTermName
520-
}
521-
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
558+
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) =>
559+
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
560+
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
522561
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
523562
val meth = newAnonFun(ctx.owner, methTpe)
524563
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
525-
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
564+
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
526565
val body = new TreeMap {
527566
override def transform(tree: Tree)(using Context): Tree =
528567
tree match
568+
/*
569+
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
570+
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
571+
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
572+
*/
573+
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args)
529574
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
530575
case tree => super.transform(tree)
531576
}.transform(tree)
@@ -534,7 +579,7 @@ object QuoteMatcher {
534579
val hoasClosure = Closure(meth, bodyFn)
535580
new ExprImpl(hoasClosure, spliceScope)
536581

537-
private inline def notMatched: optional[MatchingExprs] =
582+
private inline def notMatched[T]: optional[T] =
538583
optional.break()
539584

540585
private inline def matched: MatchingExprs =
@@ -543,8 +588,8 @@ object QuoteMatcher {
543588
private inline def matched(tree: Tree)(using Context): MatchingExprs =
544589
Seq(MatchResult.ClosedTree(tree))
545590

546-
private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
547-
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))
591+
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs =
592+
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env))
548593

549594
extension (self: MatchingExprs)
550595
/** Concatenates the contents of two successful matchings */
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
case erased: [erased case]
2+
case erased nested: c
3+
case erased nested 2: d
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import scala.quoted.*
2+
3+
inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
4+
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
5+
body match
6+
// Erased Types
7+
case '{ def erasedfn(y: String) = "placeholder"; $a(erasedfn): String } =>
8+
Expr("This case should not match")
9+
case '{ def erasedfn(erased y: String) = "placeholder"; $a(erasedfn): String } =>
10+
'{ $a((erased z: String) => "[erased case]") }
11+
case '{
12+
def erasedfn(a: String, b: String)(c: String, d: String): String = a
13+
$y(erasedfn): String
14+
} => Expr("This should not match")
15+
case '{
16+
def erasedfn(a: String, erased b: String)(erased c: String, d: String): String = a
17+
$y(erasedfn): String
18+
} =>
19+
'{ $y((a: String, erased b: String) => (erased c: String, d: String) => d) }
20+
case '{
21+
def erasedfn(a: String, erased b: String)(c: String, erased d: String): String = a
22+
$y(erasedfn): String
23+
} =>
24+
'{ $y((a: String, erased b: String) => (c: String, erased d: String) => c) }
25+
case _ => Expr("not matched")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@main def Test: Unit =
2+
println("case erased: " + testExpr { def erasedfn1(erased x: String) = "placeholder"; erasedfn1("arg1")})
3+
println("case erased nested: " + testExpr {
4+
def erasedfn2(p: String, erased q: String)(r: String, erased s: String) = p
5+
erasedfn2("a", "b")("c", "d")
6+
})
7+
println("case erased nested 2: " + testExpr {
8+
def erasedfn2(p: String, erased q: String)(erased r: String, s: String) = p
9+
erasedfn2("a", "b")("c", "d")
10+
})

tests/run-macros/i17105.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
case single: [1st case] arg1 outside
2+
case no-param-method (will be eta-expanded): [1st case] placeholder 2
3+
case curried: [2nd case] arg1, arg2 outside
4+
case methods from outer scope: [1st case] arg1 outer-method
5+
case refinement: Hoe got 1
6+
case dependent: 1
7+
case dependent2: 1
8+
case dependent3: 1

tests/run-macros/i17105/Lib1.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
// Test case for dependent types
3+
trait DSL {
4+
type N
5+
def toString(n: N): String
6+
val zero: N
7+
def next(n: N): N
8+
}
9+
10+
object IntDSL extends DSL {
11+
type N = Int
12+
def toString(n: N): String = n.toString()
13+
val zero = 0
14+
def next(n: N): N = n + 1
15+
}

tests/run-macros/i17105/Macro_2.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.quoted.*
2+
import language.experimental.erasedDefinitions
3+
4+
inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
5+
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
6+
body match
7+
case '{ def g(y: String) = "placeholder" + y; $a(g): String } =>
8+
'{ $a((z: String) => s"[1st case] ${z}") }
9+
case '{ def g(y: String)(z: String) = "placeholder" + y; $a(g): String } =>
10+
'{ $a((z1: String) => (z2: String) => s"[2nd case] ${z1}, ${z2}") }
11+
// Refined Types
12+
case '{
13+
type t
14+
def refined(a: `t`): String = $x(a): String
15+
$y(refined): String
16+
} =>
17+
'{ $y($x) }
18+
// Dependent Types
19+
case '{
20+
def p(dsl: DSL): dsl.N = dsl.zero
21+
$y(p): String
22+
} =>
23+
'{ $y((dsl1: DSL) => dsl1.next(dsl1.zero)) }
24+
case '{
25+
def p(dsl: DSL)(a: dsl.N): dsl.N = a
26+
$y(p): String
27+
} =>
28+
'{ $y((dsl: DSL) => (b2: dsl.N) => dsl.next(b2)) }
29+
case '{
30+
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
31+
$y(p): String
32+
} =>
33+
'{ $y((dsl1: DSL) => (dsl2: DSL) => dsl2.next(dsl2.zero)) }
34+
case _ => Expr("not matched")

tests/run-macros/i17105/Test_3.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import reflect.Selectable.reflectiveSelectable
2+
3+
class Hoe { def f(x: Int): String = s"Hoe got ${x}" }
4+
5+
@main def Test: Unit =
6+
println("case single: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + " outside" })
7+
println("case no-param-method (will be eta-expanded): " + testExpr { def f(x: String) = "placeholder" + x; (() => f)()("placeholder 2") })
8+
println("case curried: " + testExpr { def f(x: String)(y: String) = "placeholder" + x; f("arg1")("arg2") + " outside" })
9+
def outer() = " outer-method"
10+
println("case methods from outer scope: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + outer() })
11+
println("case refinement: " + testExpr { def refined(a: { def f(x: Int): String }): String = a.f(1); refined(Hoe()) })
12+
println("case dependent: " + testExpr {
13+
def p(a: DSL): a.N = a.zero
14+
IntDSL.toString(p(IntDSL))
15+
})
16+
println("case dependent2: " + testExpr {
17+
def p(dsl1: DSL)(c: dsl1.N): dsl1.N = c
18+
IntDSL.toString(p(IntDSL)(IntDSL.zero))
19+
})
20+
println("case dependent3: " + testExpr {
21+
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
22+
IntDSL.toString(p(IntDSL)(IntDSL))
23+
})

0 commit comments

Comments
 (0)