diff --git a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala index 31678a9481e4..a6ec28a85f3a 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala @@ -259,12 +259,34 @@ object QuoteMatcher { // Matches an open term and wraps it into a lambda that provides the free variables case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil) if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) => + + /* Some of method symbols in arguments of higher-order term hole are eta-expanded. + * e.g. + * g: (Int) => Int + * => { + * def $anonfun(y: Int): Int = g(y) + * closure($anonfun) + * } + * + * f: (using Int) => Int + * => f(using x) + * This function restores the symbol of the original method from + * the eta-expanded function. + */ + def getCapturedIdent(arg: Tree)(using Context): Ident = + arg match + case id: Ident => id + case Apply(fun, _) => getCapturedIdent(fun) + case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs) + case Typed(expr, _) => getCapturedIdent(expr) + val env = summon[Env] - val capturedArgs = args.map(_.symbol) - val captureEnv = env.filter((k, v) => !capturedArgs.contains(v)) + val capturedIds = args.map(getCapturedIdent) + val capturedSymbols = capturedIds.map(_.symbol) + val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v)) withEnv(captureEnv) { scrutinee match - case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env) + case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env) case _ => notMatched } @@ -394,19 +416,34 @@ object QuoteMatcher { case scrutinee @ DefDef(_, paramss1, tpt1, _) => pattern match case pattern @ DefDef(_, paramss2, tpt2, _) => - def rhsEnv: Env = - val paramSyms: List[(Symbol, Symbol)] = - for - (clause1, clause2) <- paramss1.zip(paramss2) - (param1, param2) <- clause1.zip(clause2) - yield - param1.symbol -> param2.symbol - val oldEnv: Env = summon[Env] - val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms - oldEnv ++ newEnv - matchLists(paramss1, paramss2)(_ =?= _) - &&& tpt1 =?= tpt2 - &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs) + def matchErasedParams(sctype: Type, pttype: Type): optional[MatchingExprs] = + (sctype, pttype) match + case (sctpe: MethodType, pttpe: MethodType) => + if sctpe.erasedParams.sameElements(pttpe.erasedParams) then + matchErasedParams(sctpe.resType, pttpe.resType) + else + notMatched + case _ => matched + + def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] = + (scparamss, ptparamss) match { + case (scparams :: screst, ptparams :: ptrest) => + val mr1 = matchLists(scparams, ptparams)(_ =?= _) + val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)) + val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest)) + (resEnv, mr1 &&& mrrest) + case (Nil, Nil) => (summon[Env], matched) + case _ => notMatched + } + + val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr) + val (pEnv, pmatch) = matchParamss(paramss1, paramss2) + val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol) + + ematch + &&& pmatch + &&& withEnv(defEnv)(tpt1 =?= tpt2) + &&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs) case _ => notMatched case Closure(_, _, tpt1) => @@ -497,10 +534,11 @@ object QuoteMatcher { * * @param tree Scrutinee sub-tree that matched * @param patternTpe Type of the pattern hole (from the pattern) - * @param args HOAS arguments (from the pattern) + * @param argIds Identifiers of HOAS arguments (from the pattern) + * @param argTypes Eta-expanded types of HOAS arguments (from the pattern) * @param env Mapping between scrutinee and pattern variables */ - case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env) + case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env) /** Return the expression that was extracted from a hole. * @@ -513,19 +551,22 @@ object QuoteMatcher { def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match case MatchResult.ClosedTree(tree) => new ExprImpl(tree, spliceScope) - case MatchResult.OpenTree(tree, patternTpe, args, env) => - val names: List[TermName] = args.map { - case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName - case arg => arg.symbol.name.asTermName - } - val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr)) + case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) => + val names: List[TermName] = argIds.map(_.symbol.name.asTermName) + val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) val meth = newAnonFun(ctx.owner, methTpe) def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { - val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap + val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap val body = new TreeMap { override def transform(tree: Tree)(using Context): Tree = tree match + /* + * When matching a method call `f(0)` against a HOAS pattern `p(g)` where + * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold + * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. + */ + case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args) case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) case tree => super.transform(tree) }.transform(tree) @@ -534,7 +575,7 @@ object QuoteMatcher { val hoasClosure = Closure(meth, bodyFn) new ExprImpl(hoasClosure, spliceScope) - private inline def notMatched: optional[MatchingExprs] = + private inline def notMatched[T]: optional[T] = optional.break() private inline def matched: MatchingExprs = @@ -543,8 +584,8 @@ object QuoteMatcher { private inline def matched(tree: Tree)(using Context): MatchingExprs = Seq(MatchResult.ClosedTree(tree)) - private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs = - Seq(MatchResult.OpenTree(tree, patternTpe, args, env)) + private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs = + Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env)) extension (self: MatchingExprs) /** Concatenates the contents of two successful matchings */ diff --git a/tests/run-custom-args/run-macros-erased/i17105.check b/tests/run-custom-args/run-macros-erased/i17105.check new file mode 100644 index 000000000000..40b736596a70 --- /dev/null +++ b/tests/run-custom-args/run-macros-erased/i17105.check @@ -0,0 +1,3 @@ +case erased: [erased case] +case erased nested: c +case erased nested 2: d diff --git a/tests/run-custom-args/run-macros-erased/i17105/Macro_1.scala b/tests/run-custom-args/run-macros-erased/i17105/Macro_1.scala new file mode 100644 index 000000000000..37a48a0ececd --- /dev/null +++ b/tests/run-custom-args/run-macros-erased/i17105/Macro_1.scala @@ -0,0 +1,25 @@ +import scala.quoted.* + +inline def testExpr(inline body: Any) = ${ testExprImpl('body) } +def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] = + body match + // Erased Types + case '{ def erasedfn(y: String) = "placeholder"; $a(erasedfn): String } => + Expr("This case should not match") + case '{ def erasedfn(erased y: String) = "placeholder"; $a(erasedfn): String } => + '{ $a((erased z: String) => "[erased case]") } + case '{ + def erasedfn(a: String, b: String)(c: String, d: String): String = a + $y(erasedfn): String + } => Expr("This should not match") + case '{ + def erasedfn(a: String, erased b: String)(erased c: String, d: String): String = a + $y(erasedfn): String + } => + '{ $y((a: String, erased b: String) => (erased c: String, d: String) => d) } + case '{ + def erasedfn(a: String, erased b: String)(c: String, erased d: String): String = a + $y(erasedfn): String + } => + '{ $y((a: String, erased b: String) => (c: String, erased d: String) => c) } + case _ => Expr("not matched") diff --git a/tests/run-custom-args/run-macros-erased/i17105/Test_2.scala b/tests/run-custom-args/run-macros-erased/i17105/Test_2.scala new file mode 100644 index 000000000000..bfb6967e1e00 --- /dev/null +++ b/tests/run-custom-args/run-macros-erased/i17105/Test_2.scala @@ -0,0 +1,10 @@ +@main def Test: Unit = + println("case erased: " + testExpr { def erasedfn1(erased x: String) = "placeholder"; erasedfn1("arg1")}) + println("case erased nested: " + testExpr { + def erasedfn2(p: String, erased q: String)(r: String, erased s: String) = p + erasedfn2("a", "b")("c", "d") + }) + println("case erased nested 2: " + testExpr { + def erasedfn2(p: String, erased q: String)(erased r: String, s: String) = p + erasedfn2("a", "b")("c", "d") + }) diff --git a/tests/run-macros/i17105.check b/tests/run-macros/i17105.check new file mode 100644 index 000000000000..17c45e97b888 --- /dev/null +++ b/tests/run-macros/i17105.check @@ -0,0 +1,8 @@ +case single: [1st case] arg1 outside +case no-param-method (will be eta-expanded): [1st case] placeholder 2 +case curried: [2nd case] arg1, arg2 outside +case methods from outer scope: [1st case] arg1 outer-method +case refinement: Hoe got 1 +case dependent: 1 +case dependent2: 1 +case dependent3: 1 diff --git a/tests/run-macros/i17105/Lib1.scala b/tests/run-macros/i17105/Lib1.scala new file mode 100644 index 000000000000..ed2b145f7914 --- /dev/null +++ b/tests/run-macros/i17105/Lib1.scala @@ -0,0 +1,15 @@ + +// Test case for dependent types +trait DSL { + type N + def toString(n: N): String + val zero: N + def next(n: N): N +} + +object IntDSL extends DSL { + type N = Int + def toString(n: N): String = n.toString() + val zero = 0 + def next(n: N): N = n + 1 +} diff --git a/tests/run-macros/i17105/Macro_2.scala b/tests/run-macros/i17105/Macro_2.scala new file mode 100644 index 000000000000..add0c29f95d3 --- /dev/null +++ b/tests/run-macros/i17105/Macro_2.scala @@ -0,0 +1,34 @@ +import scala.quoted.* +import language.experimental.erasedDefinitions + +inline def testExpr(inline body: Any) = ${ testExprImpl('body) } +def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] = + body match + case '{ def g(y: String) = "placeholder" + y; $a(g): String } => + '{ $a((z: String) => s"[1st case] ${z}") } + case '{ def g(y: String)(z: String) = "placeholder" + y; $a(g): String } => + '{ $a((z1: String) => (z2: String) => s"[2nd case] ${z1}, ${z2}") } + // Refined Types + case '{ + type t + def refined(a: `t`): String = $x(a): String + $y(refined): String + } => + '{ $y($x) } + // Dependent Types + case '{ + def p(dsl: DSL): dsl.N = dsl.zero + $y(p): String + } => + '{ $y((dsl1: DSL) => dsl1.next(dsl1.zero)) } + case '{ + def p(dsl: DSL)(a: dsl.N): dsl.N = a + $y(p): String + } => + '{ $y((dsl: DSL) => (b2: dsl.N) => dsl.next(b2)) } + case '{ + def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero + $y(p): String + } => + '{ $y((dsl1: DSL) => (dsl2: DSL) => dsl2.next(dsl2.zero)) } + case _ => Expr("not matched") diff --git a/tests/run-macros/i17105/Test_3.scala b/tests/run-macros/i17105/Test_3.scala new file mode 100644 index 000000000000..c19ac507e1a4 --- /dev/null +++ b/tests/run-macros/i17105/Test_3.scala @@ -0,0 +1,23 @@ +import reflect.Selectable.reflectiveSelectable + +class Hoe { def f(x: Int): String = s"Hoe got ${x}" } + +@main def Test: Unit = + println("case single: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + " outside" }) + println("case no-param-method (will be eta-expanded): " + testExpr { def f(x: String) = "placeholder" + x; (() => f)()("placeholder 2") }) + println("case curried: " + testExpr { def f(x: String)(y: String) = "placeholder" + x; f("arg1")("arg2") + " outside" }) + def outer() = " outer-method" + println("case methods from outer scope: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + outer() }) + println("case refinement: " + testExpr { def refined(a: { def f(x: Int): String }): String = a.f(1); refined(Hoe()) }) + println("case dependent: " + testExpr { + def p(a: DSL): a.N = a.zero + IntDSL.toString(p(IntDSL)) + }) + println("case dependent2: " + testExpr { + def p(dsl1: DSL)(c: dsl1.N): dsl1.N = c + IntDSL.toString(p(IntDSL)(IntDSL.zero)) + }) + println("case dependent3: " + testExpr { + def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero + IntDSL.toString(p(IntDSL)(IntDSL)) + })