Skip to content

Commit f37afdf

Browse files
committed
Handle captures in by-name parameters
1. Infrastructure to deal with capturesets in byname parameters 2. Handle retainsByName annotations in ElimByName Convert them to regular annotations on the generated function types. This enables capture checking on by-name parameters. 3. Add a style warning for misleading by-name parameter type formatting. By-name types should be formatted `{...}-> T`. `{...} -> T` looks too much like a function type.
1 parent 301b772 commit f37afdf

24 files changed

+233
-84
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ object desugar {
468468

469469
if mods.is(Trait) then
470470
for vparams <- originalVparamss; vparam <- vparams do
471-
if vparam.tpt.isInstanceOf[ByNameTypeTree] then
471+
if isByNameType(vparam.tpt) then
472472
report.error(em"implementation restriction: traits cannot have by name parameters", vparam.srcPos)
473473

474474
// Annotations on class _type_ parameters are set on the derived parameters
@@ -576,9 +576,8 @@ object desugar {
576576
appliedTypeTree(tycon, targs)
577577
}
578578

579-
def isRepeated(tree: Tree): Boolean = tree match {
579+
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
580580
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
581-
case ByNameTypeTree(tree1) => isRepeated(tree1)
582581
case _ => false
583582
}
584583

@@ -1779,8 +1778,13 @@ object desugar {
17791778
case ext: ExtMethods =>
17801779
Block(List(ext), Literal(Constant(())).withSpan(ext.span))
17811780
case CapturingTypeTree(refs, parent) =>
1782-
val annot = New(scalaDot(tpnme.retains), List(refs))
1783-
Annotated(parent, annot)
1781+
def annotate(annotName: TypeName, tp: Tree) =
1782+
Annotated(tp, New(scalaDot(annotName), List(refs)))
1783+
parent match
1784+
case ByNameTypeTree(restpt) =>
1785+
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
1786+
case _ =>
1787+
annotate(tpnme.retains, parent)
17841788
}
17851789
desugared.withSpan(tree.span)
17861790
}

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
172172
}
173173

174174
/** Is tpt a vararg type of the form T* or => T*? */
175-
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = tpt match {
176-
case ByNameTypeTree(tpt1) => isRepeatedParamType(tpt1)
175+
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = stripByNameType(tpt) match {
177176
case tpt: TypeTree => tpt.typeOpt.isRepeatedParam
178177
case AppliedTypeTree(Select(_, tpnme.REPEATED_PARAM_CLASS), _) => true
179178
case _ => false
@@ -190,6 +189,16 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
190189
case arg => arg.typeOpt.widen.isRepeatedParam
191190
}
192191

192+
def isByNameType(tree: Tree)(using Context): Boolean =
193+
stripByNameType(tree) ne tree
194+
195+
def stripByNameType(tree: Tree)(using Context): Tree = unsplice(tree) match
196+
case ByNameTypeTree(t1) => t1
197+
case untpd.CapturingTypeTree(_, parent) =>
198+
val parent1 = stripByNameType(parent)
199+
if parent1 eq parent then tree else parent1
200+
case _ => tree
201+
193202
/** All type and value parameter symbols of this DefDef */
194203
def allParamSyms(ddef: DefDef)(using Context): List[Symbol] =
195204
ddef.paramss.flatten.map(_.symbol)
@@ -382,6 +391,16 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
382391
case _ => None
383392
}
384393
}
394+
395+
object ImpureByNameTypeTree:
396+
def apply(tp: ByNameTypeTree)(using Context): untpd.CapturingTypeTree =
397+
untpd.CapturingTypeTree(
398+
Ident(nme.CAPTURE_ROOT).withSpan(tp.span.startPos) :: Nil, tp)
399+
def unapply(tp: Tree)(using Context): Option[ByNameTypeTree] = tp match
400+
case untpd.CapturingTypeTree(id @ Ident(nme.CAPTURE_ROOT) :: Nil, bntp: ByNameTypeTree)
401+
if id.span == bntp.span.startPos => Some(bntp)
402+
case _ => None
403+
end ImpureByNameTypeTree
385404
}
386405

387406
trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import printing.Printer
1212
import printing.Texts.Text
1313

1414

15-
case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation:
15+
case class CaptureAnnotation(refs: CaptureSet, kind: CapturingKind) extends Annotation:
1616
import CaptureAnnotation.*
1717
import tpd.*
1818

@@ -25,25 +25,26 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
2525
val arg = repeated(elems, TypeTree(defn.AnyType))
2626
New(symbol.typeRef, arg :: Nil)
2727

28-
override def symbol(using Context) = defn.RetainsAnnot
28+
override def symbol(using Context) =
29+
if kind == CapturingKind.ByName then defn.RetainsByNameAnnot else defn.RetainsAnnot
2930

3031
override def derivedAnnotation(tree: Tree)(using Context): Annotation =
3132
unsupported("derivedAnnotation(Tree)")
3233

33-
def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation =
34-
if (this.refs eq refs) && (this.boxed == boxed) then this
35-
else CaptureAnnotation(refs, boxed)
34+
def derivedAnnotation(refs: CaptureSet, kind: CapturingKind)(using Context): Annotation =
35+
if (this.refs eq refs) && (this.kind == kind) then this
36+
else CaptureAnnotation(refs, kind)
3637

3738
override def sameAnnotation(that: Annotation)(using Context): Boolean = that match
38-
case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2
39+
case CaptureAnnotation(refs2, kind2) => refs == refs2 && kind == kind2
3940
case _ => false
4041

4142
override def mapWith(tp: TypeMap)(using Context) =
4243
val elems = refs.elems.toList
4344
val elems1 = elems.mapConserve(tp)
4445
if elems1 eq elems then this
4546
else if elems1.forall(_.isInstanceOf[CaptureRef])
46-
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed)
47+
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), kind)
4748
else EmptyAnnotation
4849

4950
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
@@ -54,10 +55,11 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
5455

5556
override def toText(printer: Printer): Text = refs.toText(printer)
5657

57-
override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0)
58+
override def hash: Int =
59+
(refs.hashCode << 1) | (if kind == CapturingKind.Regular then 0 else 1)
5860

5961
override def eql(that: Annotation) = that match
60-
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed)
62+
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.kind == kind)
6163
case _ => false
6264

6365
end CaptureAnnotation

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ extension (tree: Tree)
4343
extension (tp: Type)
4444

4545
def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match
46-
case CapturingType(p, r, b) =>
46+
case CapturingType(p, r, k) =>
4747
if (parent eq p) && (refs eq r) then tp
48-
else CapturingType(parent, refs, b)
48+
else CapturingType(parent, refs, k)
4949

5050
/** If this is type variable instantiated or upper bounded with a capturing type,
5151
* the capture set associated with that type. Extended to and-or types and
@@ -54,7 +54,8 @@ extension (tp: Type)
5454
*/
5555
def boxedCaptured(using Context): CaptureSet =
5656
def getBoxed(tp: Type): CaptureSet = tp match
57-
case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty
57+
case CapturingType(_, refs, CapturingKind.Boxed) => refs
58+
case CapturingType(_, _, _) => CaptureSet.empty
5859
case tp: TypeProxy => getBoxed(tp.superType)
5960
case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2)
6061
case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2)

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ sealed abstract class CaptureSet extends Showable:
209209
((NoType: Type) /: elems) ((tp, ref) =>
210210
if tp.exists then OrType(tp, ref, soft = false) else ref)
211211

212-
def toRegularAnnotation(using Context): Annotation =
213-
Annotation(CaptureAnnotation(this, boxed = false).tree)
212+
def toRegularAnnotation(byName: Boolean)(using Context): Annotation =
213+
val kind = if byName then CapturingKind.ByName else CapturingKind.Regular
214+
Annotation(CaptureAnnotation(this, kind).tree)
214215

215216
override def toText(printer: Printer): Text =
216217
Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
5+
/** Possible kinds of captures */
6+
enum CapturingKind:
7+
case Regular // normal capture
8+
case Boxed // capture under box
9+
case ByName // capture applies to enclosing by-name type (only possible before ElimByName)

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,46 @@ package cc
55
import core.*
66
import Types.*, Symbols.*, Contexts.*
77

8+
/** A capturing type. This is internally represented as an annotated type with a `retains`
9+
* annotation, but the extractor will succeed only at phase CheckCaptures.
10+
* Annotated types with `@retainsByName` annotation can also be created that way, by
11+
* giving a `CapturingKind.ByName` as `kind` argument, but they are never extracted,
12+
* since they have already been converted to regular capturing types before CheckCaptures.
13+
*/
814
object CapturingType:
915

10-
def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type =
16+
def apply(parent: Type, refs: CaptureSet, kind: CapturingKind)(using Context): Type =
1117
if refs.isAlwaysEmpty then parent
12-
else AnnotatedType(parent, CaptureAnnotation(refs, boxed))
13-
14-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
15-
if ctx.phase == Phases.checkCapturesPhase then EventuallyCapturingType.unapply(tp)
18+
else AnnotatedType(parent, CaptureAnnotation(refs, kind))
19+
20+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
21+
if ctx.phase == Phases.checkCapturesPhase then
22+
val r = EventuallyCapturingType.unapply(tp)
23+
r match
24+
case Some((_, _, CapturingKind.ByName)) => None
25+
case _ => r
1626
else None
1727

1828
end CapturingType
1929

30+
/** An extractor for types that will be capturing types at phase CheckCaptures. Also
31+
* included are types that indicate captures on enclosing call-by-name parameters
32+
* before phase ElimByName
33+
*/
2034
object EventuallyCapturingType:
2135

22-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
23-
if tp.annot.symbol == defn.RetainsAnnot then
36+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
37+
val sym = tp.annot.symbol
38+
if sym == defn.RetainsAnnot || sym == defn.RetainsByNameAnnot then
2439
tp.annot match
25-
case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed))
40+
case ann: CaptureAnnotation =>
41+
Some((tp.parent, ann.refs, ann.kind))
2642
case ann =>
27-
try Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing))
43+
val kind =
44+
if ann.tree.isBoxedCapturing then CapturingKind.Boxed
45+
else if sym == defn.RetainsByNameAnnot then CapturingKind.ByName
46+
else CapturingKind.Regular
47+
try Some((tp.parent, ann.tree.toCaptureSet, kind))
2848
catch case ex: IllegalCaptureRef => None
2949
else None
3050

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ extends tpd.TreeTraverser:
2525
.toFunctionType(isJava = false, alwaysDependent = true)
2626

2727
private def box(tp: Type)(using Context): Type = tp match
28-
case CapturingType(parent, refs, false) => CapturingType(parent, refs, true)
28+
case CapturingType(parent, refs, CapturingKind.Regular) =>
29+
CapturingType(parent, refs, CapturingKind.Boxed)
2930
case _ => tp
3031

3132
private def setBoxed(tp: Type)(using Context) = tp match
@@ -77,7 +78,7 @@ extends tpd.TreeTraverser:
7778
cls.paramGetters.foldLeft(tp) { (core, getter) =>
7879
if getter.termRef.isTracked then
7980
val getterType = tp.memberInfo(getter).strippedDealias
80-
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false))
81+
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), CapturingKind.Regular))
8182
.showing(i"add capture refinement $tp --> $result", capt)
8283
else
8384
core
@@ -130,7 +131,7 @@ extends tpd.TreeTraverser:
130131
case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) =>
131132
CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2)
132133
case _ if canHaveInferredCapture(tp) =>
133-
CapturingType(tp, CaptureSet.Var(), boxed = false)
134+
CapturingType(tp, CaptureSet.Var(), CapturingKind.Regular)
134135
case _ =>
135136
tp
136137

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import Comments.CommentsContext
1515
import Comments.Comment
1616
import util.Spans.NoSpan
1717
import Symbols.requiredModuleRef
18-
import cc.{CapturingType, CaptureSet}
18+
import cc.{CapturingType, CaptureSet, CapturingKind, EventuallyCapturingType}
1919

2020
import scala.annotation.tailrec
2121

@@ -118,9 +118,9 @@ class Definitions {
118118
*
119119
* ErasedFunctionN and ErasedContextFunctionN erase to Function0.
120120
*
121-
* EffXYZFunctionN afollow this template:
121+
* ImpureXYZFunctionN follow this template:
122122
*
123-
* type EffXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
123+
* type ImpureXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
124124
*/
125125
private def newFunctionNType(name: TypeName): Symbol = {
126126
val impure = name.startsWith("Impure")
@@ -136,7 +136,7 @@ class Definitions {
136136
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
137137
tl => List.fill(arity + 1)(TypeBounds.empty),
138138
tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
139-
CaptureSet.universal, boxed = false)
139+
CaptureSet.universal, CapturingKind.Regular)
140140
))
141141
else
142142
val cls = denot.asClass.classSymbol
@@ -1015,6 +1015,7 @@ class Definitions {
10151015
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
10161016
@tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since")
10171017
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains")
1018+
@tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.retainsByName")
10181019

10191020
@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")
10201021

@@ -1148,9 +1149,16 @@ class Definitions {
11481149
}
11491150
}
11501151

1152+
/** Extractor for function types representing by-name parameters, of the form
1153+
* `() ?=> T`.
1154+
* Under -Ycc, this becomes `() ?-> T` or `{r1, ..., rN} () ?-> T`.
1155+
*/
11511156
object ByNameFunction:
1152-
def apply(tp: Type)(using Context): Type =
1153-
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
1157+
def apply(tp: Type)(using Context): Type = tp match
1158+
case EventuallyCapturingType(tp1, refs, CapturingKind.ByName) =>
1159+
CapturingType(apply(tp1), refs, CapturingKind.Regular)
1160+
case _ =>
1161+
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
11541162
def unapply(tp: Type)(using Context): Option[Type] = tp match
11551163
case tp @ AppliedType(tycon, arg :: Nil) if defn.isByNameFunctionClass(tycon.typeSymbol) =>
11561164
Some(arg)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ object StdNames {
565565
val reify : N = "reify"
566566
val releaseFence : N = "releaseFence"
567567
val retains: N = "retains"
568+
val retainsByName: N = "retainsByName"
568569
val rootMirror : N = "rootMirror"
569570
val run: N = "run"
570571
val runOrElse: N = "runOrElse"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import typer.ProtoTypes.constrained
2424
import typer.Applications.productSelectorTypes
2525
import reporting.trace
2626
import annotation.constructorOnly
27-
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing}
27+
import cc.{CapturingType, derivedCapturingType, CaptureSet, CapturingKind, stripCapturing}
2828

2929
/** Provides methods to compare types.
3030
*/
@@ -858,7 +858,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
858858
tp1 match
859859
case tp1: CaptureRef if tp1.isTracked =>
860860
val stripped = tp1w.stripCapturing
861-
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false)
861+
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, CapturingKind.Regular)
862862
case _ =>
863863
isSubType(tp1w, tp2, approx.addLow)
864864
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import scala.util.hashing.{ MurmurHash3 => hashing }
3535
import config.Printers.{core, typr, matchTypes}
3636
import reporting.{trace, Message}
3737
import java.lang.ref.WeakReference
38-
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing}
38+
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing, CapturingKind}
3939
import CaptureSet.CompareResult
4040

4141
import scala.annotation.internal.sharable
@@ -1875,13 +1875,15 @@ object Types {
18751875

18761876
def capturing(ref: CaptureRef)(using Context): Type =
18771877
if captureSet.accountsFor(ref) then this
1878-
else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing)
1878+
else CapturingType(this, ref.singletonCaptureSet,
1879+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18791880

18801881
def capturing(cs: CaptureSet)(using Context): Type =
18811882
if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this
18821883
else this match
18831884
case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs)
1884-
case _ => CapturingType(this, cs, this.isBoxedCapturing)
1885+
case _ => CapturingType(this, cs,
1886+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18851887

18861888
/** The set of distinct symbols referred to by this type, after all aliases are expanded */
18871889
def coveringSet(using Context): Set[Symbol] =
@@ -3812,10 +3814,11 @@ object Types {
38123814
CapturingType(parent1, CaptureSet.universal, boxed))
38133815
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
38143816
val parent1 = mapOver(parent)
3815-
if ann.symbol == defn.RetainsAnnot then
3817+
if ann.symbol == defn.RetainsAnnot || ann.symbol == defn.RetainsByNameAnnot then
3818+
val byName = ann.symbol == defn.RetainsByNameAnnot
38163819
range(
3817-
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation),
3818-
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation))
3820+
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation(byName)),
3821+
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation(byName)))
38193822
else
38203823
parent1
38213824
case _ => mapOver(tp)

0 commit comments

Comments
 (0)