diff --git a/compiler/src/scala/quoted/runtime/impl/Matcher.scala b/compiler/src/scala/quoted/runtime/impl/Matcher.scala deleted file mode 100644 index c6176cbcf830..000000000000 --- a/compiler/src/scala/quoted/runtime/impl/Matcher.scala +++ /dev/null @@ -1,438 +0,0 @@ -package scala.quoted -package runtime.impl - -import scala.annotation.internal.sharable -import scala.annotation.{Annotation, compileTimeOnly} - -/** Matches a quoted tree against a quoted pattern tree. - * A quoted pattern tree may have type and term holes in addition to normal terms. - * - * - * Semantics: - * - * We use `'{..}` for expression, `'[..]` for types and `⟨..⟩` for patterns nested in expressions. - * The semantics are defined as a list of reduction rules that are tried one by one until one matches. - * - * Operations: - * - `s =?= p` checks if a scrutinee `s` matches the pattern `p` while accumulating extracted parts of the code. - * - `isColosedUnder(x1, .., xn)('{e})` returns true if and only if all the references in `e` to names defined in the patttern are contained in the set `{x1, ... xn}`. - * - `lift(x1, .., xn)('{e})` returns `(y1, ..., yn) => [xi = $yi]'{e}` where `yi` is an `Expr` of the type of `xi`. - * - `withEnv(x1 -> y1, ..., xn -> yn)(matching)` evaluates mathing recording that `xi` is equivalent to `yi`. - * - `matched` denotes that the the match succedded and `matched('{e})` denotes that a matech succeded and extracts `'{e}` - * - `&&&` matches if both sides match. Concatenates the extracted expressions of both sides. - * - * Note: that not all quoted terms bellow are valid expressions - * - * ```scala - * /* Term hole */ - * '{ e } =?= '{ hole[T] } && typeOf('{e}) <:< T && isColosedUnder()('{e}) ===> matched('{e}) - * - * /* Higher order term hole */ - * '{ e } =?= '{ hole[(T1, ..., Tn) => T](x1, ..., xn) } && isColosedUnder(x1, ... xn)('{e}) ===> matched(lift(x1, ..., xn)('{e})) - * - * /* Match literal */ - * '{ lit } =?= '{ lit } ===> matched - * - * /* Match type ascription (a) */ - * '{ e: T } =?= '{ p } ===> '{e} =?= '{p} - * - * /* Match type ascription (b) */ - * '{ e } =?= '{ p: P } ===> '{e} =?= '{p} - * - * /* Match selection */ - * '{ e.x } =?= '{ p.x } ===> '{e} =?= '{p} - * - * /* Match reference */ - * '{ x } =?= '{ x } ===> matched - * - * /* Match application */ - * '{e0(e1, ..., en)} =?= '{p0(p1, ..., p2)} ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& ... %% '{en} =?= '{pn} - * - * /* Match type application */ - * '{e[T1, ..., Tn]} =?= '{p[P1, ..., Pn]} ===> '{e} =?= '{p} &&& '[T1] =?= '{P1} &&& ... %% '[Tn] =?= '[Pn] - * - * /* Match block flattening */ - * '{ {e0; e1; ...; en}; em } =?= '{ {p0; p1; ...; pm}; em } ===> '{ e0; {e1; ...; en; em} } =?= '{ p0; {p1; ...; pm; em} } - * - * /* Match block */ - * '{ e1; e2 } =?= '{ p1; p2 } ===> '{e1} =?= '{p1} &&& '{e2} =?= '{p2} - * - * /* Match def block */ - * '{ e1; e2 } =?= '{ p1; p2 } ===> withEnv(symOf(e1) -> symOf(p1))('{e1} =?= '{p1} &&& '{e2} =?= '{p2}) - * - * /* Match if */ - * '{ if e0 then e1 else e2 } =?= '{ if p0 then p1 else p2 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2} - * - * /* Match while */ - * '{ while e0 do e1 } =?= '{ while p0 do p1 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} - * - * /* Match assign */ - * '{ e0 = e1 } =?= '{ p0 = p1 } ==> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} - * - * /* Match new */ - * '{ new T } =?= '{ new T } ===> matched - * - * /* Match this */ - * '{ C.this } =?= '{ C.this } ===> matched - * - * /* Match super */ - * '{ e.super } =?= '{ p.super } ===> '{e} =?= '{p} - * - * /* Match varargs */ - * '{ e: _* } =?= '{ p: _* } ===> '{e} =?= '{p} - * - * /* Match val */ - * '{ val x: T = e1; e2 } =?= '{ val y: P = p1; p2 } ===> withEnv(x -> y)('[T] =?= '[P] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2}) - * - * /* Match def */ - * '{ def x0(x1: T1, ..., xn: Tn): T0 = e1; e2 } =?= '{ def y0(y1: P1, ..., yn: Pn): P0 = p1; p2 } ===> withEnv(x0 -> y0, ..., xn -> yn)('[T0] =?= '[P0] &&& ... &&& '[Tn] =?= '[Pn] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2}) - * - * // Types - * - * /* Match type */ - * '[T] =?= '[P] && T <:< P ===> matched - * - * ``` - */ -object Matcher { - - abstract class QuoteMatcher[QCtx <: Quotes & Singleton](val qctx: QCtx) { - - // TODO improve performance - - // TODO use flag from qctx.reflect. Maybe -debug or add -debug-macros - private inline val debug = false - - import qctx.reflect._ - import Matching._ - - def patternHoleSymbol: Symbol - def higherOrderHoleSymbol: Symbol - - /** A map relating equivalent symbols from the scrutinee and the pattern - * For example in - * ``` - * '{val a = 4; a * a} match case '{ val x = 4; x * x } - * ``` - * when matching `a * a` with `x * x` the environment will contain `Map(a -> x)`. - */ - private type Env = Map[Symbol, Symbol] - - inline private def withEnv[T](env: Env)(inline body: Env ?=> T): T = body(using env) - - def termMatch(scrutineeTerm: Term, patternTerm: Term): Option[Tuple] = - given Env = Map.empty - scrutineeTerm =?= patternTerm - - def typeTreeMatch(scrutineeTypeTree: TypeTree, patternTypeTree: TypeTree): Option[Tuple] = - given Env = Map.empty - scrutineeTypeTree =?= patternTypeTree - - /** Check that all trees match with `mtch` and concatenate the results with &&& */ - private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match { - case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch) - case (Nil, Nil) => matched - case _ => notMatched - } - - extension (scrutinees: List[Tree]) - /** Check that all trees match with =?= and concatenate the results with &&& */ - private def =?= (patterns: List[Tree])(using Env): Matching = - matchLists(scrutinees, patterns)(_ =?= _) - - extension (scrutinee0: Tree) - /** Check that the trees match and return the contents from the pattern holes. - * Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes. - * - * @param scrutinee The tree beeing matched - * @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes. - * @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. - * @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. - */ - private def =?= (pattern0: Tree)(using Env): Matching = { - - /* Match block flattening */ // TODO move to cases - /** Normalize the tree */ - def normalize(tree: Tree): Tree = tree match { - case Block(Nil, expr) => normalize(expr) - case Block(stats1, Block(stats2, expr)) => - expr match - case _: Closure => tree - case _ => normalize(Block(stats1 ::: stats2, expr)) - case Inlined(_, Nil, expr) => normalize(expr) - case _ => tree - } - - val scrutinee = normalize(scrutinee0) - val pattern = normalize(pattern0) - - /** Check that both are `val` or both are `lazy val` or both are `var` **/ - def checkValFlags(): Boolean = { - import Flags._ - val sFlags = scrutinee.symbol.flags - val pFlags = pattern.symbol.flags - sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable) - } - - (scrutinee, pattern) match { - - /* Term hole */ - // Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree - case (scrutinee @ Typed(s, tpt1), Typed(TypeApply(patternHole, tpt :: Nil), tpt2)) - if patternHole.symbol == patternHoleSymbol && - s.tpe <:< tpt.tpe && - tpt2.tpe.derivesFrom(defn.RepeatedParamClass) => - matched(scrutinee.asExpr) - - /* Term hole */ - // Match a scala.internal.Quoted.patternHole and return the scrutinee tree - case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil)) - if patternHole.symbol == patternHoleSymbol && - scrutinee.tpe <:< tpt.tpe => - matched(scrutinee.asExpr) - - /* Higher order term hole */ - // Matches an open term and wraps it into a lambda that provides the free variables - case (scrutinee, pattern @ Apply(TypeApply(Ident("higherOrderHole"), List(Inferred())), Repeated(args, _) :: Nil)) - if pattern.symbol == higherOrderHoleSymbol => - - def bodyFn(lambdaArgs: List[Tree]): Tree = { - val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap - new TreeMap { - override def transformTerm(tree: Term)(owner: Symbol): Term = - tree match - case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) - case tree => super.transformTerm(tree)(owner) - }.transformTree(scrutinee)(Symbol.spliceOwner) - } - val names = args.map { - case Block(List(DefDef("$anonfun", _, _, Some(Apply(Ident(name), _)))), _) => name - case arg => arg.symbol.name - } - val argTypes = args.map(x => x.tpe.widenTermRefByName) - val resType = pattern.tpe - val res = Lambda(Symbol.spliceOwner, MethodType(names)(_ => argTypes, _ => resType), (meth, x) => bodyFn(x).changeOwner(meth)) - matched(res.asExpr) - - // - // Match two equivalent trees - // - - /* Match literal */ - case (Literal(constant1), Literal(constant2)) if constant1 == constant2 => - matched - - /* Match type ascription (a) */ - case (Typed(expr1, _), pattern) => - expr1 =?= pattern - - /* Match type ascription (b) */ - case (scrutinee, Typed(expr2, _)) => - scrutinee =?= expr2 - - /* Match selection */ - case (ref: Ref, Select(qual2, _)) if symbolMatch(scrutinee, pattern) => - ref match - case Select(qual1, _) => qual1 =?= qual2 - case ref: Ident => - ref.tpe match - case TermRef(qual: TermRef, _) => Ref.term(qual) =?= qual2 - case _ => matched - - /* Match reference */ - case (_: Ref, _: Ident) if symbolMatch(scrutinee, pattern) => - matched - - /* Match application */ - case (Apply(fn1, args1), Apply(fn2, args2)) => - fn1 =?= fn2 &&& args1 =?= args2 - - /* Match type application */ - case (TypeApply(fn1, args1), TypeApply(fn2, args2)) => - fn1 =?= fn2 &&& args1 =?= args2 - - /* Match block */ - case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) => - val newEnv = (stat1, stat2) match { - case (stat1: Definition, stat2: Definition) => - summon[Env] + (stat1.symbol -> stat2.symbol) - case _ => - summon[Env] - } - withEnv(newEnv) { - stat1 =?= stat2 &&& Block(stats1, expr1) =?= Block(stats2, expr2) - } - - /* Match if */ - case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) => - cond1 =?= cond2 &&& thenp1 =?= thenp2 &&& elsep1 =?= elsep2 - - /* Match while */ - case (While(cond1, body1), While(cond2, body2)) => - cond1 =?= cond2 &&& body1 =?= body2 - - /* Match assign */ - case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) => - lhs1 =?= lhs2 &&& rhs1 =?= rhs2 - - /* Match new */ - case (New(tpt1), New(tpt2)) if tpt1.tpe.typeSymbol == tpt2.tpe.typeSymbol => - matched - - /* Match this */ - case (This(_), This(_)) if scrutinee.symbol == pattern.symbol => - matched - - /* Match super */ - case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 => - qual1 =?= qual2 - - /* Match varargs */ - case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size => - elems1 =?= elems2 - - /* Match type */ - // TODO remove this? - case (scrutinee: TypeTree, pattern: TypeTree) if scrutinee.tpe <:< pattern.tpe => - matched - - /* Match val */ - case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() => - def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol) - tpt1 =?= tpt2 &&& treeOptMatches(rhs1, rhs2)(using rhsEnv) - - /* Match def */ - case (DefDef(_, paramss1, tpt1, Some(rhs1)), DefDef(_, paramss2, tpt2, Some(rhs2))) => - def rhsEnv = - val paramSyms: List[(Symbol, Symbol)] = - for - (clause1, clause2) <- paramss1.zip(paramss2) - (param1, param2) <- clause1.params.zip(clause2.params) - yield - param1.symbol -> param2.symbol - val oldEnv: Env = summon[Env] - val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms - oldEnv ++ newEnv - - matchLists(paramss1, paramss2)(_ =?= _) - &&& tpt1 =?= tpt2 - &&& withEnv(rhsEnv)(rhs1 =?= rhs2) - - case (Closure(_, tpt1), Closure(_, tpt2)) => - // TODO match tpt1 with tpt2? - matched - - case (NamedArg(name1, arg1), NamedArg(name2, arg2)) if name1 == name2 => - arg1 =?= arg2 - - // No Match - case _ => - if (debug) - println( - s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - |Scrutinee - | ${scrutinee.show} - |did not match pattern - | ${pattern.show} - | - |with environment: ${summon[Env]} - | - |Scrutinee: ${scrutinee.show(using Printer.TreeStructure)} - |Pattern: ${pattern.show(using Printer.TreeStructure)} - | - |""".stripMargin) - notMatched - } - } - end extension - - extension (scrutinee: ParamClause) - /** Check that all parameters in the clauses clauses match with =?= and concatenate the results with &&& */ - private def =?= (pattern: ParamClause)(using Env)(using DummyImplicit): Matching = - (scrutinee, pattern) match - case (TermParamClause(params1), TermParamClause(params2)) => matchLists(params1, params2)(_ =?= _) - case (TypeParamClause(params1), TypeParamClause(params2)) => matchLists(params1, params2)(_ =?= _) - case _ => notMatched - - /** Does the scrutenne symbol match the pattern symbol? It matches if: - * - They are the same symbol - * - The scrutinee has is in the environment and they are equivalent - * - The scrutinee overrides the symbol of the pattern - */ - private def symbolMatch(scrutineeTree: Tree, patternTree: Tree)(using Env): Boolean = - val scrutinee = scrutineeTree.symbol - val devirtualizedScrutinee = scrutineeTree match - case Select(qual, _) => - val sym = scrutinee.overridingSymbol(qual.tpe.typeSymbol) - if sym.exists then sym - else scrutinee - case _ => scrutinee - val pattern = patternTree.symbol - - devirtualizedScrutinee == pattern - || summon[Env].get(devirtualizedScrutinee).contains(pattern) - || devirtualizedScrutinee.allOverriddenSymbols.contains(pattern) - - private object ClosedPatternTerm { - /** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */ - def unapply(term: Term)(using Env): Option[term.type] = - if freePatternVars(term).isEmpty then Some(term) else None - - /** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */ - def freePatternVars(term: Term)(using env: Env): Set[Symbol] = - val accumulator = new TreeAccumulator[Set[Symbol]] { - def foldTree(x: Set[Symbol], tree: Tree)(owner: Symbol): Set[Symbol] = - tree match - case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)(owner) - case _ => foldOverTree(x, tree)(owner) - } - accumulator.foldTree(Set.empty, term)(Symbol.spliceOwner) - } - - private object IdentArgs { - def unapply(args: List[Term]): Option[List[Ident]] = - args.foldRight(Option(List.empty[Ident])) { - case (id: Ident, Some(acc)) => Some(id :: acc) - case (Block(List(DefDef("$anonfun", TermParamClause(params) :: Nil, Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc)) - if params.zip(args).forall(_.symbol == _.symbol) => - Some(id :: acc) - case _ => None - } - } - - private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(using Env): Matching = { - (scrutinee, pattern) match { - case (Some(x), Some(y)) => x =?= y - case (None, None) => matched - case _ => notMatched - } - } - - } - - /** Result of matching a part of an expression */ - private opaque type Matching = Option[Tuple] - - private object Matching { - - def notMatched: Matching = None - val matched: Matching = Some(Tuple()) - def matched(x: Any): Matching = Some(Tuple1(x)) - - extension (self: Matching) - def asOptionOfTuple: Option[Tuple] = self - - /** Concatenates the contents of two successful matchings or return a `notMatched` */ - def &&& (that: => Matching): Matching = self match { - case Some(x) => - that match { - case Some(y) => Some(x ++ y) - case _ => None - } - case _ => None - } - end extension - - } - -} diff --git a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala new file mode 100644 index 000000000000..72c0f289c456 --- /dev/null +++ b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala @@ -0,0 +1,428 @@ +package scala.quoted +package runtime.impl + +import scala.annotation.internal.sharable +import scala.annotation.{Annotation, compileTimeOnly} + +import dotty.tools.dotc +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.* + +/** Matches a quoted tree against a quoted pattern tree. + * A quoted pattern tree may have type and term holes in addition to normal terms. + * + * + * Semantics: + * + * We use `'{..}` for expression, `'[..]` for types and `⟨..⟩` for patterns nested in expressions. + * The semantics are defined as a list of reduction rules that are tried one by one until one matches. + * + * Operations: + * - `s =?= p` checks if a scrutinee `s` matches the pattern `p` while accumulating extracted parts of the code. + * - `isClosedUnder(x1, .., xn)('{e})` returns true if and only if all the references in `e` to names defined in the pattern are contained in the set `{x1, ... xn}`. + * - `lift(x1, .., xn)('{e})` returns `(y1, ..., yn) => [xi = $yi]'{e}` where `yi` is an `Expr` of the type of `xi`. + * - `withEnv(x1 -> y1, ..., xn -> yn)(matching)` evaluates matching recording that `xi` is equivalent to `yi`. + * - `matched` denotes that the the match succeeded and `matched('{e})` denotes that a match succeeded and extracts `'{e}` + * - `&&&` matches if both sides match. Concatenates the extracted expressions of both sides. + * + * Note: that not all quoted terms bellow are valid expressions + * + * ```scala + * /* Term hole */ + * '{ e } =?= '{ hole[T] } && typeOf('{e}) <:< T && isClosedUnder()('{e}) ===> matched('{e}) + * + * /* Higher order term hole */ + * '{ e } =?= '{ hole[(T1, ..., Tn) => T](x1, ..., xn) } && isClosedUnder(x1, ... xn)('{e}) ===> matched(lift(x1, ..., xn)('{e})) + * + * /* Match literal */ + * '{ lit } =?= '{ lit } ===> matched + * + * /* Match type ascription (a) */ + * '{ e: T } =?= '{ p } ===> '{e} =?= '{p} + * + * /* Match type ascription (b) */ + * '{ e } =?= '{ p: P } ===> '{e} =?= '{p} + * + * /* Match selection */ + * '{ e.x } =?= '{ p.x } ===> '{e} =?= '{p} + * + * /* Match reference */ + * '{ x } =?= '{ x } ===> matched + * + * /* Match application */ + * '{e0(e1, ..., en)} =?= '{p0(p1, ..., p2)} ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& ... %% '{en} =?= '{pn} + * + * /* Match type application */ + * '{e[T1, ..., Tn]} =?= '{p[P1, ..., Pn]} ===> '{e} =?= '{p} &&& '[T1] =?= '{P1} &&& ... %% '[Tn] =?= '[Pn] + * + * /* Match block flattening */ + * '{ {e0; e1; ...; en}; em } =?= '{ {p0; p1; ...; pm}; em } ===> '{ e0; {e1; ...; en; em} } =?= '{ p0; {p1; ...; pm; em} } + * + * /* Match block */ + * '{ e1; e2 } =?= '{ p1; p2 } ===> '{e1} =?= '{p1} &&& '{e2} =?= '{p2} + * + * /* Match def block */ + * '{ e1; e2 } =?= '{ p1; p2 } ===> withEnv(symOf(e1) -> symOf(p1))('{e1} =?= '{p1} &&& '{e2} =?= '{p2}) + * + * /* Match if */ + * '{ if e0 then e1 else e2 } =?= '{ if p0 then p1 else p2 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2} + * + * /* Match while */ + * '{ while e0 do e1 } =?= '{ while p0 do p1 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} + * + * /* Match assign */ + * '{ e0 = e1 } =?= '{ p0 = p1 } ==> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} + * + * /* Match new */ + * '{ new T } =?= '{ new T } ===> matched + * + * /* Match this */ + * '{ C.this } =?= '{ C.this } ===> matched + * + * /* Match super */ + * '{ e.super } =?= '{ p.super } ===> '{e} =?= '{p} + * + * /* Match varargs */ + * '{ e: _* } =?= '{ p: _* } ===> '{e} =?= '{p} + * + * /* Match val */ + * '{ val x: T = e1; e2 } =?= '{ val y: P = p1; p2 } ===> withEnv(x -> y)('[T] =?= '[P] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2}) + * + * /* Match def */ + * '{ def x0(x1: T1, ..., xn: Tn): T0 = e1; e2 } =?= '{ def y0(y1: P1, ..., yn: Pn): P0 = p1; p2 } ===> withEnv(x0 -> y0, ..., xn -> yn)('[T0] =?= '[P0] &&& ... &&& '[Tn] =?= '[Pn] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2}) + * + * // Types + * + * /* Match type */ + * '[T] =?= '[P] && T <:< P ===> matched + * + * ``` + */ +object QuoteMatcher { + import tpd.* + + // TODO improve performance + + // TODO use flag from Context. Maybe -debug or add -debug-macros + private inline val debug = false + + import Matching._ + + /** A map relating equivalent symbols from the scrutinee and the pattern + * For example in + * ``` + * '{val a = 4; a * a} match case '{ val x = 4; x * x } + * ``` + * when matching `a * a` with `x * x` the environment will contain `Map(a -> x)`. + */ + private type Env = Map[Symbol, Symbol] + + private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env) + + def treeMatch(scrutineeTerm: Tree, patternTerm: Tree)(using Context): Option[Tuple] = + given Env = Map.empty + scrutineeTerm =?= patternTerm + + /** Check that all trees match with `mtch` and concatenate the results with &&& */ + private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match { + case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch) + case (Nil, Nil) => matched + case _ => notMatched + } + + extension (scrutinees: List[Tree]) + private def =?= (patterns: List[Tree])(using Env, Context): Matching = + matchLists(scrutinees, patterns)(_ =?= _) + + extension (scrutinee0: Tree) + + /** Check that the trees match and return the contents from the pattern holes. + * Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes. + * + * @param scrutinee The tree being matched + * @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes. + * @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. + * @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. + */ + private def =?= (pattern0: Tree)(using Env, Context): Matching = + + /* Match block flattening */ // TODO move to cases + /** Normalize the tree */ + def normalize(tree: Tree): Tree = tree match { + case Block(Nil, expr) => normalize(expr) + case Block(stats1, Block(stats2, expr)) => + expr match + case _: Closure => tree + case _ => normalize(Block(stats1 ::: stats2, expr)) + case Inlined(_, Nil, expr) => normalize(expr) + case _ => tree + } + + val scrutinee = normalize(scrutinee0) + val pattern = normalize(pattern0) + + /** Check that both are `val` or both are `lazy val` or both are `var` **/ + def checkValFlags(): Boolean = { + val sFlags = scrutinee.symbol.flags + val pFlags = pattern.symbol.flags + sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable) + } + + // TODO remove + object TypeTreeTypeTest: + def unapply(x: Tree): Option[Tree & x.type] = x match + case x: (TypeBoundsTree & x.type) => None + case x: (Tree & x.type) if x.isType => Some(x) + case _ => None + end TypeTreeTypeTest + + (scrutinee, pattern) match + + /* Term hole */ + // Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree + case (scrutinee @ Typed(s, tpt1), Typed(TypeApply(patternHole, tpt :: Nil), tpt2)) + if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) && + s.tpe <:< tpt.tpe && + tpt2.tpe.derivesFrom(defn.RepeatedParamClass) => + matched(scrutinee) + + /* Term hole */ + // Match a scala.internal.Quoted.patternHole and return the scrutinee tree + case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil)) + if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) && + scrutinee.tpe <:< tpt.tpe => + matched(scrutinee) + + /* Higher order term hole */ + // Matches an open term and wraps it into a lambda that provides the free variables + case (scrutinee, pattern @ Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)) + if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) => + val names: List[TermName] = args.map { + case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName + case arg => arg.symbol.name.asTermName + } + val argTypes = args.map(x => x.tpe.widenTermRefExpr) + val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe) + val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, methTpe) + def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { + val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap + val body = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = + tree match + case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) + case tree => super.transform(tree) + }.transform(scrutinee) + TreeOps(body).changeNonLocalOwners(meth) + } + matched(Closure(meth, bodyFn)) + + // + // Match two equivalent trees + // + + /* Match literal */ + case (Literal(constant1), Literal(constant2)) if constant1 == constant2 => + matched + + /* Match type ascription (a) */ + case (Typed(expr1, _), pattern) => + expr1 =?= pattern + + /* Match type ascription (b) */ + case (scrutinee, Typed(expr2, _)) => + scrutinee =?= expr2 + + /* Match selection */ + case (ref: RefTree, Select(qual2, _)) if symbolMatch(scrutinee, pattern) => + ref match + case Select(qual1, _) => qual1 =?= qual2 + case ref: Ident => + ref.tpe match + case TermRef(qual: TermRef, _) => tpd.ref(qual) =?= qual2 + case _ => matched + + /* Match reference */ + case (_: RefTree, _: Ident) if symbolMatch(scrutinee, pattern) => + matched + + /* Match application */ + case (Apply(fn1, args1), Apply(fn2, args2)) => + fn1 =?= fn2 &&& args1 =?= args2 + + /* Match type application */ + case (TypeApply(fn1, args1), TypeApply(fn2, args2)) => + fn1 =?= fn2 &&& args1 =?= args2 + + /* Match block */ + case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) => + val newEnv = (stat1, stat2) match { + case (stat1: MemberDef, stat2: MemberDef) => + summon[Env] + (stat1.symbol -> stat2.symbol) + case _ => + summon[Env] + } + withEnv(newEnv) { + stat1 =?= stat2 &&& Block(stats1, expr1) =?= Block(stats2, expr2) + } + + /* Match if */ + case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) => + cond1 =?= cond2 &&& thenp1 =?= thenp2 &&& elsep1 =?= elsep2 + + /* Match while */ + case (WhileDo(cond1, body1), WhileDo(cond2, body2)) => + cond1 =?= cond2 &&& body1 =?= body2 + + /* Match assign */ + case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) => + lhs1 =?= lhs2 &&& rhs1 =?= rhs2 + + /* Match new */ + case (New(tpt1), New(tpt2)) if tpt1.tpe.typeSymbol == tpt2.tpe.typeSymbol => + matched + + /* Match this */ + case (This(_), This(_)) if scrutinee.symbol == pattern.symbol => + matched + + /* Match super */ + case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 => + qual1 =?= qual2 + + /* Match varargs */ + case (SeqLiteral(elems1, _), SeqLiteral(elems2, _)) if elems1.size == elems2.size => + elems1 =?= elems2 + + /* Match type */ + // TODO remove this? + case (TypeTreeTypeTest(scrutinee), TypeTreeTypeTest(pattern)) if scrutinee.tpe <:< pattern.tpe => + matched + + /* Match val */ + case (scrutinee @ ValDef(_, tpt1, _), pattern @ ValDef(_, tpt2, _)) if checkValFlags() => + def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol) + tpt1 =?= tpt2 &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs) + + /* Match def */ + case (scrutinee @ DefDef(_, paramss1, tpt1, _), pattern @ DefDef(_, paramss2, tpt2, _)) => + def rhsEnv: Env = + val paramSyms: List[(Symbol, Symbol)] = + for + (clause1, clause2) <- paramss1.zip(paramss2) + (param1, param2) <- clause1.zip(clause2) + yield + param1.symbol -> param2.symbol + val oldEnv: Env = summon[Env] + val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms + oldEnv ++ newEnv + + matchLists(paramss1, paramss2)(_ =?= _) + &&& tpt1 =?= tpt2 + &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs) + + case (Closure(_, _, tpt1), Closure(_, _, tpt2)) => + // TODO match tpt1 with tpt2? + matched + + case (NamedArg(name1, arg1), NamedArg(name2, arg2)) if name1 == name2 => + arg1 =?= arg2 + + case (EmptyTree, EmptyTree) => + matched + + // No Match + case _ => + if (debug) + val quotes = QuotesImpl() + println( + s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + |Scrutinee + | ${scrutinee.show} + |did not match pattern + | ${pattern.show} + | + |with environment: ${summon[Env]} + | + |Scrutinee: ${quotes.reflect.Printer.TreeStructure.show(scrutinee.asInstanceOf)} + |Pattern: ${quotes.reflect.Printer.TreeStructure.show(pattern.asInstanceOf)} + | + |""".stripMargin) + notMatched + + end extension + + /** Does the scrutinee symbol match the pattern symbol? It matches if: + * - They are the same symbol + * - The scrutinee has is in the environment and they are equivalent + * - The scrutinee overrides the symbol of the pattern + */ + private def symbolMatch(scrutineeTree: Tree, patternTree: Tree)(using Env, Context): Boolean = + val scrutinee = scrutineeTree.symbol + + def overridingSymbol(ofclazz: Symbol): Symbol = + if ofclazz.isClass then scrutinee.denot.overridingSymbol(ofclazz.asClass) + else NoSymbol + + val devirtualizedScrutinee = scrutineeTree match + case Select(qual, _) => + val sym = overridingSymbol(qual.tpe.typeSymbol) + if sym.exists then sym + else scrutinee + case _ => scrutinee + val pattern = patternTree.symbol + + + devirtualizedScrutinee == pattern + || summon[Env].get(devirtualizedScrutinee).contains(pattern) + || devirtualizedScrutinee.allOverriddenSymbols.contains(pattern) + + private object ClosedPatternTerm { + /** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */ + def unapply(term: Tree)(using Env, Context): Option[term.type] = + if freePatternVars(term).isEmpty then Some(term) else None + + /** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */ + def freePatternVars(term: Tree)(using Env, Context): Set[Symbol] = + val accumulator = new TreeAccumulator[Set[Symbol]] { + def apply(x: Set[Symbol], tree: Tree)(using Context): Set[Symbol] = + tree match + case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(x + tree.symbol, tree) + case _ => foldOver(x, tree) + } + accumulator.apply(Set.empty, term) + } + + /** Result of matching a part of an expression */ + private opaque type Matching = Option[Tuple] + + private object Matching { + + def notMatched: Matching = None + + val matched: Matching = Some(Tuple()) + + def matched(tree: Tree)(using Context): Matching = + Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent))) + + extension (self: Matching) + def asOptionOfTuple: Option[Tuple] = self + + /** Concatenates the contents of two successful matchings or return a `notMatched` */ + def &&& (that: => Matching): Matching = self match { + case Some(x) => + that match { + case Some(y) => Some(x ++ y) + case _ => None + } + case _ => None + } + end extension + + } + +} diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index fed05b4492fd..7a984bbc3492 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -2935,24 +2935,16 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler ctx1.gadt.addToConstraint(typeHoles) ctx1 - val qctx1 = QuotesImpl()(using ctx1) + val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1) - val matcher = new Matcher.QuoteMatcher[qctx1.type](qctx1) { - def patternHoleSymbol: qctx1.reflect.Symbol = dotc.core.Symbols.defn.QuotedRuntimePatterns_patternHole.asInstanceOf - def higherOrderHoleSymbol: qctx1.reflect.Symbol = dotc.core.Symbols.defn.QuotedRuntimePatterns_higherOrderHole.asInstanceOf - } - - val matchings = - if pat1.isType then matcher.termMatch(scrutinee.asInstanceOf[matcher.qctx.reflect.Term], pat1.asInstanceOf[matcher.qctx.reflect.Term]) - else matcher.termMatch(scrutinee.asInstanceOf[matcher.qctx.reflect.Term], pat1.asInstanceOf[matcher.qctx.reflect.Term]) - - // val matchings = matcher.termMatch(scrutinee, pattern) if typeHoles.isEmpty then matchings else { // After matching and doing all subtype checks, we have to approximate all the type bindings // that we have found, seal them in a quoted.Type and add them to the result def typeHoleApproximation(sym: Symbol) = - ctx1.gadt.approximation(sym, !sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)).asInstanceOf[qctx1.reflect.TypeRepr].asType + val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot) + val approx = ctx1.gadt.approximation(sym, !fromAboveAnnot) + reflect.TypeReprMethods.asType(approx) matchings.map { tup => Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup }