Skip to content

Commit 7b15934

Browse files
committed
Fix scala#10573: Devirtualize member selection when matching
1 parent 90f44d4 commit 7b15934

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

compiler/src/scala/quoted/runtime/impl/Matcher.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ object Matcher {
231231
scrutinee =?= expr2
232232

233233
/* Match selection */
234-
case (ref: Ref, Select(qual2, _)) if symbolMatch(scrutinee.symbol, pattern.symbol) =>
234+
case (ref: Ref, Select(qual2, _)) if symbolMatch(scrutinee, pattern) =>
235235
ref match
236236
case Select(qual1, _) => qual1 =?= qual2
237237
case ref: Ident =>
@@ -240,7 +240,7 @@ object Matcher {
240240
case _ => matched
241241

242242
/* Match reference */
243-
case (_: Ref, _: Ident) if symbolMatch(scrutinee.symbol, pattern.symbol) =>
243+
case (_: Ref, _: Ident) if symbolMatch(scrutinee, pattern) =>
244244
matched
245245

246246
/* Match application */
@@ -348,10 +348,21 @@ object Matcher {
348348
* - The scrutinee has is in the environment and they are equivalent
349349
* - The scrutinee overrides the symbol of the pattern
350350
*/
351-
private def symbolMatch(scrutinee: Symbol, pattern: Symbol)(using Env): Boolean =
352-
scrutinee == pattern
353-
|| summon[Env].get(scrutinee).contains(pattern)
354-
|| scrutinee.allOverriddenSymbols.contains(pattern)
351+
private def symbolMatch(scrutineeTree: Tree, patternTree: Tree)(using Env): Boolean =
352+
val scrutinee = scrutineeTree.symbol
353+
val devirtualizedScrutinee = scrutineeTree match
354+
case Select(qual, _) =>
355+
// TODO improve performance
356+
qual.tpe.typeSymbol.methods.find( sym =>
357+
sym.name == scrutinee.name &&
358+
sym.allOverriddenSymbols.contains(scrutinee)
359+
).getOrElse(scrutinee)
360+
case _ => scrutinee
361+
val pattern = patternTree.symbol
362+
363+
devirtualizedScrutinee == pattern
364+
|| summon[Env].get(devirtualizedScrutinee).contains(pattern)
365+
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
355366

356367
private object ClosedPatternTerm {
357368
/** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */

tests/pos-macros/i10573/Macro_1.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted._
2+
3+
4+
trait A { def foo: Int }
5+
class B extends A { def foo: Int = 1 }
6+
7+
inline def test(): Unit = ${ testExpr() }
8+
9+
def testExpr()(using Quotes): Expr[Unit] = {
10+
val e0: Expr[A] = '{ new B }
11+
val e1: Expr[Int] = '{ $e0.foo }
12+
e1 match
13+
case '{ ($x: B).foo } => '{ val b: B = $x; () }
14+
case _ => quotes.reflect.report.throwError("did not match")
15+
}
16+

tests/pos-macros/i10573/Test_2.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
def Test = test()

0 commit comments

Comments
 (0)