Skip to content

Commit fb5f46e

Browse files
committed
Simplify TreeUtils
1 parent f5bf7fc commit fb5f46e

File tree

3 files changed

+57
-49
lines changed

3 files changed

+57
-49
lines changed

library/src/scala/tasty/reflect/TreeUtils.scala

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@ trait TreeUtils
1313

1414
// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
1515
def foldTree(x: X, tree: Tree)(implicit ctx: Context): X
16-
def foldTypeTree(x: X, tree: TypeOrBoundsTree)(implicit ctx: Context): X
1716
def foldCaseDef(x: X, tree: CaseDef)(implicit ctx: Context): X
1817
def foldTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X
1918
def foldPattern(x: X, tree: Pattern)(implicit ctx: Context): X
2019

2120
def foldTrees(x: X, trees: Iterable[Tree])(implicit ctx: Context): X = (x /: trees)(foldTree)
22-
def foldTypeTrees(x: X, trees: Iterable[TypeOrBoundsTree])(implicit ctx: Context): X = (x /: trees)(foldTypeTree)
2321
def foldCaseDefs(x: X, trees: Iterable[CaseDef])(implicit ctx: Context): X = (x /: trees)(foldCaseDef)
2422
def foldTypeCaseDefs(x: X, trees: Iterable[TypeCaseDef])(implicit ctx: Context): X = (x /: trees)(foldTypeCaseDef)
2523
def foldPatterns(x: X, trees: Iterable[Pattern])(implicit ctx: Context): X = (x /: trees)(foldPattern)
@@ -39,13 +37,13 @@ trait TreeUtils
3937
case Term.Apply(fun, args) =>
4038
foldTrees(foldTree(x, fun), args)
4139
case Term.TypeApply(fun, args) =>
42-
foldTypeTrees(foldTree(x, fun), args)
40+
foldTrees(foldTree(x, fun), args)
4341
case Term.Literal(const) =>
4442
x
4543
case Term.New(tpt) =>
46-
foldTypeTree(x, tpt)
44+
foldTree(x, tpt)
4745
case Term.Typed(expr, tpt) =>
48-
foldTypeTree(foldTree(x, expr), tpt)
46+
foldTree(foldTree(x, expr), tpt)
4947
case Term.NamedArg(_, arg) =>
5048
foldTree(x, arg)
5149
case Term.Assign(lhs, rhs) =>
@@ -56,94 +54,88 @@ trait TreeUtils
5654
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
5755
case Term.Lambda(meth, tpt) =>
5856
val a = foldTree(x, meth)
59-
tpt.fold(a)(b => foldTypeTree(a, b))
57+
tpt.fold(a)(b => foldTree(a, b))
6058
case Term.Match(selector, cases) =>
6159
foldCaseDefs(foldTree(x, selector), cases)
6260
case Term.Return(expr) =>
6361
foldTree(x, expr)
6462
case Term.Try(block, handler, finalizer) =>
6563
foldTrees(foldCaseDefs(foldTree(x, block), handler), finalizer)
6664
case Term.Repeated(elems, elemtpt) =>
67-
foldTrees(foldTypeTree(x, elemtpt), elems)
65+
foldTrees(foldTree(x, elemtpt), elems)
6866
case Term.Inlined(call, bindings, expansion) =>
6967
foldTree(foldTrees(x, bindings), expansion)
7068
case IsDefinition(vdef @ ValDef(_, tpt, rhs)) =>
7169
implicit val ctx = localCtx(vdef)
72-
foldTrees(foldTypeTree(x, tpt), rhs)
70+
foldTrees(foldTree(x, tpt), rhs)
7371
case IsDefinition(ddef @ DefDef(_, tparams, vparamss, tpt, rhs)) =>
7472
implicit val ctx = localCtx(ddef)
75-
foldTrees(foldTypeTree((foldTrees(x, tparams) /: vparamss)(foldTrees), tpt), rhs)
73+
foldTrees(foldTree((foldTrees(x, tparams) /: vparamss)(foldTrees), tpt), rhs)
7674
case IsDefinition(tdef @ TypeDef(_, rhs)) =>
7775
implicit val ctx = localCtx(tdef)
78-
foldTypeTree(x, rhs)
76+
foldTree(x, rhs)
7977
case IsDefinition(cdef @ ClassDef(_, constr, parents, derived, self, body)) =>
8078
implicit val ctx = localCtx(cdef)
81-
foldTrees(foldTrees(foldTypeTrees(foldParents(foldTree(x, constr), parents), derived), self), body)
79+
foldTrees(foldTrees(foldTrees(foldParents(foldTree(x, constr), parents), derived), self), body)
8280
case Import(_, expr, _) =>
8381
foldTree(x, expr)
8482
case IsPackageClause(clause @ PackageClause(pid, stats)) =>
8583
foldTrees(foldTree(x, pid), stats)(clause.symbol.localContext)
84+
case TypeTree.Inferred() => x
85+
case TypeTree.Ident(_) => x
86+
case TypeTree.Select(qualifier, _) => foldTree(x, qualifier)
87+
case TypeTree.Projection(qualifier, _) => foldTree(x, qualifier)
88+
case TypeTree.Singleton(ref) => foldTree(x, ref)
89+
case TypeTree.Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
90+
case TypeTree.Applied(tpt, args) => foldTrees(foldTree(x, tpt), args)
91+
case TypeTree.ByName(result) => foldTree(x, result)
92+
case TypeTree.Annotated(arg, annot) => foldTree(foldTree(x, arg), annot)
93+
case TypeTree.LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
94+
case TypeTree.TypeBind(_, tbt) => foldTree(x, tbt)
95+
case TypeTree.TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
96+
case TypeTree.MatchType(boundopt, selector, cases) =>
97+
foldTypeCaseDefs(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
98+
case WildcardTypeTree() => x
99+
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
86100
}
87101
}
88102

89-
def foldOverTypeTree(x: X, tree: TypeOrBoundsTree)(implicit ctx: Context): X = tree match {
90-
case TypeTree.Inferred() => x
91-
case TypeTree.Ident(_) => x
92-
case TypeTree.Select(qualifier, _) => foldTree(x, qualifier)
93-
case TypeTree.Projection(qualifier, _) => foldTypeTree(x, qualifier)
94-
case TypeTree.Singleton(ref) => foldTree(x, ref)
95-
case TypeTree.Refined(tpt, refinements) => foldTrees(foldTypeTree(x, tpt), refinements)
96-
case TypeTree.Applied(tpt, args) => foldTypeTrees(foldTypeTree(x, tpt), args)
97-
case TypeTree.ByName(result) => foldTypeTree(x, result)
98-
case TypeTree.Annotated(arg, annot) => foldTree(foldTypeTree(x, arg), annot)
99-
case TypeTree.LambdaTypeTree(typedefs, arg) => foldTypeTree(foldTrees(x, typedefs), arg)
100-
case TypeTree.TypeBind(_, tbt) => foldTypeTree(x, tbt)
101-
case TypeTree.TypeBlock(typedefs, tpt) => foldTypeTree(foldTrees(x, typedefs), tpt)
102-
case TypeTree.MatchType(boundopt, selector, cases) =>
103-
foldTypeCaseDefs(foldTypeTree(boundopt.fold(x)(foldTypeTree(x, _)), selector), cases)
104-
case WildcardTypeTree() => x
105-
case TypeBoundsTree(lo, hi) => foldTypeTree(foldTypeTree(x, lo), hi)
106-
}
107-
108103
def foldOverCaseDef(x: X, tree: CaseDef)(implicit ctx: Context): X = tree match {
109104
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldPattern(x, pat), guard), body)
110105
}
111106

112107
def foldOverTypeCaseDef(x: X, tree: TypeCaseDef)(implicit ctx: Context): X = tree match {
113-
case TypeCaseDef(pat, body) => foldTypeTree(foldTypeTree(x, pat), body)
108+
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
114109
}
115110

116111
def foldOverPattern(x: X, tree: Pattern)(implicit ctx: Context): X = tree match {
117112
case Pattern.Value(v) => foldTree(x, v)
118113
case Pattern.Bind(_, body) => foldPattern(x, body)
119114
case Pattern.Unapply(fun, implicits, patterns) => foldPatterns(foldTrees(foldTree(x, fun), implicits), patterns)
120115
case Pattern.Alternatives(patterns) => foldPatterns(x, patterns)
121-
case Pattern.TypeTest(tpt) => foldTypeTree(x, tpt)
116+
case Pattern.TypeTest(tpt) => foldTree(x, tpt)
122117
}
123118

124119
private def foldTermOrTypeTree(x: X, tree: TermOrTypeTree)(implicit ctx: Context): X = tree match {
125120
case IsTerm(termOrTypeTree) => foldTree(x, termOrTypeTree)
126-
case IsTypeTree(termOrTypeTree) => foldTypeTree(x, termOrTypeTree)
121+
case IsTypeTree(termOrTypeTree) => foldTree(x, termOrTypeTree)
127122
}
128123

129124
}
130125

131126
abstract class TreeTraverser extends TreeAccumulator[Unit] {
132127

133128
def traverseTree(tree: Tree)(implicit ctx: Context): Unit = traverseTreeChildren(tree)
134-
def traverseTypeTree(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = traverseTypeTreeChildren(tree)
135129
def traverseCaseDef(tree: CaseDef)(implicit ctx: Context): Unit = traverseCaseDefChildren(tree)
136130
def traverseTypeCaseDef(tree: TypeCaseDef)(implicit ctx: Context): Unit = traverseTypeCaseDefChildren(tree)
137131
def traversePattern(tree: Pattern)(implicit ctx: Context): Unit = traversePatternChildren(tree)
138132

139133
def foldTree(x: Unit, tree: Tree)(implicit ctx: Context): Unit = traverseTree(tree)
140-
def foldTypeTree(x: Unit, tree: TypeOrBoundsTree)(implicit ctx: Context) = traverseTypeTree(tree)
141134
def foldCaseDef(x: Unit, tree: CaseDef)(implicit ctx: Context) = traverseCaseDef(tree)
142135
def foldTypeCaseDef(x: Unit, tree: TypeCaseDef)(implicit ctx: Context) = traverseTypeCaseDef(tree)
143136
def foldPattern(x: Unit, tree: Pattern)(implicit ctx: Context) = traversePattern(tree)
144137

145138
protected def traverseTreeChildren(tree: Tree)(implicit ctx: Context): Unit = foldOverTree((), tree)
146-
protected def traverseTypeTreeChildren(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = foldOverTypeTree((), tree)
147139
protected def traverseCaseDefChildren(tree: CaseDef)(implicit ctx: Context): Unit = foldOverCaseDef((), tree)
148140
protected def traverseTypeCaseDefChildren(tree: TypeCaseDef)(implicit ctx: Context): Unit = foldOverTypeCaseDef((), tree)
149141
protected def traversePatternChildren(tree: Pattern)(implicit ctx: Context): Unit = foldOverPattern((), tree)
@@ -160,6 +152,9 @@ trait TreeUtils
160152
Import.copy(tree)(tree.impliedOnly, transformTerm(tree.expr), tree.selectors)
161153
case IsStatement(tree) =>
162154
transformStatement(tree)
155+
case IsTypeTree(tree) => transformTypeTree(tree)
156+
case IsTypeBoundsTree(tree) => tree // TODO traverse tree
157+
case IsWildcardTypeTree(tree) => tree // TODO traverse tree
163158
}
164159
}
165160

@@ -234,6 +229,7 @@ trait TreeUtils
234229
def transformTypeOrBoundsTree(tree: TypeOrBoundsTree)(implicit ctx: Context): TypeOrBoundsTree = tree match {
235230
case IsTypeTree(tree) => transformTypeTree(tree)
236231
case IsTypeBoundsTree(tree) => tree // TODO traverse tree
232+
case IsWildcardTypeTree(tree) => tree // TODO traverse tree
237233
}
238234

239235
def transformTypeTree(tree: TypeTree)(implicit ctx: Context): TypeTree = tree match {

semanticdb/src/dotty/semanticdb/SemanticdbConsumer.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,16 @@ class SemanticdbConsumer(sourceFilePath: java.nio.file.Path) extends TastyConsum
4141
object ChildTraverser extends TreeTraverser {
4242
var children: List[Tree] = Nil
4343
var childrenType: List[TypeOrBoundsTree] = Nil
44-
override def traverseTree(tree: Tree)(implicit ctx: Context): Unit =
45-
children = tree :: children
44+
override def traverseTree(tree: Tree)(implicit ctx: Context): Unit = tree match {
45+
case IsTypeTree(tree) =>
46+
traverseTypeTree(tree)
47+
case IsTypeBoundsTree(tree) =>
48+
traverseTypeTree(tree)
49+
case _ => children = tree :: children
50+
}
4651
override def traversePattern(pattern: Pattern)(
4752
implicit ctx: Context): Unit = ()
48-
override def traverseTypeTree(tree: TypeOrBoundsTree)(
53+
def traverseTypeTree(tree: TypeOrBoundsTree)(
4954
implicit ctx: Context): Unit =
5055
childrenType = tree :: childrenType
5156
override def traverseCaseDef(tree: CaseDef)(implicit ctx: Context): Unit =
@@ -61,7 +66,7 @@ class SemanticdbConsumer(sourceFilePath: java.nio.file.Path) extends TastyConsum
6166
}
6267
def getChildrenType(tree: TypeOrBoundsTree)(implicit ctx: Context): List[TypeOrBoundsTree] = {
6368
childrenType = Nil
64-
traverseTypeTreeChildren(tree)(ctx)
69+
traverseTreeChildren(tree)(ctx)
6570
return childrenType
6671
}
6772
}
@@ -643,7 +648,7 @@ class SemanticdbConsumer(sourceFilePath: java.nio.file.Path) extends TastyConsum
643648
})
644649
}
645650

646-
override def traverseTypeTree(tree: TypeOrBoundsTree)(
651+
def traverseTypeTree(tree: TypeOrBoundsTree)(
647652
implicit ctx: Context): Unit = {
648653
tree match {
649654
case TypeTree.Ident(_) => {
@@ -658,7 +663,7 @@ class SemanticdbConsumer(sourceFilePath: java.nio.file.Path) extends TastyConsum
658663
addOccurenceTypeTree(typetree,
659664
s.SymbolOccurrence.Role.REFERENCE,
660665
range)
661-
super.traverseTypeTree(typetree)
666+
super.traverseTree(typetree)
662667
}
663668

664669
case TypeTree.Projection(qualifier, x) => {
@@ -667,7 +672,7 @@ class SemanticdbConsumer(sourceFilePath: java.nio.file.Path) extends TastyConsum
667672
addOccurenceTypeTree(typetree,
668673
s.SymbolOccurrence.Role.REFERENCE,
669674
range)
670-
super.traverseTypeTree(typetree)
675+
super.traverseTree(typetree)
671676
}
672677

673678
case TypeTree.Inferred() => {
@@ -703,7 +708,7 @@ class SemanticdbConsumer(sourceFilePath: java.nio.file.Path) extends TastyConsum
703708
}
704709

705710
case _ => {
706-
super.traverseTypeTree(tree)
711+
super.traverseTree(tree)
707712
}
708713
}
709714
}
@@ -805,7 +810,7 @@ class SemanticdbConsumer(sourceFilePath: java.nio.file.Path) extends TastyConsum
805810

806811
// we add the parents to the symbol list
807812
parents.foreach(_ match {
808-
case IsTypeTree(t) => traverseTypeTree(t)
813+
case IsTypeTree(t) => traverseTree(t)
809814
case IsTerm(t) => traverseTree(t)
810815
})
811816

tests/run/tasty-extractors-3/quoted_1.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@ object Macros {
1212

1313
val buff = new StringBuilder
1414
val traverser = new TreeTraverser {
15-
override def traverseTypeTree(tree: TypeOrBoundsTree)(implicit ctx: Context): Unit = {
16-
buff.append(tree.tpe.show)
17-
buff.append("\n\n")
18-
traverseTypeTreeChildren(tree)
15+
override def traverseTree(tree: Tree)(implicit ctx: Context): Unit = tree match {
16+
case IsTypeBoundsTree(tree) =>
17+
buff.append(tree.tpe.show)
18+
buff.append("\n\n")
19+
traverseTreeChildren(tree)
20+
case IsTypeTree(tree) =>
21+
buff.append(tree.tpe.show)
22+
buff.append("\n\n")
23+
traverseTreeChildren(tree)
24+
case _ =>
25+
super.traverseTree(tree)
1926
}
2027
}
2128

0 commit comments

Comments
 (0)