-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Make HOAS Quote pattern match with def method capture #17567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -259,12 +259,34 @@ object QuoteMatcher { | |
// Matches an open term and wraps it into a lambda that provides the free variables | ||
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil) | ||
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) => | ||
|
||
/* Some of method symbols in arguments of higher-order term hole are eta-expanded. | ||
* e.g. | ||
* g: (Int) => Int | ||
* => { | ||
* def $anonfun(y: Int): Int = g(y) | ||
* closure($anonfun) | ||
* } | ||
* | ||
* f: (using Int) => Int | ||
* => f(using x) | ||
* This function restores the symbol of the original method from | ||
* the eta-expanded function. | ||
*/ | ||
def getCapturedIdent(arg: Tree)(using Context): Ident = | ||
arg match | ||
case id: Ident => id | ||
case Apply(fun, _) => getCapturedIdent(fun) | ||
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs) | ||
case Typed(expr, _) => getCapturedIdent(expr) | ||
|
||
val env = summon[Env] | ||
val capturedArgs = args.map(_.symbol) | ||
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v)) | ||
val capturedIds = args.map(getCapturedIdent) | ||
val capturedSymbols = capturedIds.map(_.symbol) | ||
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v)) | ||
withEnv(captureEnv) { | ||
scrutinee match | ||
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env) | ||
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env) | ||
case _ => notMatched | ||
} | ||
|
||
|
@@ -394,19 +416,34 @@ object QuoteMatcher { | |
case scrutinee @ DefDef(_, paramss1, tpt1, _) => | ||
pattern match | ||
case pattern @ DefDef(_, paramss2, tpt2, _) => | ||
def rhsEnv: Env = | ||
val paramSyms: List[(Symbol, Symbol)] = | ||
for | ||
(clause1, clause2) <- paramss1.zip(paramss2) | ||
(param1, param2) <- clause1.zip(clause2) | ||
yield | ||
param1.symbol -> param2.symbol | ||
val oldEnv: Env = summon[Env] | ||
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms | ||
oldEnv ++ newEnv | ||
matchLists(paramss1, paramss2)(_ =?= _) | ||
&&& tpt1 =?= tpt2 | ||
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs) | ||
def matchErasedParams(sctype: Type, pttype: Type): optional[MatchingExprs] = | ||
(sctype, pttype) match | ||
case (sctpe: MethodType, pttpe: MethodType) => | ||
if sctpe.erasedParams.sameElements(pttpe.erasedParams) then | ||
matchErasedParams(sctpe.resType, pttpe.resType) | ||
else | ||
notMatched | ||
case _ => matched | ||
|
||
def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] = | ||
zeptometer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(scparamss, ptparamss) match { | ||
case (scparams :: screst, ptparams :: ptrest) => | ||
val mr1 = matchLists(scparams, ptparams)(_ =?= _) | ||
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)) | ||
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest)) | ||
(resEnv, mr1 &&& mrrest) | ||
case (Nil, Nil) => (summon[Env], matched) | ||
case _ => notMatched | ||
} | ||
|
||
val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr) | ||
val (pEnv, pmatch) = matchParamss(paramss1, paramss2) | ||
val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol) | ||
|
||
ematch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
&&& pmatch | ||
&&& withEnv(defEnv)(tpt1 =?= tpt2) | ||
&&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs) | ||
case _ => notMatched | ||
|
||
case Closure(_, _, tpt1) => | ||
|
@@ -497,10 +534,11 @@ object QuoteMatcher { | |
* | ||
* @param tree Scrutinee sub-tree that matched | ||
* @param patternTpe Type of the pattern hole (from the pattern) | ||
* @param args HOAS arguments (from the pattern) | ||
* @param argIds Identifiers of HOAS arguments (from the pattern) | ||
* @param argTypes Eta-expanded types of HOAS arguments (from the pattern) | ||
* @param env Mapping between scrutinee and pattern variables | ||
*/ | ||
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env) | ||
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env) | ||
|
||
/** Return the expression that was extracted from a hole. | ||
* | ||
|
@@ -513,19 +551,22 @@ object QuoteMatcher { | |
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match | ||
case MatchResult.ClosedTree(tree) => | ||
new ExprImpl(tree, spliceScope) | ||
case MatchResult.OpenTree(tree, patternTpe, args, env) => | ||
val names: List[TermName] = args.map { | ||
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName | ||
case arg => arg.symbol.name.asTermName | ||
} | ||
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr)) | ||
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) => | ||
val names: List[TermName] = argIds.map(_.symbol.name.asTermName) | ||
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) | ||
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) | ||
val meth = newAnonFun(ctx.owner, methTpe) | ||
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { | ||
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap | ||
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap | ||
val body = new TreeMap { | ||
override def transform(tree: Tree)(using Context): Tree = | ||
tree match | ||
/* | ||
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where | ||
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold | ||
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. | ||
*/ | ||
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args) | ||
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) | ||
case tree => super.transform(tree) | ||
}.transform(tree) | ||
|
@@ -534,7 +575,7 @@ object QuoteMatcher { | |
val hoasClosure = Closure(meth, bodyFn) | ||
new ExprImpl(hoasClosure, spliceScope) | ||
|
||
private inline def notMatched: optional[MatchingExprs] = | ||
private inline def notMatched[T]: optional[T] = | ||
optional.break() | ||
|
||
private inline def matched: MatchingExprs = | ||
|
@@ -543,8 +584,8 @@ object QuoteMatcher { | |
private inline def matched(tree: Tree)(using Context): MatchingExprs = | ||
Seq(MatchResult.ClosedTree(tree)) | ||
|
||
private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs = | ||
Seq(MatchResult.OpenTree(tree, patternTpe, args, env)) | ||
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs = | ||
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env)) | ||
|
||
extension (self: MatchingExprs) | ||
/** Concatenates the contents of two successful matchings */ | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
case erased: [erased case] | ||
case erased nested: c | ||
case erased nested 2: d |
25 changes: 25 additions & 0 deletions
25
tests/run-custom-args/run-macros-erased/i17105/Macro_1.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import scala.quoted.* | ||
|
||
inline def testExpr(inline body: Any) = ${ testExprImpl('body) } | ||
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] = | ||
body match | ||
// Erased Types | ||
case '{ def erasedfn(y: String) = "placeholder"; $a(erasedfn): String } => | ||
Expr("This case should not match") | ||
case '{ def erasedfn(erased y: String) = "placeholder"; $a(erasedfn): String } => | ||
'{ $a((erased z: String) => "[erased case]") } | ||
case '{ | ||
def erasedfn(a: String, b: String)(c: String, d: String): String = a | ||
$y(erasedfn): String | ||
} => Expr("This should not match") | ||
case '{ | ||
def erasedfn(a: String, erased b: String)(erased c: String, d: String): String = a | ||
$y(erasedfn): String | ||
} => | ||
'{ $y((a: String, erased b: String) => (erased c: String, d: String) => d) } | ||
case '{ | ||
def erasedfn(a: String, erased b: String)(c: String, erased d: String): String = a | ||
$y(erasedfn): String | ||
} => | ||
'{ $y((a: String, erased b: String) => (c: String, erased d: String) => c) } | ||
case _ => Expr("not matched") |
10 changes: 10 additions & 0 deletions
10
tests/run-custom-args/run-macros-erased/i17105/Test_2.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
@main def Test: Unit = | ||
println("case erased: " + testExpr { def erasedfn1(erased x: String) = "placeholder"; erasedfn1("arg1")}) | ||
println("case erased nested: " + testExpr { | ||
def erasedfn2(p: String, erased q: String)(r: String, erased s: String) = p | ||
erasedfn2("a", "b")("c", "d") | ||
}) | ||
println("case erased nested 2: " + testExpr { | ||
def erasedfn2(p: String, erased q: String)(erased r: String, s: String) = p | ||
erasedfn2("a", "b")("c", "d") | ||
}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
case single: [1st case] arg1 outside | ||
case no-param-method (will be eta-expanded): [1st case] placeholder 2 | ||
case curried: [2nd case] arg1, arg2 outside | ||
case methods from outer scope: [1st case] arg1 outer-method | ||
case refinement: Hoe got 1 | ||
case dependent: 1 | ||
case dependent2: 1 | ||
case dependent3: 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
// Test case for dependent types | ||
trait DSL { | ||
type N | ||
def toString(n: N): String | ||
val zero: N | ||
def next(n: N): N | ||
} | ||
|
||
object IntDSL extends DSL { | ||
type N = Int | ||
def toString(n: N): String = n.toString() | ||
val zero = 0 | ||
def next(n: N): N = n + 1 | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import scala.quoted.* | ||
import language.experimental.erasedDefinitions | ||
|
||
inline def testExpr(inline body: Any) = ${ testExprImpl('body) } | ||
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] = | ||
body match | ||
case '{ def g(y: String) = "placeholder" + y; $a(g): String } => | ||
'{ $a((z: String) => s"[1st case] ${z}") } | ||
case '{ def g(y: String)(z: String) = "placeholder" + y; $a(g): String } => | ||
'{ $a((z1: String) => (z2: String) => s"[2nd case] ${z1}, ${z2}") } | ||
// Refined Types | ||
case '{ | ||
type t | ||
def refined(a: `t`): String = $x(a): String | ||
$y(refined): String | ||
} => | ||
'{ $y($x) } | ||
// Dependent Types | ||
case '{ | ||
def p(dsl: DSL): dsl.N = dsl.zero | ||
$y(p): String | ||
} => | ||
'{ $y((dsl1: DSL) => dsl1.next(dsl1.zero)) } | ||
case '{ | ||
def p(dsl: DSL)(a: dsl.N): dsl.N = a | ||
$y(p): String | ||
} => | ||
'{ $y((dsl: DSL) => (b2: dsl.N) => dsl.next(b2)) } | ||
case '{ | ||
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero | ||
$y(p): String | ||
} => | ||
'{ $y((dsl1: DSL) => (dsl2: DSL) => dsl2.next(dsl2.zero)) } | ||
case _ => Expr("not matched") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import reflect.Selectable.reflectiveSelectable | ||
|
||
class Hoe { def f(x: Int): String = s"Hoe got ${x}" } | ||
|
||
@main def Test: Unit = | ||
println("case single: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + " outside" }) | ||
println("case no-param-method (will be eta-expanded): " + testExpr { def f(x: String) = "placeholder" + x; (() => f)()("placeholder 2") }) | ||
println("case curried: " + testExpr { def f(x: String)(y: String) = "placeholder" + x; f("arg1")("arg2") + " outside" }) | ||
def outer() = " outer-method" | ||
println("case methods from outer scope: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + outer() }) | ||
println("case refinement: " + testExpr { def refined(a: { def f(x: Int): String }): String = a.f(1); refined(Hoe()) }) | ||
println("case dependent: " + testExpr { | ||
def p(a: DSL): a.N = a.zero | ||
IntDSL.toString(p(IntDSL)) | ||
}) | ||
println("case dependent2: " + testExpr { | ||
def p(dsl1: DSL)(c: dsl1.N): dsl1.N = c | ||
IntDSL.toString(p(IntDSL)(IntDSL.zero)) | ||
}) | ||
println("case dependent3: " + testExpr { | ||
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero | ||
IntDSL.toString(p(IntDSL)(IntDSL)) | ||
}) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getCapturedIdent
could return theSymbol
directly. This way, we can avoid this extramap
.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I intended to use
capturedIds
as a parameter tomatchedOpen
and we cannot omit it (we get compiler errors if we useargs
formatchedOpen
instead).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to your other comment on i17105.check, it's likely I need to change this part anyway.