@@ -946,7 +946,7 @@ class Typer extends Namer
946
946
* def double(x: Char): String = s"$x$x"
947
947
* "abc" flatMap double
948
948
*/
949
- private def decomposeProtoFunction (pt : Type , defaultArity : Int )(using Context ): (List [Type ], untpd.Tree ) = {
949
+ private def decomposeProtoFunction (pt : Type , defaultArity : Int , tree : untpd. Tree )(using Context ): (List [Type ], untpd.Tree ) = {
950
950
def typeTree (tp : Type ) = tp match {
951
951
case _ : WildcardType => untpd.TypeTree ()
952
952
case _ => untpd.TypeTree (tp)
@@ -957,7 +957,15 @@ class Typer extends Namer
957
957
newTypeVar(apply(bounds.orElse(TypeBounds .empty)).bounds)
958
958
case _ => mapOver(t)
959
959
}
960
- pt.stripTypeVar.dealias match {
960
+ val pt1 = pt.stripTypeVar.dealias
961
+ if (pt1 ne pt1.dropDependentRefinement)
962
+ && defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
963
+ then
964
+ ctx.error(
965
+ i """ Implementation restriction: Expected result type $pt1
966
+ |is a curried dependent context function type. Such types are not yet supported. """ ,
967
+ tree.sourcePos)
968
+ pt1 match {
961
969
case pt1 if defn.isNonRefinedFunction(pt1) =>
962
970
// if expected parameter type(s) are wildcards, approximate from below.
963
971
// if expected result type is a wildcard, approximate from above.
@@ -970,7 +978,7 @@ class Typer extends Namer
970
978
else
971
979
typeTree(restpe))
972
980
case tp : TypeParamRef =>
973
- decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity)
981
+ decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree )
974
982
case _ =>
975
983
(List .tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree ())
976
984
}
@@ -1131,7 +1139,7 @@ class Typer extends Namer
1131
1139
case _ =>
1132
1140
}
1133
1141
1134
- val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
1142
+ val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree )
1135
1143
1136
1144
/** The inferred parameter type for a parameter in a lambda that does
1137
1145
* not have an explicit type given.
@@ -1261,7 +1269,7 @@ class Typer extends Namer
1261
1269
typedMatchFinish(tree, tpd.EmptyTree , defn.ImplicitScrutineeTypeRef , cases1, pt)
1262
1270
}
1263
1271
else {
1264
- val (protoFormals, _) = decomposeProtoFunction(pt, 1 )
1272
+ val (protoFormals, _) = decomposeProtoFunction(pt, 1 , tree )
1265
1273
val checkMode =
1266
1274
if (pt.isRef(defn.PartialFunctionClass )) desugar.MatchCheck .None
1267
1275
else desugar.MatchCheck .Exhaustive
@@ -1447,17 +1455,40 @@ class Typer extends Namer
1447
1455
}
1448
1456
1449
1457
def typedReturn (tree : untpd.Return )(using Context ): Return = {
1458
+
1459
+ /** If `pt` is a context function type, its return type. If the CFT
1460
+ * is dependent, instantiate with the parameters of the associated
1461
+ * anonymous function.
1462
+ * @param paramss the parameters of the anonymous functions
1463
+ * enclosing the return expression
1464
+ */
1465
+ def instantiateCFT (pt : Type , paramss : => List [List [Symbol ]]): Type =
1466
+ val ift = defn.asContextFunctionType(pt)
1467
+ if ift.exists then
1468
+ ift.nonPrivateMember(nme.apply).info match
1469
+ case appType : MethodType =>
1470
+ instantiateCFT(appType.instantiate(paramss.head.map(_.termRef)), paramss.tail)
1471
+ else pt
1472
+
1450
1473
def returnProto (owner : Symbol , locals : Scope ): Type =
1451
1474
if (owner.isConstructor) defn.UnitType
1452
- else owner.info match {
1453
- case info : PolyType =>
1454
- val tparams = locals.toList.takeWhile(_ is TypeParam )
1455
- assert(info.paramNames.length == tparams.length,
1456
- i " return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, % " )
1457
- info.instantiate(tparams.map(_.typeRef)).finalResultType
1458
- case info =>
1459
- info.finalResultType
1460
- }
1475
+ else
1476
+ val rt = owner.info match
1477
+ case info : PolyType =>
1478
+ val tparams = locals.toList.takeWhile(_ is TypeParam )
1479
+ assert(info.paramNames.length == tparams.length,
1480
+ i " return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, % " )
1481
+ info.instantiate(tparams.map(_.typeRef)).finalResultType
1482
+ case info =>
1483
+ info.finalResultType
1484
+ def iftParamss = ctx.owner.ownersIterator
1485
+ .filter(_.is(Method , butNot = Accessor ))
1486
+ .takeWhile(_.isAnonymousFunction)
1487
+ .toList
1488
+ .reverse
1489
+ .map(_.paramSymss.head)
1490
+ instantiateCFT(rt, iftParamss)
1491
+
1461
1492
def enclMethInfo (cx : Context ): (Tree , Type ) = {
1462
1493
val owner = cx.owner
1463
1494
if (owner.isType) {
@@ -3155,7 +3186,7 @@ class Typer extends Namer
3155
3186
3156
3187
def isContextFunctionRef (wtp : Type ): Boolean = wtp match {
3157
3188
case RefinedType (parent, nme.apply, _) =>
3158
- isContextFunctionRef(parent) // apply refinements indicate a dependent IFT
3189
+ isContextFunctionRef(parent) // apply refinements indicate a dependent CFT
3159
3190
case _ =>
3160
3191
val underlying = wtp.underlyingClassRef(refinementOK = false ) // other refinements are not OK
3161
3192
defn.isContextFunctionClass(underlying.classSymbol)
0 commit comments