Skip to content

Commit 96d5073

Browse files
committed
Fix #10573: Devirtualize member selection when matching
1 parent 90f44d4 commit 96d5073

File tree

5 files changed

+42
-6
lines changed

5 files changed

+42
-6
lines changed

compiler/src/scala/quoted/runtime/impl/Matcher.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ object Matcher {
231231
scrutinee =?= expr2
232232

233233
/* Match selection */
234-
case (ref: Ref, Select(qual2, _)) if symbolMatch(scrutinee.symbol, pattern.symbol) =>
234+
case (ref: Ref, Select(qual2, _)) if symbolMatch(scrutinee, pattern) =>
235235
ref match
236236
case Select(qual1, _) => qual1 =?= qual2
237237
case ref: Ident =>
@@ -240,7 +240,7 @@ object Matcher {
240240
case _ => matched
241241

242242
/* Match reference */
243-
case (_: Ref, _: Ident) if symbolMatch(scrutinee.symbol, pattern.symbol) =>
243+
case (_: Ref, _: Ident) if symbolMatch(scrutinee, pattern) =>
244244
matched
245245

246246
/* Match application */
@@ -348,10 +348,19 @@ object Matcher {
348348
* - The scrutinee has is in the environment and they are equivalent
349349
* - The scrutinee overrides the symbol of the pattern
350350
*/
351-
private def symbolMatch(scrutinee: Symbol, pattern: Symbol)(using Env): Boolean =
352-
scrutinee == pattern
353-
|| summon[Env].get(scrutinee).contains(pattern)
354-
|| scrutinee.allOverriddenSymbols.contains(pattern)
351+
private def symbolMatch(scrutineeTree: Tree, patternTree: Tree)(using Env): Boolean =
352+
val scrutinee = scrutineeTree.symbol
353+
val devirtualizedScrutinee = scrutineeTree match
354+
case Select(qual, _) =>
355+
val sym = scrutinee.overridingSymbol(qual.tpe.typeSymbol)
356+
if sym.exists then sym
357+
else scrutinee
358+
case _ => scrutinee
359+
val pattern = patternTree.symbol
360+
361+
devirtualizedScrutinee == pattern
362+
|| summon[Env].get(devirtualizedScrutinee).contains(pattern)
363+
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
355364

356365
private object ClosedPatternTerm {
357366
/** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,6 +2349,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
23492349
def paramSymss: List[List[Symbol]] = self.denot.paramSymss
23502350
def primaryConstructor: Symbol = self.denot.primaryConstructor
23512351
def allOverriddenSymbols: Iterator[Symbol] = self.denot.allOverriddenSymbols
2352+
def overridingSymbol(ofclazz: Symbol): Symbol =
2353+
if ofclazz.isClass then self.denot.overridingSymbol(ofclazz.asClass)
2354+
else dotc.core.Symbols.NoSymbol
23522355

23532356
def caseFields: List[Symbol] =
23542357
if !self.isClass then Nil

library/src/scala/quoted/Quotes.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3624,6 +3624,12 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
36243624
/** Returns all symbols overridden by this symbol. */
36253625
def allOverriddenSymbols: Iterator[Symbol]
36263626

3627+
/** The symbol overriding this symbol in given subclass `ofclazz`.
3628+
*
3629+
* @param ofclazz is a subclass of this symbol's owner
3630+
*/
3631+
def overridingSymbol(ofclazz: Symbol): Symbol
3632+
36273633
/** The primary constructor of a class or trait, `noSymbol` if not applicable. */
36283634
def primaryConstructor: Symbol
36293635

tests/pos-macros/i10573/Macro_1.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted._
2+
3+
4+
trait A { def foo: Int }
5+
class B extends A { def foo: Int = 1 }
6+
7+
inline def test(): Unit = ${ testExpr() }
8+
9+
def testExpr()(using Quotes): Expr[Unit] = {
10+
val e0: Expr[A] = '{ new B }
11+
val e1: Expr[Int] = '{ $e0.foo }
12+
e1 match
13+
case '{ ($x: B).foo } => '{ val b: B = $x; () }
14+
case _ => quotes.reflect.report.throwError("did not match")
15+
}
16+

tests/pos-macros/i10573/Test_2.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
def Test = test()

0 commit comments

Comments
 (0)