Skip to content

Commit f0a5769

Browse files
Merge pull request #7570 from dotty-staging/add-ExprMap
Add quoted.util.ExprMap
2 parents a356535 + 4eb132a commit f0a5769

File tree

12 files changed

+408
-93
lines changed

12 files changed

+408
-93
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package scala.quoted.util
2+
3+
import scala.quoted._
4+
5+
trait ExprMap {
6+
7+
/** Map an expression `e` with a type `tpe` */
8+
def transform[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T]
9+
10+
/** Map subexpressions an expression `e` with a type `tpe` */
11+
def transformChildren[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T] = {
12+
import qctx.tasty.{_, given}
13+
final class MapChildren() {
14+
15+
def transformStatement(tree: Statement)(given ctx: Context): Statement = {
16+
def localCtx(definition: Definition): Context = definition.symbol.localContext
17+
tree match {
18+
case tree: Term =>
19+
transformTerm(tree, defn.AnyType)
20+
case tree: Definition =>
21+
transformDefinition(tree)
22+
case tree: Import =>
23+
tree
24+
}
25+
}
26+
27+
def transformDefinition(tree: Definition)(given ctx: Context): Definition = {
28+
def localCtx(definition: Definition): Context = definition.symbol.localContext
29+
tree match {
30+
case tree: ValDef =>
31+
implicit val ctx = localCtx(tree)
32+
val rhs1 = tree.rhs.map(x => transformTerm(x, tree.tpt.tpe))
33+
ValDef.copy(tree)(tree.name, tree.tpt, rhs1)
34+
case tree: DefDef =>
35+
implicit val ctx = localCtx(tree)
36+
DefDef.copy(tree)(tree.name, tree.typeParams, tree.paramss, tree.returnTpt, tree.rhs.map(x => transformTerm(x, tree.returnTpt.tpe)))
37+
case tree: TypeDef =>
38+
tree
39+
case tree: ClassDef =>
40+
val newBody = transformStats(tree.body)
41+
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, newBody)
42+
}
43+
}
44+
45+
def transformTermChildren(tree: Term, tpe: Type)(given ctx: Context): Term = tree match {
46+
case Ident(name) =>
47+
tree
48+
case Select(qualifier, name) =>
49+
Select.copy(tree)(transformTerm(qualifier, qualifier.tpe), name)
50+
case This(qual) =>
51+
tree
52+
case Super(qual, mix) =>
53+
tree
54+
case tree @ Apply(fun, args) =>
55+
val MethodType(_, tpes, _) = fun.tpe.widen
56+
Apply.copy(tree)(transformTerm(fun, defn.AnyType), transformTerms(args, tpes))
57+
case TypeApply(fun, args) =>
58+
TypeApply.copy(tree)(transformTerm(fun, defn.AnyType), args)
59+
case _: Literal =>
60+
tree
61+
case New(tpt) =>
62+
New.copy(tree)(transformTypeTree(tpt))
63+
case Typed(expr, tpt) =>
64+
val tp = tpt.tpe match
65+
// TODO improve code
66+
case AppliedType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "<repeated>"), List(tp0: Type)) =>
67+
type T
68+
val a = tp0.seal.asInstanceOf[quoted.Type[T]]
69+
'[Seq[$a]].unseal.tpe
70+
case tp => tp
71+
Typed.copy(tree)(transformTerm(expr, tp), transformTypeTree(tpt))
72+
case tree: NamedArg =>
73+
NamedArg.copy(tree)(tree.name, transformTerm(tree.value, tpe))
74+
case Assign(lhs, rhs) =>
75+
Assign.copy(tree)(lhs, transformTerm(rhs, lhs.tpe.widen))
76+
case Block(stats, expr) =>
77+
Block.copy(tree)(transformStats(stats), transformTerm(expr, tpe))
78+
case If(cond, thenp, elsep) =>
79+
If.copy(tree)(
80+
transformTerm(cond, defn.BooleanType),
81+
transformTerm(thenp, tpe),
82+
transformTerm(elsep, tpe))
83+
case _: Closure =>
84+
tree
85+
case Match(selector, cases) =>
86+
Match.copy(tree)(transformTerm(selector, selector.tpe), transformCaseDefs(cases, tpe))
87+
case Return(expr) =>
88+
// FIXME
89+
// ctx.owner seems to be set to the wrong symbol
90+
// Return.copy(tree)(transformTerm(expr, expr.tpe))
91+
tree
92+
case While(cond, body) =>
93+
While.copy(tree)(transformTerm(cond, defn.BooleanType), transformTerm(body, defn.AnyType))
94+
case Try(block, cases, finalizer) =>
95+
Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases, defn.AnyType), finalizer.map(x => transformTerm(x, defn.AnyType)))
96+
case Repeated(elems, elemtpt) =>
97+
Repeated.copy(tree)(transformTerms(elems, elemtpt.tpe), elemtpt)
98+
case Inlined(call, bindings, expansion) =>
99+
Inlined.copy(tree)(call, transformDefinitions(bindings), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/)
100+
}
101+
102+
def transformTerm(tree: Term, tpe: Type)(given ctx: Context): Term =
103+
tree match {
104+
case _: Closure =>
105+
tree
106+
case _: Inlined =>
107+
transformTermChildren(tree, tpe)
108+
case _ =>
109+
tree.tpe.widen match {
110+
case _: MethodType | _: PolyType =>
111+
transformTermChildren(tree, tpe)
112+
case _ =>
113+
type X
114+
val expr = tree.seal.asInstanceOf[Expr[X]]
115+
val t = tpe.seal.asInstanceOf[quoted.Type[X]]
116+
transform(expr)(given qctx, t).unseal
117+
}
118+
}
119+
120+
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree
121+
122+
def transformCaseDef(tree: CaseDef, tpe: Type)(given ctx: Context): CaseDef =
123+
CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, defn.BooleanType)), transformTerm(tree.rhs, tpe))
124+
125+
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = {
126+
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
127+
}
128+
129+
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
130+
trees mapConserve (transformStatement(_))
131+
132+
def transformDefinitions(trees: List[Definition])(given ctx: Context): List[Definition] =
133+
trees mapConserve (transformDefinition(_))
134+
135+
def transformTerms(trees: List[Term], tpes: List[Type])(given ctx: Context): List[Term] =
136+
var tpes2 = tpes // TODO use proper zipConserve
137+
trees mapConserve { x =>
138+
val tpe :: tail = tpes2
139+
tpes2 = tail
140+
transformTerm(x, tpe)
141+
}
142+
143+
def transformTerms(trees: List[Term], tpe: Type)(given ctx: Context): List[Term] =
144+
trees.mapConserve(x => transformTerm(x, tpe))
145+
146+
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
147+
trees mapConserve (transformTypeTree(_))
148+
149+
def transformCaseDefs(trees: List[CaseDef], tpe: Type)(given ctx: Context): List[CaseDef] =
150+
trees mapConserve (x => transformCaseDef(x, tpe))
151+
152+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
153+
trees mapConserve (transformTypeCaseDef(_))
154+
155+
}
156+
new MapChildren().transformTermChildren(e.unseal, tpe.unseal.tpe).seal.cast[T] // Cast will only fail if this implementation has a bug
157+
}
158+
159+
}

library/src/scala/quoted/Expr.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ package quoted {
1919
*/
2020
final def getValue[U >: T](given qctx: QuoteContext, valueOf: ValueOfExpr[U]): Option[U] = valueOf(this)
2121

22+
/** Pattern matches `this` against `that`. Effectively performing a deep equality check.
23+
* It does the equivalent of
24+
* ```
25+
* this match
26+
* case '{...} => true // where the contens of the pattern are the contents of `that`
27+
* case _ => false
28+
* ```
29+
*/
30+
final def matches(that: Expr[Any])(given qctx: QuoteContext): Boolean =
31+
!scala.internal.quoted.Expr.unapply[Unit, Unit](this)(given that, false, qctx).isEmpty
32+
2233
}
2334

2435
object Expr {

tests/run-macros/expr-map-1.check

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
oof
2+
oofoof
3+
ylppa
4+
kcolb
5+
kcolb
6+
neht
7+
esle
8+
lav
9+
vals
10+
fed
11+
defs
12+
fed
13+
rab
14+
yrt
15+
yllanif
16+
hctac
17+
elihw
18+
wen
19+
depyt
20+
depyt
21+
grAdeman
22+
qual
23+
adbmal
24+
ravsgra
25+
hctam
26+
def
27+
ooF wen
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def rewrite[T](x: => Any): Any = ${ stringRewriter('x) }
5+
6+
private def stringRewriter(e: Expr[Any])(given QuoteContext): Expr[Any] =
7+
StringRewriter.transform(e)
8+
9+
private object StringRewriter extends util.ExprMap {
10+
11+
def transform[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = e match
12+
case Const(s: String) =>
13+
Expr(s.reverse) match
14+
case '{ $x: T } => x
15+
case _ => e // e had a singlton String type
16+
case _ => transformChildren(e)
17+
18+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
println(rewrite("foo"))
5+
println(rewrite("foo" + "foo"))
6+
7+
rewrite {
8+
println("apply")
9+
}
10+
11+
rewrite {
12+
println("block")
13+
println("block")
14+
}
15+
16+
val b: Boolean = true
17+
rewrite {
18+
if b then println("then")
19+
else println("else")
20+
}
21+
22+
rewrite {
23+
if !b then println("then")
24+
else println("else")
25+
}
26+
27+
rewrite {
28+
val s: String = "val"
29+
println(s)
30+
}
31+
32+
rewrite {
33+
val s: "vals" = "vals"
34+
println(s) // prints "foo" not "oof"
35+
}
36+
37+
rewrite {
38+
def s: String = "def"
39+
println(s)
40+
}
41+
42+
rewrite {
43+
def s: "defs" = "defs"
44+
println(s) // prints "foo" not "oof"
45+
}
46+
47+
rewrite {
48+
def s(x: String): String = x
49+
println(s("def"))
50+
}
51+
52+
rewrite {
53+
var s: String = "var"
54+
s = "bar"
55+
println(s)
56+
}
57+
58+
rewrite {
59+
try println("try")
60+
finally println("finally")
61+
}
62+
63+
rewrite {
64+
try throw new Exception()
65+
catch case x: Exception => println("catch")
66+
}
67+
68+
rewrite {
69+
var x = true
70+
while (x) {
71+
println("while")
72+
x = false
73+
}
74+
}
75+
76+
rewrite {
77+
val t = new Tuple1("new")
78+
println(t._1)
79+
}
80+
81+
rewrite {
82+
println("typed": String)
83+
println("typed": Any)
84+
}
85+
86+
rewrite {
87+
val f = new Foo(foo = "namedArg")
88+
println(f.foo)
89+
}
90+
91+
rewrite {
92+
println("qual".reverse)
93+
}
94+
95+
rewrite {
96+
val f = () => "lambda"
97+
println(f())
98+
}
99+
100+
rewrite {
101+
def f(args: String*): String = args.mkString
102+
println(f("var", "args"))
103+
}
104+
105+
rewrite {
106+
"match" match {
107+
case "match" => println("match")
108+
case x => println("x")
109+
}
110+
}
111+
112+
// FIXME should print fed
113+
rewrite {
114+
def s: String = return "def"
115+
println(s)
116+
}
117+
118+
rewrite {
119+
class Foo {
120+
println("new Foo")
121+
}
122+
new Foo
123+
}
124+
125+
126+
}
127+
128+
}
129+
130+
class Foo(val foo: String)

tests/run-macros/expr-map-2.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Foo(2)
2+
4
3+
4
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def rewrite[T](x: => Any): Any = ${ stringRewriter('x) }
5+
6+
private def stringRewriter(e: Expr[Any])(given QuoteContext): Expr[Any] =
7+
StringRewriter.transform(e)
8+
9+
private object StringRewriter extends util.ExprMap {
10+
11+
def transform[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = e match
12+
case '{ ($x: Foo).x } =>
13+
'{ new Foo(4).x } match case '{ $e: T } => e
14+
case _ =>
15+
transformChildren(e)
16+
17+
}
18+
19+
case class Foo(x: Int)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
println(rewrite(new Foo(2)))
5+
println(rewrite(new Foo(2).x))
6+
7+
rewrite {
8+
val foo = new Foo(2)
9+
println(foo.x)
10+
}
11+
12+
}
13+
}

0 commit comments

Comments
 (0)