@@ -1379,21 +1379,19 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
1379
1379
}
1380
1380
1381
1381
/** Try to typecheck any arguments in `pt` that are function values missing a
1382
- * parameter type. The expected type for these arguments is the lub of the
1383
- * corresponding formal parameter types of all alternatives. Type variables
1384
- * in formal parameter types are replaced by wildcards. The result of the
1385
- * typecheck is stored in `pt`, to be retrieved when its `typedArgs` are selected.
1382
+ * parameter type. If the formal parameter types corresponding to a closure argument
1383
+ * all agree on their argument types, typecheck the argument with an expected
1384
+ * function or partial function type that contains these argument types,
1385
+ * The result of the typecheck is stored in `pt`, to be retrieved when its `typedArgs` are selected.
1386
1386
* The benefit of doing this is to allow idioms like this:
1387
1387
*
1388
1388
* def map(f: Char => Char): String = ???
1389
1389
* def map[U](f: Char => U): Seq[U] = ???
1390
1390
* map(x => x.toUpper)
1391
1391
*
1392
1392
* Without `pretypeArgs` we'd get a "missing parameter type" error for `x`.
1393
- * With `pretypeArgs`, we use the union of the two formal parameter types
1394
- * `Char => Char` and `Char => ?` as the expected type of the closure `x => x.toUpper`.
1395
- * That union is `Char => Char`, so we have an expected parameter type `Char`
1396
- * for `x`, and the code typechecks.
1393
+ * With `pretypeArgs`, we use the `Char => ?` as the expected type of the
1394
+ * closure `x => x.toUpper`, which makes the code typecheck.
1397
1395
*/
1398
1396
private def pretypeArgs (alts : List [TermRef ], pt : FunProto )(implicit ctx : Context ): Unit = {
1399
1397
def recur (altFormals : List [List [Type ]], args : List [untpd.Tree ]): Unit = args match {
@@ -1402,30 +1400,35 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
1402
1400
case ValDef (_, tpt, _) => tpt.isEmpty
1403
1401
case _ => false
1404
1402
}
1405
- if (untpd.isFunctionWithUnknownParamType(arg)) {
1403
+ val fn = untpd.functionWithUnknownParamType(arg)
1404
+ if (fn.isDefined) {
1406
1405
def isUniform [T ](xs : List [T ])(p : (T , T ) => Boolean ) = xs.forall(p(_, xs.head))
1407
1406
val formalsForArg : List [Type ] = altFormals.map(_.head)
1408
- // For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1409
- // (p_1_1, ..., p_m_1) => r_1
1410
- // ...
1411
- // (p_1_n, ..., p_m_n) => r_n
1412
- val decomposedFormalsForArg : List [Option [(List [Type ], Type , Boolean )]] =
1413
- formalsForArg.map(defn.FunctionOf .unapply)
1414
- if (decomposedFormalsForArg.forall(_.isDefined)) {
1415
- val formalParamTypessForArg : List [List [Type ]] =
1416
- decomposedFormalsForArg.map(_.get._1)
1417
- if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1418
- val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1419
- // Given definitions above, for i = 1,...,m,
1420
- // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1421
- // If all p_i_k's are the same, assume the type as formal parameter
1422
- // type of the i'th parameter of the closure.
1423
- if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1424
- else WildcardType )
1425
- val commonFormal = defn.FunctionOf (commonParamTypes, WildcardType )
1426
- overload.println(i " pretype arg $arg with expected type $commonFormal" )
1427
- pt.typedArg(arg, commonFormal)(ctx.addMode(Mode .ImplicitsEnabled ))
1407
+ def argTypesOfFormal (formal : Type ): List [Type ] =
1408
+ formal match {
1409
+ case defn.FunctionOf (args, result, isImplicit) => args
1410
+ case defn.PartialFunctionOf (arg, result) => arg :: Nil
1411
+ case _ => Nil
1428
1412
}
1413
+ val formalParamTypessForArg : List [List [Type ]] =
1414
+ formalsForArg.map(argTypesOfFormal)
1415
+ if (formalParamTypessForArg.forall(_.nonEmpty) &&
1416
+ isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1417
+ val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1418
+ // Given definitions above, for i = 1,...,m,
1419
+ // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1420
+ // If all p_i_k's are the same, assume the type as formal parameter
1421
+ // type of the i'th parameter of the closure.
1422
+ if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1423
+ else WildcardType )
1424
+ def isPartial = // we should generate a partial function for the arg
1425
+ fn.get.isInstanceOf [untpd.Match ] &&
1426
+ formalsForArg.exists(_.isRef(defn.PartialFunctionClass ))
1427
+ val commonFormal =
1428
+ if (isPartial) defn.PartialFunctionOf (commonParamTypes.head, newTypeVar(TypeBounds .empty))
1429
+ else defn.FunctionOf (commonParamTypes, WildcardType )
1430
+ overload.println(i " pretype arg $arg with expected type $commonFormal" )
1431
+ pt.typedArg(arg, commonFormal)(ctx.addMode(Mode .ImplicitsEnabled ))
1429
1432
}
1430
1433
}
1431
1434
recur(altFormals.map(_.tail), args1)
0 commit comments