Skip to content

Commit 48c8b34

Browse files
committed
Reinfer all types of a unit in a pre-traversal
1 parent f842979 commit 48c8b34

File tree

3 files changed

+67
-87
lines changed

3 files changed

+67
-87
lines changed

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ object CapturingType:
1111
if refs.isAlwaysEmpty then parent
1212
else AnnotatedType(parent, CaptureAnnotation(refs))
1313

14-
def unapply(tp: Type)(using Context) = tp match
15-
case tp: AnnotatedType => tp.annot match
16-
case ann: CaptureAnnotation =>
17-
Some((tp.parent, ann.refs))
18-
case ann =>
19-
if ann.symbol == defn.RetainsAnnot && ctx.phase == Phases.checkCapturesPhase
20-
then Some((tp.parent, ann.tree.toCaptureSet))
21-
else None
22-
case _ => None
14+
def unapply(tp: AnnotatedType)(using Context) = tp.annot match
15+
case ann: CaptureAnnotation =>
16+
Some((tp.parent, ann.refs))
17+
case ann =>
18+
if ann.symbol == defn.RetainsAnnot && ctx.phase == Phases.checkCapturesPhase
19+
then Some((tp.parent, ann.tree.toCaptureSet))
20+
else None
2321

2422
end CapturingType

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
4343
override def widenSkolems = true
4444

4545
def run(using Context): Unit =
46-
newRechecker().checkUnit(ctx.compilationUnit)
46+
val rechecker = newRechecker()
47+
rechecker.reinferAll.traverse(ctx.compilationUnit.tpdTree)
48+
rechecker.checkUnit(ctx.compilationUnit)
4749

4850
def newRechecker()(using Context): Rechecker
4951

@@ -63,36 +65,41 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
6365
else sym.flags
6466
).installAfter(preRecheckPhase)
6567

66-
/** Hooks to be overridden */
67-
protected def transformType(tp: Type, inferred: Boolean)(using Context): Type = tp
68-
69-
def enterDef(stat: Tree)(using Context): Unit =
70-
val sym = stat.symbol
71-
val info1 = stat match
72-
case stat: ValOrDefDef if stat.tpt.isInstanceOf[InferredTypeTree] =>
73-
stat match
74-
case stat: DefDef => // enter paramss early since result type refers to them
75-
stat.paramss.foreach(_.foreach(enterDef))
76-
case _ =>
77-
def integrateRT(restp: Type, info: Type, psymss: List[List[Symbol]]): Type = info match
78-
case info: MethodOrPoly =>
79-
info.derivedLambdaType(resType =
80-
integrateRT(restp.subst(psymss.head, info.paramRefs), info.resType, psymss.tail))
81-
case info: ExprType =>
82-
info.derivedExprType(resType = restp)
83-
case _ =>
84-
restp
85-
integrateRT(recheck(stat.tpt), sym.info, sym.paramSymss)
86-
case _ =>
87-
sym.info
88-
val info2 = transformType(info1, inferred = stat.isInstanceOf[Bind])
89-
recheckr.println(i"update info $sym: ${sym.info} --> $info1 --> $info2")
90-
sym.updateInfo(info2)
91-
9268
extension (tpe: Type) def rememberFor(tree: Tree)(using Context): Unit =
9369
if (tpe ne tree.tpe) && !tree.hasAttachment(RecheckedType) then
9470
tree.putAttachment(RecheckedType, tpe)
9571

72+
def knownType(tree: Tree) =
73+
tree.attachmentOrElse(RecheckedType, tree.tpe)
74+
75+
def reinfer(tp: Type)(using Context): Type = tp
76+
77+
object reinferAll extends TreeTraverser:
78+
def traverse(tree: Tree)(using Context) =
79+
traverseChildren(tree)
80+
tree match
81+
case tree: InferredTypeTree =>
82+
reinfer(tree.tpe).rememberFor(tree)
83+
case tree: ValOrDefDef if tree.tpt.isInstanceOf[InferredTypeTree] =>
84+
val sym = tree.symbol
85+
def integrateRT(restp: Type, info: Type, psymss: List[List[Symbol]]): Type = info match
86+
case info: MethodOrPoly =>
87+
info.derivedLambdaType(resType =
88+
integrateRT(restp.subst(psymss.head, info.paramRefs), info.resType, psymss.tail))
89+
case info: ExprType =>
90+
info.derivedExprType(resType = restp)
91+
case _ =>
92+
restp
93+
if !sym.isConstructor then
94+
val info1 = integrateRT(knownType(tree.tpt), sym.info, sym.paramSymss)
95+
recheckr.println(i"update info $sym: ${sym.info} --> $info1")
96+
sym.updateInfo(info1)
97+
case tree: Bind =>
98+
val sym = tree.symbol
99+
sym.updateInfo(reinfer(sym.info))
100+
case _ =>
101+
end reinferAll
102+
96103
def constFold(tree: Tree, tp: Type)(using Context): Type =
97104
val tree1 = tree.withType(tp)
98105
val tree2 = ConstFold(tree1)
@@ -115,7 +122,6 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
115122

116123
def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match
117124
case Bind(name, body) =>
118-
enterDef(tree)
119125
recheck(body, pt)
120126
val sym = tree.symbol
121127
if sym.isType then sym.typeRef else sym.info
@@ -131,9 +137,6 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
131137
sym.termRef
132138

133139
def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type =
134-
if !tree.tpt.isInstanceOf[InferredTypeTree] then
135-
// otherwise paramss were already entered
136-
tree.paramss.foreach(_.foreach(enterDef))
137140
val rhsCtx = linkConstructorParams(sym)
138141
if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then
139142
inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) }
@@ -256,12 +259,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
256259
seqLitType(tree, TypeComparer.lub(declaredElemType :: elemTypes))
257260

258261
def recheckTypeTree(tree: TypeTree)(using Context): Type =
259-
tree.getAttachment(RecheckedType) match
260-
case Some(tp) => tp
261-
case None =>
262-
val tp = transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree])
263-
tp.rememberFor(tree)
264-
tp
262+
knownType(tree)
265263

266264
def recheckAnnotated(tree: Annotated)(using Context): Type =
267265
tree.tpe match
@@ -278,7 +276,6 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
278276
NoType
279277

280278
def recheckStats(stats: List[Tree])(using Context): Unit =
281-
stats.foreach(enterDef)
282279
stats.foreach(recheck(_))
283280

284281
/** Recheck tree without adapting it, returning its new type.

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -99,55 +99,40 @@ class CheckCaptures extends Recheck:
9999
class CaptureChecker(ictx: Context) extends Rechecker(ictx):
100100
import ast.tpd.*
101101

102-
override def transformType(tp: Type, inferred: Boolean)(using Context): Type =
102+
override def reinfer(tp: Type)(using Context): Type =
103103

104104
def mapRefined(tp: RefinedType, core1: Type, rinfo1: Type): Type =
105105
if (rinfo1 ne tp.refinedInfo) && defn.isFunctionType(tp)
106106
then rinfo1.toFunctionType(isJava = false)
107107
else tp.derivedRefinedType(core1, tp.refinedName, rinfo1)
108108

109-
val mapType = new TypeMap:
109+
val cleanType = new TypeMap:
110110
def apply(t: Type) = t match
111-
case AnnotatedType(parent, annot)
112-
if annot.symbol == defn.RetainsAnnot && !annot.isInstanceOf[CaptureAnnotation] =>
113-
CapturingType(this(parent), annot.tree.toCaptureSet)
114-
case t @ RefinedType(core, nme.apply, appInfo) =>
115-
mapRefined(t, this(core), this(appInfo))
111+
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
112+
apply(parent)
116113
case _ =>
117114
mapOver(t)
118115

119-
def reinfer(tp: Type): Type =
120-
val cleanType = new TypeMap:
121-
def apply(t: Type) = t match
122-
case AnnotatedType(parent, annot)
123-
if annot.symbol == defn.RetainsAnnot =>
124-
parent
125-
case _ =>
126-
mapOver(t)
127-
128-
def addInnerVars(tp: Type): Type = tp match
129-
case tp @ AppliedType(tycon, args) =>
130-
tp.derivedAppliedType(tycon, args.map(addVars))
131-
case tp @ RefinedType(core, _, rinfo) =>
132-
mapRefined(tp, addInnerVars(core), addVars(rinfo))
133-
case tp: MethodOrPoly =>
134-
tp.derivedLambdaType(resType = addVars(tp.resType))
135-
case tp: ExprType =>
136-
tp.derivedExprType(resType = addVars(tp.resType))
137-
case _ =>
138-
tp
139-
140-
def addVars(tp: Type): Type =
141-
val tp1 = addInnerVars(tp)
142-
if tp1.canHaveInferredCapture
143-
then CapturingType(tp1, CaptureSet.Var())
144-
else tp1
145-
146-
addVars(cleanType(tp))
147-
end reinfer
148-
149-
if inferred then reinfer(tp) else tp//mapType(tp)
150-
end transformType
116+
def addInnerVars(tp: Type): Type = tp match
117+
case tp @ AppliedType(tycon, args) =>
118+
tp.derivedAppliedType(tycon, args.map(addVars))
119+
case tp @ RefinedType(core, _, rinfo) =>
120+
mapRefined(tp, addInnerVars(core), addVars(rinfo))
121+
case tp: MethodOrPoly =>
122+
tp.derivedLambdaType(resType = addVars(tp.resType))
123+
case tp: ExprType =>
124+
tp.derivedExprType(resType = addVars(tp.resType))
125+
case _ =>
126+
tp
127+
128+
def addVars(tp: Type): Type =
129+
val tp1 = addInnerVars(tp)
130+
if tp1.canHaveInferredCapture
131+
then CapturingType(tp1, CaptureSet.Var())
132+
else tp1
133+
134+
addVars(cleanType(tp))
135+
end reinfer
151136

152137
private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, false, null)
153138

@@ -270,7 +255,7 @@ class CheckCaptures extends Recheck:
270255
tree match
271256
case _: InferredTypeTree =>
272257
case tree: TypeTree =>
273-
transformType(tree.tpe, inferred = false).foreachPart(
258+
tree.tpe.foreachPart(
274259
checkWellformedPost(_, tree.srcPos))
275260
tree.tpe.foreachPart {
276261
case AnnotatedType(_, annot) =>

0 commit comments

Comments
 (0)