Skip to content

Commit 85b331d

Browse files
committed
Mark inferred closure parameters as InferredTypeTrees
# Conflicts: # compiler/src/dotty/tools/dotc/typer/Typer.scala
1 parent 5893a72 commit 85b331d

File tree

2 files changed

+50
-45
lines changed

2 files changed

+50
-45
lines changed

compiler/src/dotty/tools/dotc/reporting/messages.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,14 @@ import transform.SymUtils._
149149
}
150150

151151
class AnonymousFunctionMissingParamType(param: untpd.ValDef,
152-
args: List[untpd.Tree],
153152
tree: untpd.Function,
154153
pt: Type)
155154
(using Context)
156155
extends TypeMsg(AnonymousFunctionMissingParamTypeID) {
157156
def msg = {
158157
val ofFun =
159158
if param.name.is(WildcardParamName)
160-
|| (MethodType.syntheticParamNames(args.length + 1) contains param.name)
159+
|| (MethodType.syntheticParamNames(tree.args.length + 1) contains param.name)
161160
then i" of expanded function:\n$tree"
162161
else ""
163162

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ class Typer extends Namer
11161116
* def double(x: Char): String = s"$x$x"
11171117
* "abc" flatMap double
11181118
*/
1119-
private def decomposeProtoFunction(pt: Type, defaultArity: Int, tree: untpd.Tree)(using Context): (List[Type], untpd.Tree) = {
1119+
private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = {
11201120
def typeTree(tp: Type) = tp match {
11211121
case _: WildcardType => untpd.TypeTree()
11221122
case _ => untpd.TypeTree(tp)
@@ -1135,11 +1135,10 @@ class Typer extends Namer
11351135
report.error(
11361136
i"""Implementation restriction: Expected result type $pt1
11371137
|is a curried dependent context function type. Such types are not yet supported.""",
1138-
tree.srcPos)
1139-
1138+
pos)
11401139
pt1 match {
11411140
case tp: TypeParamRef =>
1142-
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree)
1141+
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, pos)
11431142
case _ => pt1.findFunctionTypeInUnion match {
11441143
case pt1 if defn.isNonRefinedFunction(pt1) =>
11451144
// if expected parameter type(s) are wildcards, approximate from below.
@@ -1161,6 +1160,37 @@ class Typer extends Namer
11611160
}
11621161
}
11631162

1163+
/** The parameter type for a parameter in a lambda that does
1164+
* not have an explicit type given, and where the type is not known from the context.
1165+
* In this case the paranmeter type needs to be inferred the "target type" T known
1166+
* from the callee `f` if the lambda is of a form like `x => f(x)`.
1167+
* If `T` exists, we know that `S <: I <: T`.
1168+
*
1169+
* The inference makes two attempts:
1170+
*
1171+
* 1. Compute the target type `T` and make it known that `S <: T`.
1172+
* If the expected type `S` can be fully defined under ForceDegree.flipBottom,
1173+
* pick this one (this might use the fact that S <: T for an upper approximation).
1174+
* 2. Otherwise, if the target type `T` can be fully defined under ForceDegree.flipBottom,
1175+
* pick this one.
1176+
*
1177+
* If both attempts fail, issue a "missing parameter type" error.
1178+
*/
1179+
def inferredFromTarget(
1180+
param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type =
1181+
val target = calleeType.widen match
1182+
case mtpe: MethodType =>
1183+
val pos = paramIndex(param.name)
1184+
if pos < mtpe.paramInfos.length then
1185+
val ptype = mtpe.paramInfos(pos)
1186+
if ptype.isRepeatedParam then NoType else ptype
1187+
else NoType
1188+
case _ => NoType
1189+
if target.exists then formal <:< target
1190+
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1191+
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
1192+
else NoType
1193+
11641194
def typedFunction(tree: untpd.Function, pt: Type)(using Context): Tree =
11651195
if (ctx.mode is Mode.Type) typedFunctionType(tree, pt)
11661196
else typedFunctionValue(tree, pt)
@@ -1333,41 +1363,7 @@ class Typer extends Namer
13331363
case _ =>
13341364
}
13351365

1336-
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree)
1337-
1338-
/** The inferred parameter type for a parameter in a lambda that does
1339-
* not have an explicit type given.
1340-
* An inferred parameter type I has two possible sources:
1341-
* - the type S known from the context
1342-
* - the "target type" T known from the callee `f` if the lambda is of a form like `x => f(x)`
1343-
* If `T` exists, we know that `S <: I <: T`.
1344-
*
1345-
* The inference makes three attempts:
1346-
*
1347-
* 1. If the expected type `S` is already fully defined under ForceDegree.failBottom
1348-
* pick this one.
1349-
* 2. Compute the target type `T` and make it known that `S <: T`.
1350-
* If the expected type `S` can be fully defined under ForceDegree.flipBottom,
1351-
* pick this one (this might use the fact that S <: T for an upper approximation).
1352-
* 3. Otherwise, if the target type `T` can be fully defined under ForceDegree.flipBottom,
1353-
* pick this one.
1354-
*
1355-
* If all attempts fail, issue a "missing parameter type" error.
1356-
*/
1357-
def inferredParamType(param: untpd.ValDef, formal: Type): Type =
1358-
if isFullyDefined(formal, ForceDegree.failBottom) then return formal
1359-
val target = calleeType.widen match
1360-
case mtpe: MethodType =>
1361-
val pos = paramIndex(param.name)
1362-
if pos < mtpe.paramInfos.length then
1363-
val ptype = mtpe.paramInfos(pos)
1364-
if ptype.isRepeatedParam then NoType else ptype
1365-
else NoType
1366-
case _ => NoType
1367-
if target.exists then formal <:< target
1368-
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1369-
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
1370-
else errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.srcPos)
1366+
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
13711367

13721368
def protoFormal(i: Int): Type =
13731369
if (protoFormals.length == params.length) protoFormals(i)
@@ -1393,9 +1389,19 @@ class Typer extends Namer
13931389
val inferredParams: List[untpd.ValDef] =
13941390
for ((param, i) <- params.zipWithIndex) yield
13951391
if (!param.tpt.isEmpty) param
1396-
else cpy.ValDef(param)(
1397-
tpt = untpd.TypeTree(
1398-
inferredParamType(param, protoFormal(i)).translateFromRepeated(toArray = false)))
1392+
else
1393+
val formal = protoFormal(i)
1394+
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
1395+
val paramType =
1396+
if knownFormal then formal
1397+
else inferredFromTarget(param, formal, calleeType, paramIndex)
1398+
.orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
1399+
val paramTpt = untpd.TypedSplice(
1400+
(if knownFormal then InferredTypeTree() else untpd.TypeTree())
1401+
.withType(paramType.translateFromRepeated(toArray = false))
1402+
.withSpan(param.span.endPos)
1403+
)
1404+
cpy.ValDef(param)(tpt = paramTpt)
13991405
desugar.makeClosure(inferredParams, fnBody, resultTpt, isContextual, tree.span)
14001406
}
14011407
typed(desugared, pt)
@@ -1461,7 +1467,7 @@ class Typer extends Namer
14611467
typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt)
14621468
}
14631469
else {
1464-
val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree)
1470+
val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree.srcPos)
14651471
val checkMode =
14661472
if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None
14671473
else desugar.MatchCheck.Exhaustive

0 commit comments

Comments
 (0)