Skip to content

Commit 6ecc32e

Browse files
committed
Fix scala#16405 - wildcards prematurely resolving to Nothing
This was a problem because it could it get in the way of some metaprogramming techniques. The main issue was the fact that when typing functions, the type inference would first look at the types from the source method (resolving type wildcards to Nothing) and only after that, it could look at the target method. Now, in the case of wildcards we save that fact for later (while still resolving the prototype parameter to Nothing) and we in that case we prioritize according to the target method, after which we fallback to the default procedure.
1 parent b65b0f2 commit 6ecc32e

File tree

2 files changed

+67
-17
lines changed

2 files changed

+67
-17
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
11971197
)
11981198
end typedIf
11991199

1200-
/** Decompose function prototype into a list of parameter prototypes and a result prototype
1200+
/** Decompose function prototype into a list of parameter prototypes, an optional list
1201+
* describing whether the parameter prototypes come from WildcardTypes, and a result prototype
12011202
* tree, using WildcardTypes where a type is not known.
12021203
* For the result type we do this even if the expected type is not fully
12031204
* defined, which is a bit of a hack. But it's needed to make the following work
@@ -1206,7 +1207,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12061207
* def double(x: Char): String = s"$x$x"
12071208
* "abc" flatMap double
12081209
*/
1209-
private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = {
1210+
private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], Option[List[Boolean]], untpd.Tree) = {
12101211
def typeTree(tp: Type) = tp match {
12111212
case _: WildcardType => new untpd.InferredTypeTree()
12121213
case _ => untpd.InferredTypeTree(tp)
@@ -1234,18 +1235,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12341235
// if expected parameter type(s) are wildcards, approximate from below.
12351236
// if expected result type is a wildcard, approximate from above.
12361237
// this can type the greatest set of admissible closures.
1237-
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
1238+
// However, we still keep the information on whether expected parameter types were
1239+
// WildcardTypes, in case of types inferred from target being more specific
1240+
1241+
val fromWildcards = pt1.argInfos.init.map{
1242+
case bounds @ TypeBounds(nt, at) if nt == defn.NothingType && at == defn.AnyType => true
1243+
case bounds => false
1244+
}
1245+
1246+
(pt1.argTypesLo.init, Some(fromWildcards), typeTree(interpolateWildcards(pt1.argTypesHi.last)))
12381247
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
12391248
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
1240-
(formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
1249+
(formals, None, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
12411250
case SAMType(mt @ MethodTpe(_, formals, restpe)) =>
1242-
(formals,
1251+
(formals, None,
12431252
if (mt.isResultDependent)
12441253
untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))
12451254
else
12461255
typeTree(restpe))
12471256
case _ =>
1248-
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1257+
(List.tabulate(defaultArity)(alwaysWildcardType), None, untpd.TypeTree())
12491258
}
12501259
}
12511260
}
@@ -1259,15 +1268,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12591268
* The inference makes two attempts:
12601269
*
12611270
* 1. Compute the target type `T` and make it known that `S <: T`.
1262-
* If the expected type `S` can be fully defined under ForceDegree.flipBottom,
1263-
* pick this one (this might use the fact that S <: T for an upper approximation).
1271+
* If the expected type `S` can be fully defined under ForceDegree.flipBottom
1272+
* and with minimizedSelected option set as true, pick this one
1273+
* (this might use the fact that S <: T for an upper approximation).
12641274
* 2. Otherwise, if the target type `T` can be fully defined under ForceDegree.flipBottom,
12651275
* pick this one.
12661276
*
12671277
* If both attempts fail, return `NoType`.
12681278
*/
12691279
def inferredFromTarget(
1270-
param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type =
1280+
param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int, isWildcardParam: Boolean)(using Context): Type =
12711281
val target = calleeType.widen match
12721282
case mtpe: MethodType =>
12731283
val pos = paramIndex(param.name)
@@ -1280,7 +1290,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12801290
else NoType
12811291
case _ => NoType
12821292
if target.exists then formal <:< target
1283-
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
1293+
if !isWildcardParam && isFullyDefined(formal, ForceDegree.flipBottom) then formal
12841294
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
12851295
else NoType
12861296

@@ -1457,7 +1467,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
14571467
case _ =>
14581468
}
14591469

1460-
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
1470+
val (protoFormals, areWildcardParams, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
14611471

14621472
def protoFormal(i: Int): Type =
14631473
if (protoFormals.length == params.length) protoFormals(i)
@@ -1500,13 +1510,22 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15001510
if (!param.tpt.isEmpty) param
15011511
else
15021512
val formal = protoFormal(i)
1513+
val isWildcardParam = areWildcardParams.map(list => if i < list.length then list(i) else false).getOrElse(false)
15031514
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
1504-
val paramType =
1505-
if knownFormal then formal
1506-
else inferredFromTarget(param, formal, calleeType, paramIndex)
1507-
.orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
1515+
// Since decomposeProtoFunction eagerly approximates function arguments
1516+
// from below, then in the case that parameter was also identified as
1517+
// a wildcard we try to prioritize inferring from target, if possible.
1518+
// See issue 16405 (16405.scala)
1519+
val (usingFormal, paramType) =
1520+
if !isWildcardParam && knownFormal then (true, formal)
1521+
else
1522+
val fromTarget = inferredFromTarget(param, formal, calleeType, paramIndex, isWildcardParam)
1523+
if fromTarget.exists then
1524+
(false, fromTarget)
1525+
else if knownFormal then (true, formal)
1526+
else (false, errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos))
15081527
val paramTpt = untpd.TypedSplice(
1509-
(if knownFormal then InferredTypeTree() else untpd.TypeTree())
1528+
(if usingFormal then InferredTypeTree() else untpd.TypeTree())
15101529
.withType(paramType.translateFromRepeated(toArray = false))
15111530
.withSpan(param.span.endPos)
15121531
)
@@ -1577,7 +1596,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15771596
typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt)
15781597
}
15791598
else {
1580-
val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree.srcPos)
1599+
val (protoFormals, _, _) = decomposeProtoFunction(pt, 1, tree.srcPos)
15811600
val checkMode =
15821601
if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None
15831602
else desugar.MatchCheck.Exhaustive

tests/run/16405.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import scala.compiletime.summonInline
2+
3+
case class TypeDesc[T](tpe: String)
4+
object TypeDesc {
5+
given nothing: TypeDesc[Nothing] = TypeDesc("Nothing")
6+
given string: TypeDesc[String] = TypeDesc("String")
7+
given int: TypeDesc[Int] = TypeDesc("Int")
8+
}
9+
10+
def exampleFn(s: String, i: Int): Unit = ()
11+
12+
inline def argumentTypesOf[R](fun: (_, _) => R): (TypeDesc[?], TypeDesc[?]) = {
13+
inline fun match {
14+
case x: ((a, b) => R) =>
15+
(scala.compiletime.summonInline[TypeDesc[a]], scala.compiletime.summonInline[TypeDesc[b]])
16+
}
17+
}
18+
inline def argumentTypesOfNoWildCard[A, B, R](fun: (A, B) => R): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)
19+
inline def argumentTypesOfAllWildCard(fun: (?, ?) => ?): (TypeDesc[?], TypeDesc[?]) = argumentTypesOf(fun)
20+
21+
object Test {
22+
def main(args: Array[String]): Unit = {
23+
val expected = (TypeDesc.string, TypeDesc.int)
24+
assert(argumentTypesOf(exampleFn) == expected)
25+
assert(argumentTypesOf(exampleFn(_, _)) == expected)
26+
assert(argumentTypesOfNoWildCard(exampleFn) == expected)
27+
assert(argumentTypesOfNoWildCard(exampleFn(_, _)) == expected)
28+
assert(argumentTypesOfAllWildCard(exampleFn) == expected)
29+
assert(argumentTypesOfAllWildCard(exampleFn(_, _)) == expected)
30+
}
31+
}

0 commit comments

Comments
 (0)