From 96d507367612e5f68266b4ecd656301a408b1fc0 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 1 Dec 2020 11:29:13 +0100 Subject: [PATCH] Fix #10573: Devirtualize member selection when matching --- .../scala/quoted/runtime/impl/Matcher.scala | 21 +++++++++++++------ .../quoted/runtime/impl/QuotesImpl.scala | 3 +++ library/src/scala/quoted/Quotes.scala | 6 ++++++ tests/pos-macros/i10573/Macro_1.scala | 16 ++++++++++++++ tests/pos-macros/i10573/Test_2.scala | 2 ++ 5 files changed, 42 insertions(+), 6 deletions(-) create mode 100644 tests/pos-macros/i10573/Macro_1.scala create mode 100644 tests/pos-macros/i10573/Test_2.scala diff --git a/compiler/src/scala/quoted/runtime/impl/Matcher.scala b/compiler/src/scala/quoted/runtime/impl/Matcher.scala index 21fb5bb6d74f..89086df98156 100644 --- a/compiler/src/scala/quoted/runtime/impl/Matcher.scala +++ b/compiler/src/scala/quoted/runtime/impl/Matcher.scala @@ -231,7 +231,7 @@ object Matcher { scrutinee =?= expr2 /* Match selection */ - case (ref: Ref, Select(qual2, _)) if symbolMatch(scrutinee.symbol, pattern.symbol) => + case (ref: Ref, Select(qual2, _)) if symbolMatch(scrutinee, pattern) => ref match case Select(qual1, _) => qual1 =?= qual2 case ref: Ident => @@ -240,7 +240,7 @@ object Matcher { case _ => matched /* Match reference */ - case (_: Ref, _: Ident) if symbolMatch(scrutinee.symbol, pattern.symbol) => + case (_: Ref, _: Ident) if symbolMatch(scrutinee, pattern) => matched /* Match application */ @@ -348,10 +348,19 @@ object Matcher { * - The scrutinee has is in the environment and they are equivalent * - The scrutinee overrides the symbol of the pattern */ - private def symbolMatch(scrutinee: Symbol, pattern: Symbol)(using Env): Boolean = - scrutinee == pattern - || summon[Env].get(scrutinee).contains(pattern) - || scrutinee.allOverriddenSymbols.contains(pattern) + private def symbolMatch(scrutineeTree: Tree, patternTree: Tree)(using Env): Boolean = + val scrutinee = scrutineeTree.symbol + val devirtualizedScrutinee = scrutineeTree match + case Select(qual, _) => + val sym = scrutinee.overridingSymbol(qual.tpe.typeSymbol) + if sym.exists then sym + else scrutinee + case _ => scrutinee + val pattern = patternTree.symbol + + devirtualizedScrutinee == pattern + || summon[Env].get(devirtualizedScrutinee).contains(pattern) + || devirtualizedScrutinee.allOverriddenSymbols.contains(pattern) private object ClosedPatternTerm { /** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */ diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 6e197f5462fc..cd6f94458eb7 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -2349,6 +2349,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler def paramSymss: List[List[Symbol]] = self.denot.paramSymss def primaryConstructor: Symbol = self.denot.primaryConstructor def allOverriddenSymbols: Iterator[Symbol] = self.denot.allOverriddenSymbols + def overridingSymbol(ofclazz: Symbol): Symbol = + if ofclazz.isClass then self.denot.overridingSymbol(ofclazz.asClass) + else dotc.core.Symbols.NoSymbol def caseFields: List[Symbol] = if !self.isClass then Nil diff --git a/library/src/scala/quoted/Quotes.scala b/library/src/scala/quoted/Quotes.scala index a8dd2317e5d8..4c0f1bdb6d96 100644 --- a/library/src/scala/quoted/Quotes.scala +++ b/library/src/scala/quoted/Quotes.scala @@ -3624,6 +3624,12 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => /** Returns all symbols overridden by this symbol. */ def allOverriddenSymbols: Iterator[Symbol] + /** The symbol overriding this symbol in given subclass `ofclazz`. + * + * @param ofclazz is a subclass of this symbol's owner + */ + def overridingSymbol(ofclazz: Symbol): Symbol + /** The primary constructor of a class or trait, `noSymbol` if not applicable. */ def primaryConstructor: Symbol diff --git a/tests/pos-macros/i10573/Macro_1.scala b/tests/pos-macros/i10573/Macro_1.scala new file mode 100644 index 000000000000..7c49a8e5ca13 --- /dev/null +++ b/tests/pos-macros/i10573/Macro_1.scala @@ -0,0 +1,16 @@ +import scala.quoted._ + + +trait A { def foo: Int } +class B extends A { def foo: Int = 1 } + +inline def test(): Unit = ${ testExpr() } + +def testExpr()(using Quotes): Expr[Unit] = { + val e0: Expr[A] = '{ new B } + val e1: Expr[Int] = '{ $e0.foo } + e1 match + case '{ ($x: B).foo } => '{ val b: B = $x; () } + case _ => quotes.reflect.report.throwError("did not match") +} + diff --git a/tests/pos-macros/i10573/Test_2.scala b/tests/pos-macros/i10573/Test_2.scala new file mode 100644 index 000000000000..1e5f437d6bd9 --- /dev/null +++ b/tests/pos-macros/i10573/Test_2.scala @@ -0,0 +1,2 @@ + +def Test = test()