Skip to content

Commit 02e6f88

Browse files
committed
WIP Add runtime.quoted.Matcher
This allows to match a quote against another quote while extracting the contents of defined holes
1 parent 07847f8 commit 02e6f88

File tree

5 files changed

+317
-1
lines changed

5 files changed

+317
-1
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
200200
type Term = tpd.Tree
201201

202202
def matchTerm(tree: Tree)(implicit ctx: Context): Option[Term] =
203-
if (tree.isTerm) Some(tree) else None
203+
if (tree.isTerm) Some(tree) else matchRepeated(tree)
204204

205205
// TODO move to Kernel and use isTerm directly with a cast
206206
def matchTermNotTypeTree(termOrTypeTree: TermOrTypeTree)(implicit ctx: Context): Option[Term] =
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package scala.runtime.quoted
2+
3+
import scala.quoted._
4+
import scala.tasty._
5+
6+
object Matcher {
7+
8+
type Hole[T /* <: AnyKind */] = T
9+
10+
def hole[T]: T = ???
11+
12+
def unapplySeq(scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Seq[Any]] = {
13+
import reflection._
14+
15+
def treeMatches(scrutinee: Tree, pattern: Tree): Option[Seq[Any]] = {
16+
import Term._
17+
(scrutinee, pattern) match {
18+
// Normalize blocks without statements
19+
case (Block(Nil, expr), _) => treeMatches(expr, pattern)
20+
case (_, Block(Nil, pat)) => treeMatches(scrutinee, pat)
21+
22+
case (IsTerm(scrutinee), TypeApply(Ident("hole"), tpt :: Nil))
23+
if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.hole" && // TODO check symbol equality instead of its name
24+
scrutinee.tpe <:< tpt.tpe =>
25+
Some(Seq(scrutinee))
26+
27+
case (Inlined(None, Nil, scr), _) =>
28+
treeMatches(scr, pattern)
29+
case (_, Inlined(None, Nil, pat)) =>
30+
treeMatches(scrutinee, pat)
31+
32+
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
33+
Some(Seq.empty)
34+
35+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol =>
36+
Some(Seq.empty)
37+
38+
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
39+
foldMatchings(treeMatches(expr1, expr2) :: typeTreeMatches(tpt1, tpt2) :: Nil)
40+
41+
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
42+
treeMatches(qual1, qual2)
43+
44+
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
45+
foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield treeMatches(arg1, arg2)))
46+
47+
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
48+
foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield typeTreeMatches(arg1, arg2)))
49+
50+
case (Block(stats1, expr1), Block(stats2, expr2)) =>
51+
// TODO handle bindings
52+
foldMatchings((for ((stat1, stat2) <- stats1.zip(stats2)) yield treeMatches(stat1, stat2)) ::: treeMatches(expr1, expr2) :: Nil)
53+
54+
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
55+
foldMatchings(treeMatches(cond1, cond2) :: treeMatches(thenp1, thenp2) :: treeMatches(elsep1, elsep2) :: Nil)
56+
57+
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
58+
// TODO how to handle LHS?
59+
if (treeMatches(lhs1, lhs2).isDefined) treeMatches(rhs1, rhs2)
60+
else None
61+
62+
case (While(cond1, body1), While(cond2, body2)) =>
63+
foldMatchings(treeMatches(cond1, cond2) :: treeMatches(body1, body2) :: Nil)
64+
65+
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
66+
treeMatches(expr1, expr2)
67+
68+
case (New(tpt1), New(tpt2)) =>
69+
typeTreeMatches(tpt1, tpt2)
70+
71+
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
72+
Some(Seq.empty)
73+
74+
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
75+
foldMatchings(for ((elem1, elem2) <- elems1.zip(elems2)) yield treeMatches(elem1, elem2))
76+
77+
case _ =>
78+
// println(
79+
// s"""Scrutinee ${scrutinee.showCode}
80+
// |${scrutinee.show}
81+
// |did not match pattern ${pattern.showCode}
82+
// |${pattern.show}
83+
// |
84+
// |""".stripMargin)
85+
None
86+
}
87+
}
88+
89+
def typeTreeMatches(scrutinee: TypeOrBoundsTree, pattern: TypeOrBoundsTree): Option[Seq[Any]] = {
90+
import TypeTree._
91+
(scrutinee, pattern) match {
92+
case (IsTypeTree(scrutinee), IsTypeTree(pattern @ Applied(Ident("Hole"), IsTypeTree(tpt) :: Nil)))
93+
if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.Hole" && // TODO check symbol equality instead of its name
94+
scrutinee.tpe <:< tpt.tpe => // Is the subtype check required?
95+
Some(Seq(scrutinee))
96+
97+
case (IsTypeTree(scrutinee @ Ident(_)), IsTypeTree(pattern @ Ident(_))) if scrutinee.symbol == pattern.symbol =>
98+
Some(Seq.empty)
99+
100+
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
101+
Some(Seq.empty)
102+
103+
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
104+
val matchings: List[Option[Seq[Any]]] =
105+
typeTreeMatches(tycon1, tycon2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield typeTreeMatches(arg1, arg2))
106+
foldMatchings(matchings)
107+
108+
case _ =>
109+
// println(
110+
// s"""Scrutinee ${scrutinee.showCode}
111+
// |${scrutinee.show}
112+
// |did not match pattern ${pattern.showCode}
113+
// |${pattern.show}
114+
// |
115+
// |""".stripMargin)
116+
None
117+
}
118+
}
119+
120+
treeMatches(scrutineeExpr.unseal, patternExpr.unseal)
121+
}
122+
123+
private def foldMatchings(matchings: List[Option[Seq[Any]]]): Option[Seq[Any]] = {
124+
matchings.foldLeft[Option[Seq[Any]]](Some(Seq.empty)) {
125+
case (Some(acc), Some(holes)) => Some(acc ++ holes)
126+
case (_, _) => None
127+
}
128+
}
129+
130+
}

tests/run/quote-matcher-runtime.check

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
Matches
2+
Some(List())
3+
Some(List(Literal(Constant(3))))
4+
Some(List(Ident(x)))
5+
Some(List(Literal(Constant(5))))
6+
Some(List(Apply(Select(Literal(Constant(6)),+),List(Ident(x)))))
7+
Some(List(Ident(x)))
8+
Some(List(Literal(Constant(6))))
9+
Some(List(Literal(Constant(6)), Ident(x)))
10+
Some(List(Ident(x)))
11+
Some(List(Apply(Ident(f),List(Literal(Constant(4))))))
12+
Some(List(Literal(Constant(5))))
13+
Some(List(TypeApply(Ident(g),List(Ident(Int)))))
14+
Some(List(Apply(TypeApply(Ident(h),List(Ident(Int))),List(Literal(Constant(7))))))
15+
Some(List(Literal(Constant(8))))
16+
Some(List())
17+
Some(List(This(Ident(Test$))))
18+
Some(List())
19+
Some(List(Apply(Select(New(Ident(Foo)),<init>),List(Literal(Constant(1))))))
20+
Some(List(Literal(Constant(1))))
21+
Some(List())
22+
Some(List(If(Ident(b),Ident(x),Ident(y))))
23+
Some(List(Ident(b), Ident(x), Ident(y)))
24+
Some(List())
25+
Some(List(WhileDo(Ident(b),Block(List(Ident(x)),Literal(Constant(()))))))
26+
Some(List(Ident(b), Ident(x)))
27+
Some(List())
28+
Some(List(Assign(Ident(z),Literal(Constant(4)))))
29+
Some(List(Literal(Constant(4))))
30+
Some(List())
31+
Some(List())
32+
Some(List())
33+
Some(List(SeqLiteral(List(),TypeTree[TypeRef(TermRef(ThisType(TypeRef(NoPrefix,module class <root>)),module scala),class Int)])))
34+
Some(List())
35+
Some(List(Literal(Constant(1)), Literal(Constant(2))))
36+
Some(List(SeqLiteral(List(Literal(Constant(1)), Literal(Constant(2)), Literal(Constant(3))),TypeTree[TypeRef(TermRef(ThisType(TypeRef(NoPrefix,module class <root>)),module scala),class Int)])))
37+
Some(List())
38+
Some(List())
39+
Some(List(Literal(Constant(1)), Literal(Constant(2))))
40+
41+
Matches type
42+
Some(List())
43+
Some(List(Ident(Int)))
44+
Some(List(Ident(Int)))
45+
Some(List(Ident(Int)))
46+
Some(List(Ident(Int)))
47+
Some(List(SingletonTypeTree(Literal(Constant(6)))))
48+
Some(List(SingletonTypeTree(Literal(Constant(6)))))
49+
Some(List(SingletonTypeTree(This(Ident()))))
50+
51+
No match
52+
None
53+
None
54+
None
55+
None
56+
None
57+
None
58+
None
59+
None
60+
None
61+
62+
No match type
63+
None
64+
None
65+
None
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted._
2+
import scala.tasty.Reflection
3+
4+
object Macros {
5+
6+
inline def matches[A, B](a: => A, b: => B): Unit = ${impl('a, 'b)}
7+
8+
private def impl[A, B](a: Expr[A], b: Expr[B])(implicit reflect: Reflection): Expr[Unit] = {
9+
import reflect._
10+
11+
val res = scala.runtime.quoted.Matcher.unapplySeq(a)(b, reflect)
12+
13+
'{ println(${res.toString.toExpr}) }
14+
}
15+
16+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
2+
import Macros._
3+
4+
import scala.runtime.quoted.Matcher.Hole
5+
import scala.runtime.quoted.Matcher.hole
6+
7+
object Test {
8+
9+
def main(args: Array[String]): Unit = {
10+
val b: Boolean = true
11+
val x: Int = 42
12+
val y: Int = 52
13+
var z: Int = 62
14+
var z2: Int = 62
15+
def f(a: Int): Int = 72
16+
def f2(a: Int, b: Int): Int = 72
17+
def g[A]: A = ???
18+
def h[A](a: A): A = a
19+
def fs(a: Int*): Int = 72
20+
21+
// Matches
22+
println("Matches")
23+
matches(1, 1)
24+
matches(3, hole[Int])
25+
matches(x, hole[Int])
26+
matches(5, hole[Any])
27+
matches(6 + x, hole[Int])
28+
matches(6 + x, 6 + hole[Int])
29+
matches(6 + x, hole[Int] + x)
30+
matches(6 + x, hole[Int] + hole[Int])
31+
matches(6 + x + y, 6 + hole[Int] + y)
32+
matches(f(4), hole[Int])
33+
matches(f(5), f(hole[Int]))
34+
matches(g[Int], hole[Int])
35+
matches(h[Int](7), hole[Int])
36+
matches(h[Int](8), h[Int](hole[Int]))
37+
matches(this, this)
38+
matches(this, hole[this.type])
39+
matches(new Foo(1), new Foo(1))
40+
matches(new Foo(1), hole[Foo])
41+
matches(new Foo(1), new Foo(hole[Int]))
42+
matches(if (b) x else y, if (b) x else y)
43+
matches(if (b) x else y, hole[Int])
44+
matches(if (b) x else y, if (hole[Boolean]) hole[Int] else hole[Int])
45+
matches(while (b) x, while (b) x)
46+
matches(while (b) x, hole[Unit])
47+
matches(while (b) x, while (hole[Boolean]) hole[Int])
48+
matches({z = 4}, {z = 4})
49+
matches({z = 4}, hole[Unit])
50+
matches({z = 4}, {z = hole[Int]})
51+
matches(1, {1})
52+
matches({1}, 1)
53+
// Should these match?
54+
// matches({(); 1}, 1)
55+
// matches(1, {(); 1})
56+
matches(fs(), fs())
57+
matches(fs(), fs(hole[Seq[Int]]: _*))
58+
matches(fs(1, 2, 3), fs(1, 2, 3))
59+
matches(fs(1, 2, 3), fs(hole[Int], hole[Int], 3))
60+
matches(fs(1, 2, 3), fs(hole[Seq[Int]]: _*))
61+
matches(f2(1, 2), f2(1, 2))
62+
matches(f2(a = 1, b = 2), f2(a = 1, b = 2))
63+
matches(f2(a = 1, b = 2), f2(a = hole[Int], b = hole[Int]))
64+
// Should these match?
65+
// matches(f2(a = 1, b = 2), f2(1, 2))
66+
// matches(f2(b = 2, a = 1), f2(1, 2))
67+
68+
println()
69+
println("Matches type")
70+
matches(1: Int, 1: Int)
71+
matches(1: Int, 1: Hole[Int])
72+
matches(Nil: List[Int], Nil: List[Hole[Int]])
73+
matches(g[Int], g[Hole[Int]])
74+
matches(h[Int](6), h[Hole[Int]](6))
75+
matches(h[6](6), h[Hole[6]](6))
76+
matches(h[6](6), h[Hole[Int]](6))
77+
matches(g[this.type], g[Hole[this.type]])
78+
79+
// TODO add lots of tests
80+
81+
// No match
82+
println()
83+
println("No match")
84+
matches(1, 2)
85+
matches(4, hole[String])
86+
matches(6 + x, 7 + hole[Int])
87+
matches(6 + x, hole[Int] + 4)
88+
matches(g[Int], hole[String])
89+
matches(h[Int](7), h[String](hole[String]))
90+
matches(h[Int](6), h[Hole[Int]](7))
91+
matches({z = 4}, {z = 5})
92+
matches({z = 4}, {z2 = 4})
93+
94+
println()
95+
println("No match type")
96+
matches(??? : Int, ??? : Hole[String])
97+
matches(g[Int], g[Hole[String]])
98+
matches(h[Int](6), h[Hole[String]]("abc"))
99+
100+
// TODO add lots of tests
101+
102+
}
103+
}
104+
105+
class Foo(a: Int)

0 commit comments

Comments
 (0)