@@ -259,12 +259,34 @@ object QuoteMatcher {
259
259
// Matches an open term and wraps it into a lambda that provides the free variables
260
260
case Apply (TypeApply (Ident (_), List (TypeTree ())), SeqLiteral (args, _) :: Nil )
261
261
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole ) =>
262
+
263
+ /* Some of method symbols in arguments of higher-order term hole are eta-expanded.
264
+ * e.g.
265
+ * g: (Int) => Int
266
+ * => {
267
+ * def $anonfun(y: Int): Int = g(y)
268
+ * closure($anonfun)
269
+ * }
270
+ *
271
+ * f: (using Int) => Int
272
+ * => f(using x)
273
+ * This function restores the symbol of the original method from
274
+ * the eta-expanded function.
275
+ */
276
+ def getCapturedIdent (arg : Tree )(using Context ): Ident =
277
+ arg match
278
+ case id : Ident => id
279
+ case Apply (fun, _) => getCapturedIdent(fun)
280
+ case Block ((ddef : DefDef ) :: _, _ : Closure ) => getCapturedIdent(ddef.rhs)
281
+ case Typed (expr, _) => getCapturedIdent(expr)
282
+
262
283
val env = summon[Env ]
263
- val capturedArgs = args.map(_.symbol)
264
- val captureEnv = env.filter((k, v) => ! capturedArgs.contains(v))
284
+ val capturedIds = args.map(getCapturedIdent)
285
+ val capturedSymbols = capturedIds.map(_.symbol)
286
+ val captureEnv = env.filter((k, v) => ! capturedSymbols.contains(v))
265
287
withEnv(captureEnv) {
266
288
scrutinee match
267
- case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
289
+ case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe) , env)
268
290
case _ => notMatched
269
291
}
270
292
@@ -394,19 +416,34 @@ object QuoteMatcher {
394
416
case scrutinee @ DefDef (_, paramss1, tpt1, _) =>
395
417
pattern match
396
418
case pattern @ DefDef (_, paramss2, tpt2, _) =>
397
- def rhsEnv : Env =
398
- val paramSyms : List [(Symbol , Symbol )] =
399
- for
400
- (clause1, clause2) <- paramss1.zip(paramss2)
401
- (param1, param2) <- clause1.zip(clause2)
402
- yield
403
- param1.symbol -> param2.symbol
404
- val oldEnv : Env = summon[Env ]
405
- val newEnv : List [(Symbol , Symbol )] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
406
- oldEnv ++ newEnv
407
- matchLists(paramss1, paramss2)(_ =?= _)
408
- &&& tpt1 =?= tpt2
409
- &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
419
+ def matchErasedParams (sctype : Type , pttype : Type ): optional[MatchingExprs ] =
420
+ (sctype, pttype) match
421
+ case (sctpe : MethodType , pttpe : MethodType ) =>
422
+ if sctpe.erasedParams.sameElements(pttpe.erasedParams) then
423
+ matchErasedParams(sctpe.resType, pttpe.resType)
424
+ else
425
+ notMatched
426
+ case _ => matched
427
+
428
+ def matchParamss (scparamss : List [ParamClause ], ptparamss : List [ParamClause ])(using Env ): optional[(Env , MatchingExprs )] =
429
+ (scparamss, ptparamss) match {
430
+ case (scparams :: screst, ptparams :: ptrest) =>
431
+ val mr1 = matchLists(scparams, ptparams)(_ =?= _)
432
+ val newEnv = summon[Env ] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
433
+ val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
434
+ (resEnv, mr1 &&& mrrest)
435
+ case (Nil , Nil ) => (summon[Env ], matched)
436
+ case _ => notMatched
437
+ }
438
+
439
+ val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
440
+ val (pEnv, pmatch) = matchParamss(paramss1, paramss2)
441
+ val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)
442
+
443
+ ematch
444
+ &&& pmatch
445
+ &&& withEnv(defEnv)(tpt1 =?= tpt2)
446
+ &&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs)
410
447
case _ => notMatched
411
448
412
449
case Closure (_, _, tpt1) =>
@@ -497,10 +534,11 @@ object QuoteMatcher {
497
534
*
498
535
* @param tree Scrutinee sub-tree that matched
499
536
* @param patternTpe Type of the pattern hole (from the pattern)
500
- * @param args HOAS arguments (from the pattern)
537
+ * @param argIds Identifiers of HOAS arguments (from the pattern)
538
+ * @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
501
539
* @param env Mapping between scrutinee and pattern variables
502
540
*/
503
- case OpenTree (tree : Tree , patternTpe : Type , args : List [Tree ], env : Env )
541
+ case OpenTree (tree : Tree , patternTpe : Type , argIds : List [Tree ], argTypes : List [ Type ], env : Env )
504
542
505
543
/** Return the expression that was extracted from a hole.
506
544
*
@@ -513,19 +551,22 @@ object QuoteMatcher {
513
551
def toExpr (mapTypeHoles : Type => Type , spliceScope : Scope )(using Context ): Expr [Any ] = this match
514
552
case MatchResult .ClosedTree (tree) =>
515
553
new ExprImpl (tree, spliceScope)
516
- case MatchResult .OpenTree (tree, patternTpe, args, env) =>
517
- val names : List [TermName ] = args.map {
518
- case Block (List (DefDef (nme.ANON_FUN , _, _, Apply (Ident (name), _))), _) => name.asTermName
519
- case arg => arg.symbol.name.asTermName
520
- }
521
- val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
554
+ case MatchResult .OpenTree (tree, patternTpe, argIds, argTypes, env) =>
555
+ val names : List [TermName ] = argIds.map(_.symbol.name.asTermName)
556
+ val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
522
557
val methTpe = MethodType (names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
523
558
val meth = newAnonFun(ctx.owner, methTpe)
524
559
def bodyFn (lambdaArgss : List [List [Tree ]]): Tree = {
525
- val argsMap = args .view.map(_.symbol).zip(lambdaArgss.head).toMap
560
+ val argsMap = argIds .view.map(_.symbol).zip(lambdaArgss.head).toMap
526
561
val body = new TreeMap {
527
562
override def transform (tree : Tree )(using Context ): Tree =
528
563
tree match
564
+ /*
565
+ * When matching a method call `f(0)` against a HOAS pattern `p(g)` where
566
+ * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
567
+ * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
568
+ */
569
+ case Apply (fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args)
529
570
case tree : Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
530
571
case tree => super .transform(tree)
531
572
}.transform(tree)
@@ -534,7 +575,7 @@ object QuoteMatcher {
534
575
val hoasClosure = Closure (meth, bodyFn)
535
576
new ExprImpl (hoasClosure, spliceScope)
536
577
537
- private inline def notMatched : optional[MatchingExprs ] =
578
+ private inline def notMatched [ T ] : optional[T ] =
538
579
optional.break()
539
580
540
581
private inline def matched : MatchingExprs =
@@ -543,8 +584,8 @@ object QuoteMatcher {
543
584
private inline def matched (tree : Tree )(using Context ): MatchingExprs =
544
585
Seq (MatchResult .ClosedTree (tree))
545
586
546
- private def matchedOpen (tree : Tree , patternTpe : Type , args : List [Tree ], env : Env )(using Context ): MatchingExprs =
547
- Seq (MatchResult .OpenTree (tree, patternTpe, args , env))
587
+ private def matchedOpen (tree : Tree , patternTpe : Type , argIds : List [Tree ], argTypes : List [ Type ], env : Env )(using Context ): MatchingExprs =
588
+ Seq (MatchResult .OpenTree (tree, patternTpe, argIds, argTypes , env))
548
589
549
590
extension (self : MatchingExprs )
550
591
/** Concatenates the contents of two successful matchings */
0 commit comments