Skip to content

Commit 3a2141b

Browse files
committed
Make incremental compilation aware of synthesized mirrors
A product mirror needs to be resynthesized if any class parameter changes, and a sum mirror needs to be resynthesized if any child of the sealed type changes, but previously this did not reliably work because the dependency recording in ExtractDependencies was unaware of mirrors. Instead of making ExtractDependencies aware of mirrors, we solve this by directly recording the dependencies when the mirror is synthesized, this way we can be sure to always correctly invalidate users of mirrors, even if the synthesized mirror type is not present in the AST at phase ExtractDependencies. This is the first time that we record dependencies outside of the ExtractDependencies phase, in the future we should see if we can extend this mechanism to record more dependencies during typechecking to make incremental compilation more robust (e.g. by keeping track of symbols looked up by macros). Eventually, we might even want to completely get rid of the ExtractDependencies phase and record all dependencies on the fly if it turns out to be faster.
1 parent 54e2f59 commit 3a2141b

File tree

13 files changed

+91
-16
lines changed

13 files changed

+91
-16
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ object Contexts {
177177
val local = incCallback
178178
local != null && local.enabled || forceRun
179179

180+
/** Used to record dependencies to invalidate during incremental compilation.
181+
* This is only used if `runZincPhases` is true.
182+
*/
183+
def depRecorder: sbt.DependencyRecorder = base.depRecorder
184+
180185
/** The current plain printer */
181186
def printerFn: Context => Printer = store(printerFnLoc)
182187

@@ -1042,6 +1047,9 @@ object Contexts {
10421047
charArray = new Array[Char](charArray.length * 2)
10431048
charArray
10441049

1050+
// Incremental compilation state
1051+
private[Contexts] val depRecorder: sbt.DependencyRecorder = sbt.DependencyRecorder()
1052+
10451053
def reset(): Unit =
10461054
uniques.clear()
10471055
uniqueAppliedTypes.clear()
@@ -1053,6 +1061,7 @@ object Contexts {
10531061
sources.clear()
10541062
files.clear()
10551063
comparers.clear() // forces re-evaluation of top and bottom classes in TypeComparer
1064+
depRecorder.clear()
10561065

10571066
// Test that access is single threaded
10581067

compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,12 @@ class ExtractDependencies extends Phase {
6969

7070
override def run(using Context): Unit = {
7171
val unit = ctx.compilationUnit
72-
val rec = DependencyRecorder()
73-
val collector = ExtractDependenciesCollector(rec)
72+
val collector = ExtractDependenciesCollector()
7473
collector.traverse(unit.tpdTree)
7574

7675
if (ctx.settings.YdumpSbtInc.value) {
77-
val deps = rec.classDependencies.map(_.toString).toArray[Object]
78-
val names = rec.usedNames.map { case (clazz, names) => s"$clazz: $names" }.toArray[Object]
76+
val deps = ctx.depRecorder.classDependencies.map(_.toString).toArray[Object]
77+
val names = ctx.depRecorder.usedNames.map { case (clazz, names) => s"$clazz: $names" }.toArray[Object]
7978
Arrays.sort(deps)
8079
Arrays.sort(names)
8180

@@ -92,7 +91,7 @@ class ExtractDependencies extends Phase {
9291
} finally pw.close()
9392
}
9493

95-
rec.sendToZinc()
94+
ctx.depRecorder.sendToZinc()
9695
}
9796
}
9897

@@ -116,32 +115,32 @@ object ExtractDependencies {
116115
* specially, see the subsection "Dependencies introduced by member reference and
117116
* inheritance" in the "Name hashing algorithm" section.
118117
*/
119-
private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.TreeTraverser { thisTreeTraverser =>
118+
private class ExtractDependenciesCollector() extends tpd.TreeTraverser { thisTreeTraverser =>
120119
import tpd._
121120

122121
private def addMemberRefDependency(sym: Symbol)(using Context): Unit =
123122
if (!ignoreDependency(sym)) {
124-
rec.addUsedName(sym)
123+
ctx.depRecorder.addUsedName(sym)
125124
// packages have class symbol. Only record them as used names but not dependency
126125
if (!sym.is(Package)) {
127126
val enclOrModuleClass = if (sym.is(ModuleVal)) sym.moduleClass else sym.enclosingClass
128127
assert(enclOrModuleClass.isClass, s"$enclOrModuleClass, $sym")
129128

130-
rec.addClassDependency(enclOrModuleClass, DependencyByMemberRef)
129+
ctx.depRecorder.addClassDependency(enclOrModuleClass, DependencyByMemberRef)
131130
}
132131
}
133132

134133
private def addInheritanceDependencies(tree: Closure)(using Context): Unit =
135134
// If the tpt is empty, this is a non-SAM lambda, so no need to register
136135
// an inheritance relationship.
137136
if !tree.tpt.isEmpty then
138-
rec.addClassDependency(tree.tpt.tpe.classSymbol, LocalDependencyByInheritance)
137+
ctx.depRecorder.addClassDependency(tree.tpt.tpe.classSymbol, LocalDependencyByInheritance)
139138

140139
private def addInheritanceDependencies(tree: Template)(using Context): Unit =
141140
if (tree.parents.nonEmpty) {
142141
val depContext = depContextOf(tree.symbol.owner)
143142
for parent <- tree.parents do
144-
rec.addClassDependency(parent.tpe.classSymbol, depContext)
143+
ctx.depRecorder.addClassDependency(parent.tpe.classSymbol, depContext)
145144
}
146145

147146
private def depContextOf(cls: Symbol)(using Context): DependencyContext =
@@ -179,7 +178,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.
179178
for sel <- selectors if !sel.isWildcard do
180179
addImported(sel.name)
181180
if sel.rename != sel.name then
182-
rec.addUsedRawName(sel.rename)
181+
ctx.depRecorder.addUsedRawName(sel.rename)
183182
case exp @ Export(expr, selectors) =>
184183
val dep = expr.tpe.classSymbol
185184
if dep.exists && selectors.exists(_.isWildcard) then
@@ -192,7 +191,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.
192191
// inheritance dependency in the presence of wildcard exports
193192
// to ensure all new members of `dep` are forwarded to.
194193
val depContext = depContextOf(ctx.owner.lexicallyEnclosingClass)
195-
rec.addClassDependency(dep, depContext)
194+
ctx.depRecorder.addClassDependency(dep, depContext)
196195
case t: TypeTree =>
197196
addTypeDependency(t.tpe)
198197
case ref: RefTree =>
@@ -299,7 +298,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd.
299298
val traverser = new TypeDependencyTraverser {
300299
def addDependency(symbol: Symbol) =
301300
if (!ignoreDependency(symbol) && symbol.is(Sealed)) {
302-
rec.addUsedName(symbol, includeSealedChildren = true)
301+
ctx.depRecorder.addUsedName(symbol, includeSealedChildren = true)
303302
}
304303
}
305304
traverser.traverse(tpe)
@@ -422,8 +421,12 @@ class DependencyRecorder {
422421
case (usedName, scopes) =>
423422
cb.usedName(className, usedName.toString, scopes)
424423
classDependencies.foreach(recordClassDependency(cb, _))
425-
_usedNames.clear()
426-
_classDependencies.clear()
424+
clear()
425+
426+
/** Clear all state. */
427+
def clear(): Unit =
428+
_usedNames.clear()
429+
_classDependencies.clear()
427430

428431
/** Handles dependency on given symbol by trying to figure out if represents a term
429432
* that is coming from either source code (not necessarily compiled in this compilation

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ import ast.Trees.genericEmptyTree
1919
import annotation.{tailrec, constructorOnly}
2020
import ast.tpd._
2121
import Synthesizer._
22+
import sbt.ExtractDependencies.*
23+
import sbt.ClassDependency
24+
import xsbti.api.DependencyContext._
2225

2326
/** Synthesize terms for special classes */
2427
class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
@@ -458,7 +461,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
458461
val reason = s"it reduces to a tuple with arity $arity, expected arity <= $maxArity"
459462
withErrors(i"${defn.PairClass} is not a generic product because $reason")
460463
case MirrorSource.ClassSymbol(pre, cls) =>
461-
if cls.isGenericProduct then makeProductMirror(pre, cls, None)
464+
if cls.isGenericProduct then
465+
if ctx.runZincPhases then
466+
// The mirror should be resynthesized if the constructor of the
467+
// case class `cls` changes. See `sbt-test/source-dependencies/mirror-product`.
468+
ctx.depRecorder.addClassDependency(cls, DependencyByMemberRef)
469+
ctx.depRecorder.addUsedName(cls.primaryConstructor)
470+
makeProductMirror(pre, cls, None)
462471
else withErrors(i"$cls is not a generic product because ${cls.whyNotGenericProduct}")
463472
case Left(msg) =>
464473
withErrors(i"type `$mirroredType` is not a generic product because $msg")
@@ -478,6 +487,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
478487
val clsIsGenericSum = cls.isGenericSum(pre)
479488

480489
if acceptableMsg.isEmpty && clsIsGenericSum then
490+
if ctx.runZincPhases then
491+
// The mirror should be resynthesized if any child of the sealed class
492+
// `cls` changes. See `sbt-test/source-dependencies/mirror-sum`.
493+
ctx.depRecorder.addClassDependency(cls, DependencyByMemberRef)
494+
ctx.depRecorder.addUsedName(cls, includeSealedChildren = true)
495+
481496
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
482497

483498
def internalError(msg: => String)(using Context): Unit =
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
case class MyProduct(x: Int)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.deriving.Mirror
2+
import scala.compiletime.erasedValue
3+
4+
transparent inline def foo[T](using m: Mirror.Of[T]): Int =
5+
inline erasedValue[m.MirroredElemTypes] match
6+
case _: (Int *: EmptyTuple) => 1
7+
case _: (Int *: String *: EmptyTuple) => 2
8+
9+
@main def Test =
10+
assert(foo[MyProduct] == 2)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
scalaVersion := sys.props("plugin.scalaVersion")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
case class MyProduct(x: Int, y: String)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
> compile
2+
3+
# change the case class constructor
4+
$ copy-file changes/MyProduct.scala MyProduct.scala
5+
6+
# Both MyProduct.scala and Test.scala should be recompiled, otherwise the assertion will fail
7+
> run
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
sealed trait Sum
2+
case class Child1() extends Sum
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import scala.deriving.Mirror
2+
import scala.compiletime.erasedValue
3+
4+
object Test:
5+
transparent inline def foo[T](using m: Mirror.Of[T]): Int =
6+
inline erasedValue[m.MirroredElemLabels] match
7+
case _: ("Child1" *: EmptyTuple) => 1
8+
case _: ("Child1" *: "Child2" *: EmptyTuple) => 2
9+
10+
def main(args: Array[String]): Unit =
11+
assert(foo[Sum] == 2)
12+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
scalaVersion := sys.props("plugin.scalaVersion")
2+
// Use more precise invalidation, otherwise the reference to `Sum` in
3+
// Test.scala is enough to invalidate it when a child is added.
4+
ThisBuild / incOptions ~= { _.withUseOptimizedSealed(true) }
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
sealed trait Sum
2+
case class Child1() extends Sum
3+
case class Child2() extends Sum
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
> compile
2+
3+
# Add a child
4+
$ copy-file changes/Sum.scala Sum.scala
5+
6+
# Both Sum.scala and Test.scala should be recompiled, otherwise the assertion will fail
7+
> run

0 commit comments

Comments
 (0)