Skip to content

Commit 2ca1165

Browse files
Merge pull request #9303 from dotty-staging/fix-ift-return
Disallow curried dependent context function types
2 parents 7723864 + 3d42606 commit 2ca1165

File tree

6 files changed

+99
-23
lines changed

6 files changed

+99
-23
lines changed

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,10 @@ class Namer { typer: Typer =>
246246
val xtree = expanded(tree)
247247
xtree.getAttachment(TypedAhead) match {
248248
case Some(ttree) => ttree.symbol
249-
case none => xtree.attachment(SymOfTree)
249+
case none =>
250+
xtree.getAttachment(SymOfTree) match
251+
case Some(sym) => sym
252+
case _ => throw IllegalArgumentException(i"$xtree does not have a symbol")
250253
}
251254
}
252255

@@ -443,14 +446,11 @@ class Namer { typer: Typer =>
443446
/** If `sym` exists, enter it in effective scope. Check that
444447
* package members are not entered twice in the same run.
445448
*/
446-
def enterSymbol(sym: Symbol)(using Context): Symbol = {
449+
def enterSymbol(sym: Symbol)(using Context): Unit =
447450
// We do not enter Scala 2 macros defined in Scala 3 as they have an equivalent Scala 3 inline method.
448-
if (sym.exists && !sym.isScala2MacroInScala3) {
451+
if sym.exists && !sym.isScala2MacroInScala3 then
449452
typr.println(s"entered: $sym in ${ctx.owner}")
450453
ctx.enter(sym)
451-
}
452-
sym
453-
}
454454

455455
/** Create package if it does not yet exist. */
456456
private def createPackageSymbol(pid: RefTree)(using Context): Symbol = {
@@ -540,7 +540,8 @@ class Namer { typer: Typer =>
540540
case imp: Import =>
541541
ctx.importContext(imp, createSymbol(imp))
542542
case mdef: DefTree =>
543-
val sym = enterSymbol(createSymbol(mdef))
543+
val sym = createSymbol(mdef)
544+
enterSymbol(sym)
544545
setDocstring(sym, origStat)
545546
addEnumConstants(mdef, sym)
546547
ctx

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ class Typer extends Namer
946946
* def double(x: Char): String = s"$x$x"
947947
* "abc" flatMap double
948948
*/
949-
private def decomposeProtoFunction(pt: Type, defaultArity: Int)(using Context): (List[Type], untpd.Tree) = {
949+
private def decomposeProtoFunction(pt: Type, defaultArity: Int, tree: untpd.Tree)(using Context): (List[Type], untpd.Tree) = {
950950
def typeTree(tp: Type) = tp match {
951951
case _: WildcardType => untpd.TypeTree()
952952
case _ => untpd.TypeTree(tp)
@@ -957,7 +957,15 @@ class Typer extends Namer
957957
newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds)
958958
case _ => mapOver(t)
959959
}
960-
pt.stripTypeVar.dealias match {
960+
val pt1 = pt.stripTypeVar.dealias
961+
if (pt1 ne pt1.dropDependentRefinement)
962+
&& defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
963+
then
964+
ctx.error(
965+
i"""Implementation restriction: Expected result type $pt1
966+
|is a curried dependent context function type. Such types are not yet supported.""",
967+
tree.sourcePos)
968+
pt1 match {
961969
case pt1 if defn.isNonRefinedFunction(pt1) =>
962970
// if expected parameter type(s) are wildcards, approximate from below.
963971
// if expected result type is a wildcard, approximate from above.
@@ -970,7 +978,7 @@ class Typer extends Namer
970978
else
971979
typeTree(restpe))
972980
case tp: TypeParamRef =>
973-
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity)
981+
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree)
974982
case _ =>
975983
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
976984
}
@@ -1131,7 +1139,7 @@ class Typer extends Namer
11311139
case _ =>
11321140
}
11331141

1134-
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
1142+
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree)
11351143

11361144
/** The inferred parameter type for a parameter in a lambda that does
11371145
* not have an explicit type given.
@@ -1261,7 +1269,7 @@ class Typer extends Namer
12611269
typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt)
12621270
}
12631271
else {
1264-
val (protoFormals, _) = decomposeProtoFunction(pt, 1)
1272+
val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree)
12651273
val checkMode =
12661274
if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None
12671275
else desugar.MatchCheck.Exhaustive
@@ -1447,17 +1455,40 @@ class Typer extends Namer
14471455
}
14481456

14491457
def typedReturn(tree: untpd.Return)(using Context): Return = {
1458+
1459+
/** If `pt` is a context function type, its return type. If the CFT
1460+
* is dependent, instantiate with the parameters of the associated
1461+
* anonymous function.
1462+
* @param paramss the parameters of the anonymous functions
1463+
* enclosing the return expression
1464+
*/
1465+
def instantiateCFT(pt: Type, paramss: => List[List[Symbol]]): Type =
1466+
val ift = defn.asContextFunctionType(pt)
1467+
if ift.exists then
1468+
ift.nonPrivateMember(nme.apply).info match
1469+
case appType: MethodType =>
1470+
instantiateCFT(appType.instantiate(paramss.head.map(_.termRef)), paramss.tail)
1471+
else pt
1472+
14501473
def returnProto(owner: Symbol, locals: Scope): Type =
14511474
if (owner.isConstructor) defn.UnitType
1452-
else owner.info match {
1453-
case info: PolyType =>
1454-
val tparams = locals.toList.takeWhile(_ is TypeParam)
1455-
assert(info.paramNames.length == tparams.length,
1456-
i"return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, %")
1457-
info.instantiate(tparams.map(_.typeRef)).finalResultType
1458-
case info =>
1459-
info.finalResultType
1460-
}
1475+
else
1476+
val rt = owner.info match
1477+
case info: PolyType =>
1478+
val tparams = locals.toList.takeWhile(_ is TypeParam)
1479+
assert(info.paramNames.length == tparams.length,
1480+
i"return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, %")
1481+
info.instantiate(tparams.map(_.typeRef)).finalResultType
1482+
case info =>
1483+
info.finalResultType
1484+
def iftParamss = ctx.owner.ownersIterator
1485+
.filter(_.is(Method, butNot = Accessor))
1486+
.takeWhile(_.isAnonymousFunction)
1487+
.toList
1488+
.reverse
1489+
.map(_.paramSymss.head)
1490+
instantiateCFT(rt, iftParamss)
1491+
14611492
def enclMethInfo(cx: Context): (Tree, Type) = {
14621493
val owner = cx.owner
14631494
if (owner.isType) {
@@ -3155,7 +3186,7 @@ class Typer extends Namer
31553186

31563187
def isContextFunctionRef(wtp: Type): Boolean = wtp match {
31573188
case RefinedType(parent, nme.apply, _) =>
3158-
isContextFunctionRef(parent) // apply refinements indicate a dependent IFT
3189+
isContextFunctionRef(parent) // apply refinements indicate a dependent CFT
31593190
case _ =>
31603191
val underlying = wtp.underlyingClassRef(refinementOK = false) // other refinements are not OK
31613192
defn.isContextFunctionClass(underlying.classSymbol)

tests/neg/curried-dependent-ift.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
trait Ctx1:
2+
type T
3+
val x: T
4+
val y: T
5+
6+
trait Ctx2:
7+
type T
8+
val x: T
9+
val y: T
10+
11+
trait A
12+
trait B
13+
14+
def h(x: Boolean): A ?=> B ?=> (A, B) =
15+
(summon[A], summon[B]) // OK
16+
17+
def g(x: Boolean): (c1: Ctx1) ?=> Ctx2 ?=> (c1.T, Ctx2) =
18+
return ??? // error
19+
???

tests/neg/i4668.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ trait Functor[F[_]] { def map[A,B](x: F[A])(f: A => B): F[B] }
88
object Functor { implicit object listFun extends Functor[List] { def map[A,B](ls: List[A])(f: A => B) = ls.map(f) } }
99

1010
val map: (A:Type,B:Type,F:Type1) ?=> (Functor[F.T]) ?=> (F.T[A.T]) => (A.T => B.T) => F.T[B.T] =
11-
fun ?=> x => f => fun.map(x)(f) // error // error // error: Missing parameter type
11+
fun ?=> x => f => fun.map(x)(f) // error

tests/run/ift-return.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
(22,abc)
2+
(22,def)

tests/run/ift-return.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
trait A:
2+
val x: Int
3+
4+
trait Ctx:
5+
type T
6+
val x: T
7+
val y: T
8+
9+
def f(x: Boolean): A ?=> (c: Ctx) ?=> (Int, c.T) =
10+
if x then return (summon[A].x, summon[Ctx].x)
11+
(summon[A].x, summon[Ctx].y)
12+
13+
@main def Test =
14+
given A:
15+
val x = 22
16+
given Ctx:
17+
type T = String
18+
val x = "abc"
19+
val y = "def"
20+
21+
println(f(true))
22+
println(f(false))
23+

0 commit comments

Comments
 (0)