Skip to content

Commit 4f1160e

Browse files
committed
Simplify TreeUtils
1 parent f5bf7fc commit 4f1160e

File tree

2 files changed

+43
-40
lines changed

2 files changed

+43
-40
lines changed

library/src/scala/tasty/reflect/TreeUtils.scala

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@ trait TreeUtils
1313

1414
// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
1515
def foldTree(x: X, tree: Tree)(implicit ctx: Context): X
16-
def foldTypeTree(x: X, tree: TypeOrBoundsTree)(implicit ctx: Context): X
1716
def foldCaseDef(x: X, tree: CaseDef)(implicit ctx: Context): X
1817
def foldTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X
1918
def foldPattern(x: X, tree: Pattern)(implicit ctx: Context): X
2019

2120
def foldTrees(x: X, trees: Iterable[Tree])(implicit ctx: Context): X = (x /: trees)(foldTree)
22-
def foldTypeTrees(x: X, trees: Iterable[TypeOrBoundsTree])(implicit ctx: Context): X = (x /: trees)(foldTypeTree)
2321
def foldCaseDefs(x: X, trees: Iterable[CaseDef])(implicit ctx: Context): X = (x /: trees)(foldCaseDef)
2422
def foldTypeCaseDefs(x: X, trees: Iterable[TypeCaseDef])(implicit ctx: Context): X = (x /: trees)(foldTypeCaseDef)
2523
def foldPatterns(x: X, trees: Iterable[Pattern])(implicit ctx: Context): X = (x /: trees)(foldPattern)
@@ -39,13 +37,13 @@ trait TreeUtils
3937
case Term.Apply(fun, args) =>
4038
foldTrees(foldTree(x, fun), args)
4139
case Term.TypeApply(fun, args) =>
42-
foldTypeTrees(foldTree(x, fun), args)
40+
foldTrees(foldTree(x, fun), args)
4341
case Term.Literal(const) =>
4442
x
4543
case Term.New(tpt) =>
46-
foldTypeTree(x, tpt)
44+
foldTree(x, tpt)
4745
case Term.Typed(expr, tpt) =>
48-
foldTypeTree(foldTree(x, expr), tpt)
46+
foldTree(foldTree(x, expr), tpt)
4947
case Term.NamedArg(_, arg) =>
5048
foldTree(x, arg)
5149
case Term.Assign(lhs, rhs) =>
@@ -56,94 +54,88 @@ trait TreeUtils
5654
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
5755
case Term.Lambda(meth, tpt) =>
5856
val a = foldTree(x, meth)
59-
tpt.fold(a)(b => foldTypeTree(a, b))
57+
tpt.fold(a)(b => foldTree(a, b))
6058
case Term.Match(selector, cases) =>
6159
foldCaseDefs(foldTree(x, selector), cases)
6260
case Term.Return(expr) =>
6361
foldTree(x, expr)
6462
case Term.Try(block, handler, finalizer) =>
6563
foldTrees(foldCaseDefs(foldTree(x, block), handler), finalizer)
6664
case Term.Repeated(elems, elemtpt) =>
67-
foldTrees(foldTypeTree(x, elemtpt), elems)
65+
foldTrees(foldTree(x, elemtpt), elems)
6866
case Term.Inlined(call, bindings, expansion) =>
6967
foldTree(foldTrees(x, bindings), expansion)
7068
case IsDefinition(vdef @ ValDef(_, tpt, rhs)) =>
7169
implicit val ctx = localCtx(vdef)
72-
foldTrees(foldTypeTree(x, tpt), rhs)
70+
foldTrees(foldTree(x, tpt), rhs)
7371
case IsDefinition(ddef @ DefDef(_, tparams, vparamss, tpt, rhs)) =>
7472
implicit val ctx = localCtx(ddef)
75-
foldTrees(foldTypeTree((foldTrees(x, tparams) /: vparamss)(foldTrees), tpt), rhs)
73+
foldTrees(foldTree((foldTrees(x, tparams) /: vparamss)(foldTrees), tpt), rhs)
7674
case IsDefinition(tdef @ TypeDef(_, rhs)) =>
7775
implicit val ctx = localCtx(tdef)
78-
foldTypeTree(x, rhs)
76+
foldTree(x, rhs)
7977
case IsDefinition(cdef @ ClassDef(_, constr, parents, derived, self, body)) =>
8078
implicit val ctx = localCtx(cdef)
81-
foldTrees(foldTrees(foldTypeTrees(foldParents(foldTree(x, constr), parents), derived), self), body)
79+
foldTrees(foldTrees(foldTrees(foldParents(foldTree(x, constr), parents), derived), self), body)
8280
case Import(_, expr, _) =>
8381
foldTree(x, expr)
8482
case IsPackageClause(clause @ PackageClause(pid, stats)) =>
8583
foldTrees(foldTree(x, pid), stats)(clause.symbol.localContext)
84+
case TypeTree.Inferred() => x
85+
case TypeTree.Ident(_) => x
86+
case TypeTree.Select(qualifier, _) => foldTree(x, qualifier)
87+
case TypeTree.Projection(qualifier, _) => foldTree(x, qualifier)
88+
case TypeTree.Singleton(ref) => foldTree(x, ref)
89+
case TypeTree.Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
90+
case TypeTree.Applied(tpt, args) => foldTrees(foldTree(x, tpt), args)
91+
case TypeTree.ByName(result) => foldTree(x, result)
92+
case TypeTree.Annotated(arg, annot) => foldTree(foldTree(x, arg), annot)
93+
case TypeTree.LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
94+
case TypeTree.TypeBind(_, tbt) => foldTree(x, tbt)
95+
case TypeTree.TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
96+
case TypeTree.MatchType(boundopt, selector, cases) =>
97+
foldTypeCaseDefs(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
98+
case WildcardTypeTree() => x
99+
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
86100
}
87101
}
88102

89-
def foldOverTypeTree(x: X, tree: TypeOrBoundsTree)(implicit ctx: Context): X = tree match {
90-
case TypeTree.Inferred() => x
91-
case TypeTree.Ident(_) => x
92-
case TypeTree.Select(qualifier, _) => foldTree(x, qualifier)
93-
case TypeTree.Projection(qualifier, _) => foldTypeTree(x, qualifier)
94-
case TypeTree.Singleton(ref) => foldTree(x, ref)
95-
case TypeTree.Refined(tpt, refinements) => foldTrees(foldTypeTree(x, tpt), refinements)
96-
case TypeTree.Applied(tpt, args) => foldTypeTrees(foldTypeTree(x, tpt), args)
97-
case TypeTree.ByName(result) => foldTypeTree(x, result)
98-
case TypeTree.Annotated(arg, annot) => foldTree(foldTypeTree(x, arg), annot)
99-
case TypeTree.LambdaTypeTree(typedefs, arg) => foldTypeTree(foldTrees(x, typedefs), arg)
100-
case TypeTree.TypeBind(_, tbt) => foldTypeTree(x, tbt)
101-
case TypeTree.TypeBlock(typedefs, tpt) => foldTypeTree(foldTrees(x, typedefs), tpt)
102-
case TypeTree.MatchType(boundopt, selector, cases) =>
103-
foldTypeCaseDefs(foldTypeTree(boundopt.fold(x)(foldTypeTree(x, _)), selector), cases)
104-
case WildcardTypeTree() => x
105-
case TypeBoundsTree(lo, hi) => foldTypeTree(foldTypeTree(x, lo), hi)
106-
}
107-
108103
def foldOverCaseDef(x: X, tree: CaseDef)(implicit ctx: Context): X = tree match {
109104
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldPattern(x, pat), guard), body)
110105
}
111106

112107
def foldOverTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X = tree match {
113-
case TypeCaseDef(pat, body) => foldTypeTree(foldTypeTree(x, pat), body)
108+
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
114109
}
115110

116111
def foldOverPattern(x: X, tree: Pattern)(implicit ctx: Context): X = tree match {
117112
case Pattern.Value(v) => foldTree(x, v)
118113
case Pattern.Bind(_, body) => foldPattern(x, body)
119114
case Pattern.Unapply(fun, implicits, patterns) => foldPatterns(foldTrees(foldTree(x, fun), implicits), patterns)
120115
case Pattern.Alternatives(patterns) => foldPatterns(x, patterns)
121-
case Pattern.TypeTest(tpt) => foldTypeTree(x, tpt)
116+
case Pattern.TypeTest(tpt) => foldTree(x, tpt)
122117
}
123118

124119
private def foldTermOrTypeTree(x: X, tree: TermOrTypeTree)(implicit ctx: Context): X = tree match {
125120
case IsTerm(termOrTypeTree) => foldTree(x, termOrTypeTree)
126-
case IsTypeTree(termOrTypeTree) => foldTypeTree(x, termOrTypeTree)
121+
case IsTypeTree(termOrTypeTree) => foldTree(x, termOrTypeTree)
127122
}
128123

129124
}
130125

131126
abstract class TreeTraverser extends TreeAccumulator[Unit] {
132127

133128
def traverseTree(tree: Tree)(implicit ctx: Context): Unit = traverseTreeChildren(tree)
134-
def traverseTypeTree(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = traverseTypeTreeChildren(tree)
135129
def traverseCaseDef(tree: CaseDef)(implicit ctx: Context): Unit = traverseCaseDefChildren(tree)
136130
def traverseTypeCaseDef(tree: TypeCaseDef)(implicit ctx: Context): Unit = traverseTypeCaseDefChildren(tree)
137131
def traversePattern(tree: Pattern)(implicit ctx: Context): Unit = traversePatternChildren(tree)
138132

139133
def foldTree(x: Unit, tree: Tree)(implicit ctx: Context): Unit = traverseTree(tree)
140-
def foldTypeTree(x: Unit, tree: TypeOrBoundsTree)(implicit ctx: Context) = traverseTypeTree(tree)
141134
def foldCaseDef(x: Unit, tree: CaseDef)(implicit ctx: Context) = traverseCaseDef(tree)
142135
def foldTypeCaseDef(x: Unit, tree: TypeCaseDef)(implicit ctx: Context) = traverseTypeCaseDef(tree)
143136
def foldPattern(x: Unit, tree: Pattern)(implicit ctx: Context) = traversePattern(tree)
144137

145138
protected def traverseTreeChildren(tree: Tree)(implicit ctx: Context): Unit = foldOverTree((), tree)
146-
protected def traverseTypeTreeChildren(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = foldOverTypeTree((), tree)
147139
protected def traverseCaseDefChildren(tree: CaseDef)(implicit ctx: Context): Unit = foldOverCaseDef((), tree)
148140
protected def traverseTypeCaseDefChildren(tree: TypeCaseDef)(implicit ctx: Context): Unit = foldOverTypeCaseDef((), tree)
149141
protected def traversePatternChildren(tree: Pattern)(implicit ctx: Context): Unit = foldOverPattern((), tree)
@@ -160,6 +152,9 @@ trait TreeUtils
160152
Import.copy(tree)(tree.impliedOnly, transformTerm(tree.expr), tree.selectors)
161153
case IsStatement(tree) =>
162154
transformStatement(tree)
155+
case IsTypeTree(tree) => transformTypeTree(tree)
156+
case IsTypeBoundsTree(tree) => tree // TODO traverse tree
157+
case IsWildcardTypeTree(tree) => tree // TODO traverse tree
163158
}
164159
}
165160

@@ -234,6 +229,7 @@ trait TreeUtils
234229
def transformTypeOrBoundsTree(tree: TypeOrBoundsTree)(implicit ctx: Context): TypeOrBoundsTree = tree match {
235230
case IsTypeTree(tree) => transformTypeTree(tree)
236231
case IsTypeBoundsTree(tree) => tree // TODO traverse tree
232+
case IsWildcardTypeTree(tree) => tree // TODO traverse tree
237233
}
238234

239235
def transformTypeTree(tree: TypeTree)(implicit ctx: Context): TypeTree = tree match {

tests/run/tasty-extractors-3/quoted_1.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@ object Macros {
1212

1313
val buff = new StringBuilder
1414
val traverser = new TreeTraverser {
15-
override def traverseTypeTree(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = {
16-
buff.append(tree.tpe.show)
17-
buff.append("\n\n")
18-
traverseTypeTreeChildren(tree)
15+
override def traverseTree(tree: Tree)(implicit ctx: Context): Unit = tree match {
16+
case IsTypeBoundsTree(tree) =>
17+
buff.append(tree.tpe.show)
18+
buff.append("\n\n")
19+
traverseTreeChildren(tree)
20+
case IsTypeTree(tree) =>
21+
buff.append(tree.tpe.show)
22+
buff.append("\n\n")
23+
traverseTreeChildren(tree)
24+
case _ =>
25+
super.traverseTree(tree)
1926
}
2027
}
2128

0 commit comments

Comments
 (0)