Skip to content

Commit a210b7f

Browse files
authored
More fixes to prepare dotc compilation with capture checking (#16251)
2 parents 2bc3d23 + d581015 commit a210b7f

File tree

13 files changed

+104
-30
lines changed

13 files changed

+104
-30
lines changed

compiler/src/dotty/tools/dotc/CompilationUnit.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ import util.{FreshNameCreator, SourceFile, NoSource}
99
import util.Spans.Span
1010
import ast.{tpd, untpd}
1111
import tpd.{Tree, TreeTraverser}
12+
import ast.Trees.{Import, Ident}
1213
import typer.Nullables
1314
import transform.SymUtils._
1415
import core.Decorators._
15-
import config.SourceVersion
16+
import config.{SourceVersion, Feature}
17+
import StdNames.nme
1618
import scala.annotation.internal.sharable
1719

1820
class CompilationUnit protected (val source: SourceFile) {
@@ -51,6 +53,12 @@ class CompilationUnit protected (val source: SourceFile) {
5153
*/
5254
var needsStaging: Boolean = false
5355

56+
/** Will be set to true if the unit contains a captureChecking language import */
57+
var needsCaptureChecking: Boolean = false
58+
59+
/** Will be set to true if the unit contains a pureFunctions language import */
60+
var knowsPureFuns: Boolean = false
61+
5462
var suspended: Boolean = false
5563
var suspendedAtInliningPhase: Boolean = false
5664

@@ -138,11 +146,20 @@ object CompilationUnit {
138146
private class Force extends TreeTraverser {
139147
var containsQuote = false
140148
var containsInline = false
149+
var containsCaptureChecking = false
141150
def traverse(tree: Tree)(using Context): Unit = {
142151
if (tree.symbol.isQuote)
143152
containsQuote = true
144153
if tree.symbol.is(Flags.Inline) then
145154
containsInline = true
155+
tree match
156+
case Import(qual, selectors) =>
157+
tpd.languageImport(qual) match
158+
case Some(prefix) =>
159+
for case untpd.ImportSelector(untpd.Ident(imported), untpd.EmptyTree, _) <- selectors do
160+
Feature.handleGlobalLanguageImport(prefix, imported)
161+
case _ =>
162+
case _ =>
146163
traverseChildren(tree)
147164
}
148165
}

compiler/src/dotty/tools/dotc/Run.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
163163
/** Actions that need to be performed at the end of the current compilation run */
164164
private var finalizeActions = mutable.ListBuffer[() => Unit]()
165165

166+
/** Will be set to true if any of the compiled compilation units contains
167+
* a pureFunctions or captureChecking language import.
168+
*/
169+
var pureFunsImportEncountered = false
170+
166171
def compile(files: List[AbstractFile]): Unit =
167172
try
168173
val sources = files.map(runContext.getSource(_))

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ extension (tree: Tree)
4141
tree.putAttachment(Captures, refs)
4242
refs
4343

44+
/** Under pureFunctions, add a @retainsByName(*)` annotation to the argument of
45+
* a by name parameter type, turning the latter into an impure by name parameter type.
46+
*/
47+
def adaptByNameArgUnderPureFuns(using Context): Tree =
48+
if Feature.pureFunsEnabledSomewhere then
49+
val rbn = defn.RetainsByNameAnnot
50+
Annotated(tree,
51+
New(rbn.typeRef).select(rbn.primaryConstructor).appliedTo(
52+
Typed(
53+
SeqLiteral(ref(defn.captureRoot) :: Nil, TypeTree(defn.AnyType)),
54+
TypeTree(defn.RepeatedParamType.appliedTo(defn.AnyType))
55+
)
56+
)
57+
)
58+
else tree
59+
4460
extension (tp: Type)
4561

4662
/** @pre `tp` is a CapturingType */
@@ -125,7 +141,7 @@ extension (tp: Type)
125141
*/
126142
def adaptFunctionTypeUnderPureFuns(using Context): Type = tp match
127143
case AppliedType(fn, args)
128-
if Feature.pureFunsEnabled && defn.isFunctionClass(fn.typeSymbol) =>
144+
if Feature.pureFunsEnabledSomewhere && defn.isFunctionClass(fn.typeSymbol) =>
129145
val fname = fn.typeSymbol.name
130146
defn.FunctionType(
131147
fname.functionArity,
@@ -135,6 +151,16 @@ extension (tp: Type)
135151
case _ =>
136152
tp
137153

154+
/** Under pureFunctions, add a @retainsByName(*)` annotation to the argument of
155+
* a by name parameter type, turning the latter into an impure by name parameter type.
156+
*/
157+
def adaptByNameArgUnderPureFuns(using Context): Type =
158+
if Feature.pureFunsEnabledSomewhere then
159+
AnnotatedType(tp,
160+
CaptureAnnotation(CaptureSet.universal, boxed = false)(defn.RetainsByNameAnnot))
161+
else
162+
tp
163+
138164
def isCapturingType(using Context): Boolean =
139165
tp match
140166
case CapturingType(_, _) => true

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ object CheckCaptures:
2626

2727
class Pre extends PreRecheck, SymTransformer:
2828

29-
override def isEnabled(using Context) = Feature.ccEnabled
29+
override def isEnabled(using Context) = true
3030

3131
/** Reset `private` flags of parameter accessors so that we can refine them
3232
* in Setup if they have non-empty capture sets. Special handling of some
@@ -133,13 +133,14 @@ class CheckCaptures extends Recheck, SymTransformer:
133133
import CheckCaptures.*
134134

135135
def phaseName: String = "cc"
136-
override def isEnabled(using Context) = Feature.ccEnabled
136+
override def isEnabled(using Context) = true
137137

138138
def newRechecker()(using Context) = CaptureChecker(ctx)
139139

140140
override def run(using Context): Unit =
141-
checkOverrides.traverse(ctx.compilationUnit.tpdTree)
142-
super.run
141+
if Feature.ccEnabled then
142+
checkOverrides.traverse(ctx.compilationUnit.tpdTree)
143+
super.run
143144

144145
override def transformSym(sym: SymDenotation)(using Context): SymDenotation =
145146
if Synthetics.needsTransform(sym) then Synthetics.transformFromCC(sym)

compiler/src/dotty/tools/dotc/config/Feature.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,18 @@ object Feature:
8080
def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)
8181

8282
def pureFunsEnabled(using Context) =
83-
enabled(pureFunctions) || ccEnabled
83+
enabledBySetting(pureFunctions)
84+
|| ctx.compilationUnit.knowsPureFuns
85+
|| ccEnabled
8486

85-
def ccEnabled(using Context) = enabled(captureChecking)
87+
def ccEnabled(using Context) =
88+
enabledBySetting(captureChecking)
89+
|| ctx.compilationUnit.needsCaptureChecking
90+
91+
def pureFunsEnabledSomewhere(using Context) =
92+
enabledBySetting(pureFunctions)
93+
|| enabledBySetting(captureChecking)
94+
|| ctx.run != null && ctx.run.nn.pureFunsImportEncountered
8695

8796
def sourceVersionSetting(using Context): SourceVersion =
8897
SourceVersion.valueOf(ctx.settings.source.value)
@@ -130,4 +139,16 @@ object Feature:
130139
def isExperimentalEnabled(using Context): Boolean =
131140
Properties.experimental && !ctx.settings.YnoExperimental.value
132141

142+
def handleGlobalLanguageImport(prefix: TermName, imported: Name)(using Context): Boolean =
143+
val fullFeatureName = QualifiedName(prefix, imported.asTermName)
144+
if fullFeatureName == pureFunctions then
145+
ctx.compilationUnit.knowsPureFuns = true
146+
if ctx.run != null then ctx.run.nn.pureFunsImportEncountered = true
147+
true
148+
else if fullFeatureName == captureChecking then
149+
ctx.compilationUnit.needsCaptureChecking = true
150+
if ctx.run != null then ctx.run.nn.pureFunsImportEncountered = true
151+
true
152+
else
153+
false
133154
end Feature

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ object StdNames {
426426
val canEqual_ : N = "canEqual"
427427
val canEqualAny : N = "canEqualAny"
428428
val caps: N = "caps"
429+
val captureChecking: N = "captureChecking"
429430
val checkInitialized: N = "checkInitialized"
430431
val classOf: N = "classOf"
431432
val classType: N = "classType"

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import ast.{Trees, tpd, untpd}
3232
import Trees._
3333
import Decorators._
3434
import transform.SymUtils._
35-
import cc.adaptFunctionTypeUnderPureFuns
35+
import cc.{adaptFunctionTypeUnderPureFuns, adaptByNameArgUnderPureFuns}
3636

3737
import dotty.tools.tasty.{TastyBuffer, TastyReader}
3838
import TastyBuffer._
@@ -455,7 +455,8 @@ class TreeUnpickler(reader: TastyReader,
455455
val ref = readAddr()
456456
typeAtAddr.getOrElseUpdate(ref, forkAt(ref).readType())
457457
case BYNAMEtype =>
458-
ExprType(readType())
458+
val arg = readType()
459+
ExprType(if knowsPureFuns then arg else arg.adaptByNameArgUnderPureFuns)
459460
case _ =>
460461
ConstantType(readConstant(tag))
461462
}
@@ -1178,7 +1179,8 @@ class TreeUnpickler(reader: TastyReader,
11781179
case SINGLETONtpt =>
11791180
SingletonTypeTree(readTerm())
11801181
case BYNAMEtpt =>
1181-
ByNameTypeTree(readTpt())
1182+
val arg = readTpt()
1183+
ByNameTypeTree(if knowsPureFuns then arg else arg.adaptByNameArgUnderPureFuns)
11821184
case NAMEDARG =>
11831185
NamedArg(readName(), readTerm())
11841186
case _ =>

compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import scala.collection.mutable
3333
import scala.collection.mutable.ListBuffer
3434
import scala.annotation.switch
3535
import reporting._
36-
import cc.adaptFunctionTypeUnderPureFuns
36+
import cc.{adaptFunctionTypeUnderPureFuns, adaptByNameArgUnderPureFuns}
3737

3838
object Scala2Unpickler {
3939

@@ -817,7 +817,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
817817
}
818818
val tycon = select(pre, sym)
819819
val args = until(end, () => readTypeRef())
820-
if (sym == defn.ByNameParamClass2x) ExprType(args.head)
820+
if (sym == defn.ByNameParamClass2x) ExprType(args.head.adaptByNameArgUnderPureFuns)
821821
else if (ctx.settings.scalajs.value && args.length == 2 &&
822822
sym.owner == JSDefinitions.jsdefn.ScalaJSJSPackageClass && sym == JSDefinitions.jsdefn.PseudoUnionClass) {
823823
// Treat Scala.js pseudo-unions as real unions, this requires a

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ object Parsers {
196196

197197
def isIdent = in.isIdent
198198
def isIdent(name: Name) = in.isIdent(name)
199-
def isPureArrow(name: Name): Boolean = in.pureFunsEnabled && isIdent(name)
199+
def isPureArrow(name: Name): Boolean = isIdent(name) && Feature.pureFunsEnabled
200200
def isPureArrow: Boolean = isPureArrow(nme.PUREARROW) || isPureArrow(nme.PURECTXARROW)
201201
def isErased = isIdent(nme.erased) && in.erasedEnabled
202202
def isSimpleLiteral =
@@ -972,7 +972,7 @@ object Parsers {
972972
* capture set `{ref1, ..., refN}` followed by a token that can start a type?
973973
*/
974974
def followingIsCaptureSet(): Boolean =
975-
in.featureEnabled(Feature.captureChecking) && {
975+
Feature.ccEnabled && {
976976
val lookahead = in.LookaheadScanner()
977977
def followingIsTypeStart() =
978978
lookahead.nextToken()
@@ -1485,7 +1485,7 @@ object Parsers {
14851485
if !imods.flags.isEmpty || params.isEmpty then
14861486
syntaxError(em"illegal parameter list for type lambda", start)
14871487
token = ARROW
1488-
else if in.pureFunsEnabled then
1488+
else if Feature.pureFunsEnabled then
14891489
// `=>` means impure function under pureFunctions or captureChecking
14901490
// language imports, whereas `->` is then a regular function.
14911491
imods |= Impure
@@ -1891,7 +1891,7 @@ object Parsers {
18911891
if in.token == ARROW || isPureArrow(nme.PUREARROW) then
18921892
val isImpure = in.token == ARROW
18931893
val tp = atSpan(in.skipToken()) { ByNameTypeTree(core()) }
1894-
if isImpure && in.pureFunsEnabled then ImpureByNameTypeTree(tp) else tp
1894+
if isImpure && Feature.pureFunsEnabled then ImpureByNameTypeTree(tp) else tp
18951895
else if in.token == LBRACE && followingIsCaptureSet() then
18961896
val start = in.offset
18971897
val cs = captureSet()
@@ -3308,10 +3308,8 @@ object Parsers {
33083308
languageImport(tree) match
33093309
case Some(prefix) =>
33103310
in.languageImportContext = in.languageImportContext.importContext(imp, NoSymbol)
3311-
for
3312-
case ImportSelector(id @ Ident(imported), EmptyTree, _) <- selectors
3313-
do
3314-
if globalOnlyImports.contains(QualifiedName(prefix, imported.asTermName)) && !outermost then
3311+
for case ImportSelector(id @ Ident(imported), EmptyTree, _) <- selectors do
3312+
if Feature.handleGlobalLanguageImport(prefix, imported) && !outermost then
33153313
syntaxError(i"this language import is only allowed at the toplevel", id.span)
33163314
if allSourceVersionNames.contains(imported) && prefix.isEmpty then
33173315
if !outermost then

compiler/src/dotty/tools/dotc/parsing/Scanners.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,6 @@ object Scanners {
230230
postfixOpsEnabledCtx = myLanguageImportContext
231231
postfixOpsEnabledCache
232232

233-
private var pureFunsEnabledCache = false
234-
private var pureFunsEnabledCtx: Context = NoContext
235-
236-
def pureFunsEnabled =
237-
if pureFunsEnabledCtx ne myLanguageImportContext then
238-
pureFunsEnabledCache = featureEnabled(Feature.pureFunctions) || featureEnabled(Feature.captureChecking)
239-
pureFunsEnabledCtx = myLanguageImportContext
240-
pureFunsEnabledCache
241-
242233
/** All doc comments kept by their end position in a `Map`.
243234
*
244235
* Note: the map is necessary since the comments are looked up after an

compiler/test/dotc/pos-test-pickling.blacklist

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ i13842.scala
8181
# Position change under captureChecking
8282
boxmap-paper.scala
8383

84+
# Function types print differnt after unpickling since test mispredicts Feature.preFundsEnabled
85+
caps-universal.scala
86+
87+
8488
# GADT cast applied to singleton type difference
8589
i4176-gadt.scala
8690

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
object A:
2+
def f(x: => Int) = ()
3+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import language.experimental.captureChecking
2+
object B:
3+
def test(x: => Int) = A.f(x)
4+
5+

0 commit comments

Comments
 (0)