Skip to content

Commit 6d672b1

Browse files
committed
Map regular function types to impure function types when unpickling
Map regular function types to impure function types when unpickling a class under -Ycc that was not itself compiled with -Ycc.
1 parent 9033e0c commit 6d672b1

File tree

6 files changed

+72
-4
lines changed

6 files changed

+72
-4
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package cc
55
import core.*
66
import Types.*, Symbols.*, Contexts.*, Annotations.*
77
import ast.{tpd, untpd}
8-
import Decorators.*
8+
import Decorators.*, NameOps.*
99
import config.Printers.capt
1010
import util.Property.Key
1111
import tpd.*
@@ -71,3 +71,17 @@ extension (tp: Type)
7171
atd.derivedAnnotatedType(parent.stripCapturing, annot)
7272
case _ =>
7373
tp
74+
75+
/** Under -Ycc, map regular function type to impure function type
76+
*/
77+
def adaptFunctionType(using Context): Type = tp match
78+
case AppliedType(fn, args)
79+
if ctx.settings.Ycc.value && defn.isFunctionClass(fn.typeSymbol) =>
80+
val fname = fn.typeSymbol.name
81+
defn.FunctionType(
82+
fname.functionArity,
83+
isContextual = fname.isContextFunction,
84+
isErased = fname.isErasedFunction,
85+
isImpure = true).appliedTo(args)
86+
case _ =>
87+
tp

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import ast.{TreeTypeMap, Trees, tpd, untpd}
3131
import Trees._
3232
import Decorators._
3333
import transform.SymUtils._
34+
import cc.adaptFunctionType
3435

3536
import dotty.tools.tasty.{TastyBuffer, TastyReader}
3637
import TastyBuffer._
@@ -87,6 +88,9 @@ class TreeUnpickler(reader: TastyReader,
8788
/** The root owner tree. See `OwnerTree` class definition. Set by `enterTopLevel`. */
8889
private var ownerTree: OwnerTree = _
8990

91+
/** Was unpickled class compiled with -Ycc? */
92+
private var wasCaptureChecked: Boolean = false
93+
9094
private def registerSym(addr: Addr, sym: Symbol) =
9195
symAtAddr(addr) = sym
9296

@@ -357,7 +361,7 @@ class TreeUnpickler(reader: TastyReader,
357361
// Note that the lambda "rt => ..." is not equivalent to a wildcard closure!
358362
// Eta expansion of the latter puts readType() out of the expression.
359363
case APPLIEDtype =>
360-
readType().appliedTo(until(end)(readType()))
364+
postProcessFunction(readType().appliedTo(until(end)(readType())))
361365
case TYPEBOUNDS =>
362366
val lo = readType()
363367
if nothingButMods(end) then
@@ -470,6 +474,12 @@ class TreeUnpickler(reader: TastyReader,
470474
def readTermRef()(using Context): TermRef =
471475
readType().asInstanceOf[TermRef]
472476

477+
/** Under -Ycc, map all function types to impure function types,
478+
* unless the unpickled class was also compiled with -Ycc.
479+
*/
480+
private def postProcessFunction(tp: Type)(using Context): Type =
481+
if wasCaptureChecked then tp else tp.adaptFunctionType
482+
473483
// ------ Reading definitions -----------------------------------------------------
474484

475485
private def nothingButMods(end: Addr): Boolean =
@@ -605,6 +615,8 @@ class TreeUnpickler(reader: TastyReader,
605615
}
606616
registerSym(start, sym)
607617
if (isClass) {
618+
if sym.owner.is(Package) && annots.exists(_.symbol == defn.CaptureCheckedAnnot) then
619+
wasCaptureChecked = true
608620
sym.completer.withDecls(newScope)
609621
forkAt(templateStart).indexTemplateParams()(using localContext(sym))
610622
}
@@ -1265,7 +1277,7 @@ class TreeUnpickler(reader: TastyReader,
12651277
val args = until(end)(readTpt())
12661278
val tree = untpd.AppliedTypeTree(tycon, args)
12671279
val ownType = ctx.typeAssigner.processAppliedType(tree, tycon.tpe.safeAppliedTo(args.tpes))
1268-
tree.withType(ownType)
1280+
tree.withType(postProcessFunction(ownType))
12691281
case ANNOTATEDtpt =>
12701282
Annotated(readTpt(), readTerm())
12711283
case LAMBDAtpt =>

compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import scala.collection.mutable
3030
import scala.collection.mutable.ListBuffer
3131
import scala.annotation.switch
3232
import reporting._
33+
import cc.adaptFunctionType
3334

3435
object Scala2Unpickler {
3536

@@ -818,7 +819,9 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
818819
// special-case in erasure, see TypeErasure#eraseInfo.
819820
OrType(args(0), args(1), soft = false)
820821
}
821-
else if (args.nonEmpty) tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly)))
822+
else if args.nonEmpty then
823+
tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly)))
824+
.adaptFunctionType
822825
else if (sym.typeParams.nonEmpty) tycon.EtaExpand(sym.typeParams)
823826
else tycon
824827
case TYPEBOUNDStpe =>
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
object Lib:
2+
extension [A](xs: Seq[A])
3+
def mapp[B](f: A => B): Seq[B] =
4+
xs.map(f.asInstanceOf[A -> B])
5+
6+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import language.experimental.saferExceptions
2+
import Lib.*
3+
4+
class LimitExceeded extends Exception
5+
6+
val limit = 10e9
7+
8+
def f(x: Double): Double throws LimitExceeded =
9+
if x < limit then x * x else throw LimitExceeded()
10+
11+
@main def test(xs: Double*) =
12+
try println(xs.mapp(f).sum)
13+
catch case ex: LimitExceeded => println("too large")
14+
15+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import language.experimental.saferExceptions
2+
3+
class LimitExceeded extends Exception
4+
5+
val limit = 10e9
6+
7+
extension [A](xs: Seq[A])
8+
def mapp[B](f: A => B): Seq[B] =
9+
xs.map(f.asInstanceOf[A -> B])
10+
11+
def f(x: Double): Double throws LimitExceeded =
12+
if x < limit then x * x else throw LimitExceeded()
13+
14+
@main def test(xs: Double*) =
15+
try println(xs.mapp(f).sum + xs.map(f).sum)
16+
catch case ex: LimitExceeded => println("too large")
17+
18+

0 commit comments

Comments
 (0)