@@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Types.*
11
11
import dotty .tools .dotc .core .StdNames .nme
12
12
import dotty .tools .dotc .core .Symbols .*
13
13
import dotty .tools .dotc .util .optional
14
+ import dotty .tools .dotc .core .Definitions
14
15
15
16
/** Matches a quoted tree against a quoted pattern tree.
16
17
* A quoted pattern tree may have type and term holes in addition to normal terms.
@@ -259,12 +260,34 @@ object QuoteMatcher {
259
260
// Matches an open term and wraps it into a lambda that provides the free variables
260
261
case Apply (TypeApply (Ident (_), List (TypeTree ())), SeqLiteral (args, _) :: Nil )
261
262
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole ) =>
263
+
264
+ /* Some of method symbols in arguments of higher-order term hole are eta-expanded.
265
+ * e.g.
266
+ * g: (Int) => Int
267
+ * => {
268
+ * def $anonfun(y: Int): Int = g(y)
269
+ * closure($anonfun)
270
+ * }
271
+ *
272
+ * f: (using Int) => Int
273
+ * => f(using x)
274
+ * This function restores the symbol of the original method from
275
+ * the eta-expanded function.
276
+ */
277
+ def getCapturedIdent (arg : Tree )(using Context ): Ident =
278
+ arg match
279
+ case id : Ident => id
280
+ case Apply (fun, _) => getCapturedIdent(fun)
281
+ case Block ((ddef : DefDef ) :: _, _ : Closure ) => getCapturedIdent(ddef.rhs)
282
+ case Typed (expr, _) => getCapturedIdent(expr)
283
+
262
284
val env = summon[Env ]
263
- val capturedArgs = args.map(_.symbol)
264
- val captureEnv = env.filter((k, v) => ! capturedArgs.contains(v))
285
+ val capturedIds = args.map(getCapturedIdent)
286
+ val capturedSymbols = capturedIds.map(_.symbol)
287
+ val captureEnv = env.filter((k, v) => ! capturedSymbols.contains(v))
265
288
withEnv(captureEnv) {
266
289
scrutinee match
267
- case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
290
+ case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe) , env)
268
291
case _ => notMatched
269
292
}
270
293
@@ -394,19 +417,34 @@ object QuoteMatcher {
394
417
case scrutinee @ DefDef (_, paramss1, tpt1, _) =>
395
418
pattern match
396
419
case pattern @ DefDef (_, paramss2, tpt2, _) =>
397
- def rhsEnv : Env =
398
- val paramSyms : List [(Symbol , Symbol )] =
399
- for
400
- (clause1, clause2) <- paramss1.zip(paramss2)
401
- (param1, param2) <- clause1.zip(clause2)
402
- yield
403
- param1.symbol -> param2.symbol
404
- val oldEnv : Env = summon[Env ]
405
- val newEnv : List [(Symbol , Symbol )] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
406
- oldEnv ++ newEnv
407
- matchLists(paramss1, paramss2)(_ =?= _)
408
- &&& tpt1 =?= tpt2
409
- &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
420
+ def matchErasedParams (sctype : Type , pttype : Type ): optional[MatchingExprs ] =
421
+ (sctype, pttype) match
422
+ case (sctpe : MethodType , pttpe : MethodType ) =>
423
+ if sctpe.erasedParams.sameElements(pttpe.erasedParams) then
424
+ matchErasedParams(sctpe.resType, pttpe.resType)
425
+ else
426
+ notMatched
427
+ case _ => matched
428
+
429
+ def matchParamss (scparamss : List [ParamClause ], ptparamss : List [ParamClause ])(using Env ): optional[(Env , MatchingExprs )] =
430
+ (scparamss, ptparamss) match {
431
+ case (scparams :: screst, ptparams :: ptrest) =>
432
+ val mr1 = matchLists(scparams, ptparams)(_ =?= _)
433
+ val newEnv = summon[Env ] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
434
+ val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
435
+ (resEnv, mr1 &&& mrrest)
436
+ case (Nil , Nil ) => (summon[Env ], matched)
437
+ case _ => notMatched
438
+ }
439
+
440
+ val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
441
+ val (pEnv, pmatch) = matchParamss(paramss1, paramss2)
442
+ val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)
443
+
444
+ ematch
445
+ &&& pmatch
446
+ &&& withEnv(defEnv)(tpt1 =?= tpt2)
447
+ &&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs)
410
448
case _ => notMatched
411
449
412
450
case Closure (_, _, tpt1) =>
@@ -497,10 +535,14 @@ object QuoteMatcher {
497
535
*
498
536
* @param tree Scrutinee sub-tree that matched
499
537
* @param patternTpe Type of the pattern hole (from the pattern)
500
- * @param args HOAS arguments (from the pattern)
538
+ * @param argIds Identifiers of HOAS arguments (from the pattern)
539
+ * @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
501
540
* @param env Mapping between scrutinee and pattern variables
502
541
*/
503
- case OpenTree (tree : Tree , patternTpe : Type , args : List [Tree ], env : Env )
542
+ case OpenTree (tree : Tree , patternTpe : Type , argIds : List [Tree ], argTypes : List [Type ], env : Env )
543
+
544
+ /** The Definitions object */
545
+ def defn (using Context ): Definitions = ctx.definitions
504
546
505
547
/** Return the expression that was extracted from a hole.
506
548
*
@@ -513,19 +555,22 @@ object QuoteMatcher {
513
555
def toExpr (mapTypeHoles : Type => Type , spliceScope : Scope )(using Context ): Expr [Any ] = this match
514
556
case MatchResult .ClosedTree (tree) =>
515
557
new ExprImpl (tree, spliceScope)
516
- case MatchResult .OpenTree (tree, patternTpe, args, env) =>
517
- val names : List [TermName ] = args.map {
518
- case Block (List (DefDef (nme.ANON_FUN , _, _, Apply (Ident (name), _))), _) => name.asTermName
519
- case arg => arg.symbol.name.asTermName
520
- }
521
- val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
558
+ case MatchResult .OpenTree (tree, patternTpe, argIds, argTypes, env) =>
559
+ val names : List [TermName ] = argIds.map(_.symbol.name.asTermName)
560
+ val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
522
561
val methTpe = MethodType (names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
523
562
val meth = newAnonFun(ctx.owner, methTpe)
524
563
def bodyFn (lambdaArgss : List [List [Tree ]]): Tree = {
525
- val argsMap = args .view.map(_.symbol).zip(lambdaArgss.head).toMap
564
+ val argsMap = argIds .view.map(_.symbol).zip(lambdaArgss.head).toMap
526
565
val body = new TreeMap {
527
566
override def transform (tree : Tree )(using Context ): Tree =
528
567
tree match
568
+ /*
569
+ * When matching a method call `f(0)` against a HOAS pattern `p(g)` where
570
+ * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
571
+ * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
572
+ */
573
+ case Apply (fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args)
529
574
case tree : Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
530
575
case tree => super .transform(tree)
531
576
}.transform(tree)
@@ -534,7 +579,7 @@ object QuoteMatcher {
534
579
val hoasClosure = Closure (meth, bodyFn)
535
580
new ExprImpl (hoasClosure, spliceScope)
536
581
537
- private inline def notMatched : optional[MatchingExprs ] =
582
+ private inline def notMatched [ T ] : optional[T ] =
538
583
optional.break()
539
584
540
585
private inline def matched : MatchingExprs =
@@ -543,8 +588,8 @@ object QuoteMatcher {
543
588
private inline def matched (tree : Tree )(using Context ): MatchingExprs =
544
589
Seq (MatchResult .ClosedTree (tree))
545
590
546
- private def matchedOpen (tree : Tree , patternTpe : Type , args : List [Tree ], env : Env )(using Context ): MatchingExprs =
547
- Seq (MatchResult .OpenTree (tree, patternTpe, args , env))
591
+ private def matchedOpen (tree : Tree , patternTpe : Type , argIds : List [Tree ], argTypes : List [ Type ], env : Env )(using Context ): MatchingExprs =
592
+ Seq (MatchResult .OpenTree (tree, patternTpe, argIds, argTypes , env))
548
593
549
594
extension (self : MatchingExprs )
550
595
/** Concatenates the contents of two successful matchings */
0 commit comments