Skip to content

Commit 6947983

Browse files
committed
WIP Add runtime.quoted.Matcher
This allows to match a quote against another quote while extracting the contents of defined holes
1 parent 07847f8 commit 6947983

File tree

6 files changed

+323
-2
lines changed

6 files changed

+323
-2
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
200200
type Term = tpd.Tree
201201

202202
def matchTerm(tree: Tree)(implicit ctx: Context): Option[Term] =
203-
if (tree.isTerm) Some(tree) else None
203+
if (tree.isTerm) Some(tree) else matchRepeated(tree)
204204

205205
// TODO move to Kernel and use isTerm directly with a cast
206206
def matchTermNotTypeTree(termOrTypeTree: TermOrTypeTree)(implicit ctx: Context): Option[Term] =
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package scala.runtime.quoted
2+
3+
import scala.quoted._
4+
import scala.tasty._
5+
6+
object Matcher {
7+
8+
type Hole[T /* <: AnyKind */] = T
9+
10+
def hole[T]: T = ???
11+
12+
def unapplySeq(scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Seq[Any]] = {
13+
import reflection._
14+
15+
def treeMatches(scrutinee: Tree, pattern: Tree): Option[Seq[Any]] = {
16+
import Term._
17+
(scrutinee, pattern) match {
18+
// Normalize blocks without statements
19+
case (Block(Nil, expr), _) => treeMatches(expr, pattern)
20+
case (_, Block(Nil, pat)) => treeMatches(scrutinee, pat)
21+
22+
case (IsTerm(scrutinee), TypeApply(Ident("hole"), tpt :: Nil))
23+
if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.hole" && // TODO check symbol equality instead of its name
24+
scrutinee.tpe <:< tpt.tpe =>
25+
Some(Seq(scrutinee))
26+
27+
case (Inlined(None, Nil, scr), _) =>
28+
treeMatches(scr, pattern)
29+
case (_, Inlined(None, Nil, pat)) =>
30+
treeMatches(scrutinee, pat)
31+
32+
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
33+
Some(Seq.empty)
34+
35+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol =>
36+
Some(Seq.empty)
37+
38+
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
39+
foldMatchings(treeMatches(expr1, expr2) :: typeTreeMatches(tpt1, tpt2) :: Nil)
40+
41+
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
42+
treeMatches(qual1, qual2)
43+
44+
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
45+
foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield treeMatches(arg1, arg2)))
46+
47+
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
48+
foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield typeTreeMatches(arg1, arg2)))
49+
50+
case (Block(stats1, expr1), Block(stats2, expr2)) =>
51+
// TODO handle bindings
52+
foldMatchings((for ((stat1, stat2) <- stats1.zip(stats2)) yield treeMatches(stat1, stat2)) ::: treeMatches(expr1, expr2) :: Nil)
53+
54+
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
55+
foldMatchings(treeMatches(cond1, cond2) :: treeMatches(thenp1, thenp2) :: treeMatches(elsep1, elsep2) :: Nil)
56+
57+
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
58+
// TODO how to handle LHS?
59+
if (treeMatches(lhs1, lhs2).isDefined) treeMatches(rhs1, rhs2)
60+
else None
61+
62+
case (While(cond1, body1), While(cond2, body2)) =>
63+
foldMatchings(treeMatches(cond1, cond2) :: treeMatches(body1, body2) :: Nil)
64+
65+
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
66+
treeMatches(expr1, expr2)
67+
68+
case (New(tpt1), New(tpt2)) =>
69+
typeTreeMatches(tpt1, tpt2)
70+
71+
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
72+
Some(Seq.empty)
73+
74+
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
75+
treeMatches(qual1, qual2)
76+
77+
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
78+
foldMatchings(for ((elem1, elem2) <- elems1.zip(elems2)) yield treeMatches(elem1, elem2))
79+
80+
case _ =>
81+
// println(
82+
// s"""Scrutinee ${scrutinee.showCode}
83+
// |${scrutinee.show}
84+
// |did not match pattern ${pattern.showCode}
85+
// |${pattern.show}
86+
// |
87+
// |""".stripMargin)
88+
None
89+
}
90+
}
91+
92+
def typeTreeMatches(scrutinee: TypeOrBoundsTree, pattern: TypeOrBoundsTree): Option[Seq[Any]] = {
93+
import TypeTree._
94+
(scrutinee, pattern) match {
95+
case (IsTypeTree(scrutinee), IsTypeTree(pattern @ Applied(Ident("Hole"), IsTypeTree(tpt) :: Nil)))
96+
if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.Hole" && // TODO check symbol equality instead of its name
97+
scrutinee.tpe <:< tpt.tpe => // Is the subtype check required?
98+
Some(Seq(scrutinee))
99+
100+
case (IsTypeTree(scrutinee @ Ident(_)), IsTypeTree(pattern @ Ident(_))) if scrutinee.symbol == pattern.symbol =>
101+
Some(Seq.empty)
102+
103+
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
104+
Some(Seq.empty)
105+
106+
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
107+
val matchings: List[Option[Seq[Any]]] =
108+
typeTreeMatches(tycon1, tycon2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield typeTreeMatches(arg1, arg2))
109+
foldMatchings(matchings)
110+
111+
case _ =>
112+
// println(
113+
// s"""Scrutinee ${scrutinee.showCode}
114+
// |${scrutinee.show}
115+
// |did not match pattern ${pattern.showCode}
116+
// |${pattern.show}
117+
// |
118+
// |""".stripMargin)
119+
None
120+
}
121+
}
122+
123+
treeMatches(scrutineeExpr.unseal, patternExpr.unseal)
124+
}
125+
126+
private def foldMatchings(matchings: List[Option[Seq[Any]]]): Option[Seq[Any]] = {
127+
matchings.foldLeft[Option[Seq[Any]]](Some(Seq.empty)) {
128+
case (Some(acc), Some(holes)) => Some(acc ++ holes)
129+
case (_, _) => None
130+
}
131+
}
132+
133+
}

library/src/scala/tasty/reflect/Printers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ trait Printers
181181
case Term.This(qual) =>
182182
this += "Term.This(" += qual += ")"
183183
case Term.Super(qual, mix) =>
184-
this += "Term.TypeApply(" += qual += ", " += mix += ")"
184+
this += "Term.Super(" += qual += ", " += mix += ")"
185185
case Term.Apply(fun, args) =>
186186
this += "Term.Apply(" += fun += ", " ++= args += ")"
187187
case Term.TypeApply(fun, args) =>

tests/run/quote-matcher-runtime.check

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
Matches
2+
Some(List())
3+
Some(List(Literal(Constant(3))))
4+
Some(List(Ident(x)))
5+
Some(List(Literal(Constant(5))))
6+
Some(List(Apply(Select(Literal(Constant(6)),+),List(Ident(x)))))
7+
Some(List(Ident(x)))
8+
Some(List(Literal(Constant(6))))
9+
Some(List(Literal(Constant(6)), Ident(x)))
10+
Some(List(Ident(x)))
11+
Some(List(Apply(Ident(f),List(Literal(Constant(4))))))
12+
Some(List(Literal(Constant(5))))
13+
Some(List(TypeApply(Ident(g),List(Ident(Int)))))
14+
Some(List(Apply(TypeApply(Ident(h),List(Ident(Int))),List(Literal(Constant(7))))))
15+
Some(List(Literal(Constant(8))))
16+
Some(List())
17+
Some(List(This(Ident(Test$))))
18+
Some(List())
19+
Some(List(Apply(Select(New(Ident(Foo)),<init>),List(Literal(Constant(1))))))
20+
Some(List(Literal(Constant(1))))
21+
Some(List())
22+
Some(List(If(Ident(b),Ident(x),Ident(y))))
23+
Some(List(Ident(b), Ident(x), Ident(y)))
24+
Some(List())
25+
Some(List(WhileDo(Ident(b),Block(List(Ident(x)),Literal(Constant(()))))))
26+
Some(List(Ident(b), Ident(x)))
27+
Some(List())
28+
Some(List(Assign(Ident(z),Literal(Constant(4)))))
29+
Some(List(Literal(Constant(4))))
30+
Some(List())
31+
Some(List())
32+
Some(List())
33+
Some(List(SeqLiteral(List(),TypeTree[TypeRef(TermRef(ThisType(TypeRef(NoPrefix,module class <root>)),module scala),class Int)])))
34+
Some(List())
35+
Some(List(Literal(Constant(1)), Literal(Constant(2))))
36+
Some(List(SeqLiteral(List(Literal(Constant(1)), Literal(Constant(2)), Literal(Constant(3))),TypeTree[TypeRef(TermRef(ThisType(TypeRef(NoPrefix,module class <root>)),module scala),class Int)])))
37+
Some(List())
38+
Some(List())
39+
Some(List(Literal(Constant(1)), Literal(Constant(2))))
40+
Some(List())
41+
42+
Matches type
43+
Some(List())
44+
Some(List(Ident(Int)))
45+
Some(List(Ident(Int)))
46+
Some(List(Ident(Int)))
47+
Some(List(Ident(Int)))
48+
Some(List(SingletonTypeTree(Literal(Constant(6)))))
49+
Some(List(SingletonTypeTree(Literal(Constant(6)))))
50+
Some(List(SingletonTypeTree(This(Ident()))))
51+
52+
No match
53+
None
54+
None
55+
None
56+
None
57+
None
58+
None
59+
None
60+
None
61+
None
62+
63+
No match type
64+
None
65+
None
66+
None
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted._
2+
import scala.tasty.Reflection
3+
4+
object Macros {
5+
6+
inline def matches[A, B](a: => A, b: => B): Unit = ${impl('a, 'b)}
7+
8+
private def impl[A, B](a: Expr[A], b: Expr[B])(implicit reflect: Reflection): Expr[Unit] = {
9+
import reflect._
10+
11+
val res = scala.runtime.quoted.Matcher.unapplySeq(a)(b, reflect)
12+
13+
'{ println(${res.toString.toExpr}) }
14+
}
15+
16+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
2+
import Macros._
3+
4+
import scala.runtime.quoted.Matcher.Hole
5+
import scala.runtime.quoted.Matcher.hole
6+
7+
object Test {
8+
9+
def main(args: Array[String]): Unit = {
10+
val b: Boolean = true
11+
val x: Int = 42
12+
val y: Int = 52
13+
var z: Int = 62
14+
var z2: Int = 62
15+
def f(a: Int): Int = 72
16+
def f2(a: Int, b: Int): Int = 72
17+
def g[A]: A = ???
18+
def h[A](a: A): A = a
19+
def fs(a: Int*): Int = 72
20+
21+
// Matches
22+
println("Matches")
23+
matches(1, 1)
24+
matches(3, hole[Int])
25+
matches(x, hole[Int])
26+
matches(5, hole[Any])
27+
matches(6 + x, hole[Int])
28+
matches(6 + x, 6 + hole[Int])
29+
matches(6 + x, hole[Int] + x)
30+
matches(6 + x, hole[Int] + hole[Int])
31+
matches(6 + x + y, 6 + hole[Int] + y)
32+
matches(f(4), hole[Int])
33+
matches(f(5), f(hole[Int]))
34+
matches(g[Int], hole[Int])
35+
matches(h[Int](7), hole[Int])
36+
matches(h[Int](8), h[Int](hole[Int]))
37+
matches(this, this)
38+
matches(this, hole[this.type])
39+
matches(new Foo(1), new Foo(1))
40+
matches(new Foo(1), hole[Foo])
41+
matches(new Foo(1), new Foo(hole[Int]))
42+
matches(if (b) x else y, if (b) x else y)
43+
matches(if (b) x else y, hole[Int])
44+
matches(if (b) x else y, if (hole[Boolean]) hole[Int] else hole[Int])
45+
matches(while (b) x, while (b) x)
46+
matches(while (b) x, hole[Unit])
47+
matches(while (b) x, while (hole[Boolean]) hole[Int])
48+
matches({z = 4}, {z = 4})
49+
matches({z = 4}, hole[Unit])
50+
matches({z = 4}, {z = hole[Int]})
51+
matches(1, {1})
52+
matches({1}, 1)
53+
// Should these match?
54+
// matches({(); 1}, 1)
55+
// matches(1, {(); 1})
56+
matches(fs(), fs())
57+
matches(fs(), fs(hole[Seq[Int]]: _*))
58+
matches(fs(1, 2, 3), fs(1, 2, 3))
59+
matches(fs(1, 2, 3), fs(hole[Int], hole[Int], 3))
60+
matches(fs(1, 2, 3), fs(hole[Seq[Int]]: _*))
61+
matches(f2(1, 2), f2(1, 2))
62+
matches(f2(a = 1, b = 2), f2(a = 1, b = 2))
63+
matches(f2(a = 1, b = 2), f2(a = hole[Int], b = hole[Int]))
64+
// Should these match?
65+
// matches(f2(a = 1, b = 2), f2(1, 2))
66+
// matches(f2(b = 2, a = 1), f2(1, 2))
67+
matches(super.toString, super.toString)
68+
69+
println()
70+
println("Matches type")
71+
matches(1: Int, 1: Int)
72+
matches(1: Int, 1: Hole[Int])
73+
matches(Nil: List[Int], Nil: List[Hole[Int]])
74+
matches(g[Int], g[Hole[Int]])
75+
matches(h[Int](6), h[Hole[Int]](6))
76+
matches(h[6](6), h[Hole[6]](6))
77+
matches(h[6](6), h[Hole[Int]](6))
78+
matches(g[this.type], g[Hole[this.type]])
79+
80+
// TODO add lots of tests
81+
82+
// No match
83+
println()
84+
println("No match")
85+
matches(1, 2)
86+
matches(4, hole[String])
87+
matches(6 + x, 7 + hole[Int])
88+
matches(6 + x, hole[Int] + 4)
89+
matches(g[Int], hole[String])
90+
matches(h[Int](7), h[String](hole[String]))
91+
matches(h[Int](6), h[Hole[Int]](7))
92+
matches({z = 4}, {z = 5})
93+
matches({z = 4}, {z2 = 4})
94+
95+
println()
96+
println("No match type")
97+
matches(??? : Int, ??? : Hole[String])
98+
matches(g[Int], g[Hole[String]])
99+
matches(h[Int](6), h[Hole[String]]("abc"))
100+
101+
// TODO add lots of tests
102+
103+
}
104+
}
105+
106+
class Foo(a: Int)

0 commit comments

Comments
 (0)