Skip to content

Commit 9e1ed7a

Browse files
authored
Merge pull request #2868 from dotty-staging/fix-#2867
Fix #2866: Handle partial functions in preTypeArgs
2 parents b12326c + 8ade44c commit 9e1ed7a

File tree

4 files changed

+62
-37
lines changed

4 files changed

+62
-37
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,17 +273,24 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
273273
case _ => false
274274
}
275275

276-
def isFunctionWithUnknownParamType(tree: Tree) = tree match {
276+
def functionWithUnknownParamType(tree: Tree): Option[Tree] = tree match {
277277
case Function(args, _) =>
278-
args.exists {
278+
if (args.exists {
279279
case ValDef(_, tpt, _) => tpt.isEmpty
280280
case _ => false
281-
}
281+
}) Some(tree)
282+
else None
282283
case Match(EmptyTree, _) =>
283-
true
284-
case _ => false
284+
Some(tree)
285+
case Block(Nil, expr) =>
286+
functionWithUnknownParamType(expr)
287+
case _ =>
288+
None
285289
}
286290

291+
def isFunctionWithUnknownParamType(tree: Tree): Boolean =
292+
functionWithUnknownParamType(tree).isDefined
293+
287294
/** Is `tree` an implicit function or closure, possibly nested in a block? */
288295
def isImplicitClosure(tree: Tree)(implicit ctx: Context): Boolean = unsplice(tree) match {
289296
case Function((param: untpd.ValDef) :: _, _) => param.mods.is(Implicit)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,18 @@ class Definitions {
704704
}
705705
}
706706

707+
object PartialFunctionOf {
708+
def apply(arg: Type, result: Type)(implicit ctx: Context) =
709+
PartialFunctionType.appliedTo(arg :: result :: Nil)
710+
def unapply(pft: Type)(implicit ctx: Context) = {
711+
if (pft.isRef(PartialFunctionClass)) {
712+
val targs = pft.dealias.argInfos
713+
if (targs.length == 2) Some((targs.head, targs.tail)) else None
714+
}
715+
else None
716+
}
717+
}
718+
707719
object ArrayOf {
708720
def apply(elem: Type)(implicit ctx: Context) =
709721
if (ctx.erasedTypes) JavaArrayType(elem)

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,21 +1379,19 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
13791379
}
13801380

13811381
/** Try to typecheck any arguments in `pt` that are function values missing a
1382-
* parameter type. The expected type for these arguments is the lub of the
1383-
* corresponding formal parameter types of all alternatives. Type variables
1384-
* in formal parameter types are replaced by wildcards. The result of the
1385-
* typecheck is stored in `pt`, to be retrieved when its `typedArgs` are selected.
1382+
* parameter type. If the formal parameter types corresponding to a closure argument
1383+
* all agree on their argument types, typecheck the argument with an expected
1384+
* function or partial function type that contains these argument types,
1385+
* The result of the typecheck is stored in `pt`, to be retrieved when its `typedArgs` are selected.
13861386
* The benefit of doing this is to allow idioms like this:
13871387
*
13881388
* def map(f: Char => Char): String = ???
13891389
* def map[U](f: Char => U): Seq[U] = ???
13901390
* map(x => x.toUpper)
13911391
*
13921392
* Without `pretypeArgs` we'd get a "missing parameter type" error for `x`.
1393-
* With `pretypeArgs`, we use the union of the two formal parameter types
1394-
* `Char => Char` and `Char => ?` as the expected type of the closure `x => x.toUpper`.
1395-
* That union is `Char => Char`, so we have an expected parameter type `Char`
1396-
* for `x`, and the code typechecks.
1393+
* With `pretypeArgs`, we use the `Char => ?` as the expected type of the
1394+
* closure `x => x.toUpper`, which makes the code typecheck.
13971395
*/
13981396
private def pretypeArgs(alts: List[TermRef], pt: FunProto)(implicit ctx: Context): Unit = {
13991397
def recur(altFormals: List[List[Type]], args: List[untpd.Tree]): Unit = args match {
@@ -1402,30 +1400,35 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
14021400
case ValDef(_, tpt, _) => tpt.isEmpty
14031401
case _ => false
14041402
}
1405-
if (untpd.isFunctionWithUnknownParamType(arg)) {
1403+
val fn = untpd.functionWithUnknownParamType(arg)
1404+
if (fn.isDefined) {
14061405
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
14071406
val formalsForArg: List[Type] = altFormals.map(_.head)
1408-
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1409-
// (p_1_1, ..., p_m_1) => r_1
1410-
// ...
1411-
// (p_1_n, ..., p_m_n) => r_n
1412-
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
1413-
formalsForArg.map(defn.FunctionOf.unapply)
1414-
if (decomposedFormalsForArg.forall(_.isDefined)) {
1415-
val formalParamTypessForArg: List[List[Type]] =
1416-
decomposedFormalsForArg.map(_.get._1)
1417-
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1418-
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1419-
// Given definitions above, for i = 1,...,m,
1420-
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1421-
// If all p_i_k's are the same, assume the type as formal parameter
1422-
// type of the i'th parameter of the closure.
1423-
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1424-
else WildcardType)
1425-
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
1426-
overload.println(i"pretype arg $arg with expected type $commonFormal")
1427-
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
1407+
def argTypesOfFormal(formal: Type): List[Type] =
1408+
formal match {
1409+
case defn.FunctionOf(args, result, isImplicit) => args
1410+
case defn.PartialFunctionOf(arg, result) => arg :: Nil
1411+
case _ => Nil
14281412
}
1413+
val formalParamTypessForArg: List[List[Type]] =
1414+
formalsForArg.map(argTypesOfFormal)
1415+
if (formalParamTypessForArg.forall(_.nonEmpty) &&
1416+
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1417+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1418+
// Given definitions above, for i = 1,...,m,
1419+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1420+
// If all p_i_k's are the same, assume the type as formal parameter
1421+
// type of the i'th parameter of the closure.
1422+
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1423+
else WildcardType)
1424+
def isPartial = // we should generate a partial function for the arg
1425+
fn.get.isInstanceOf[untpd.Match] &&
1426+
formalsForArg.exists(_.isRef(defn.PartialFunctionClass))
1427+
val commonFormal =
1428+
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, newTypeVar(TypeBounds.empty))
1429+
else defn.FunctionOf(commonParamTypes, WildcardType)
1430+
overload.println(i"pretype arg $arg with expected type $commonFormal")
1431+
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
14291432
}
14301433
}
14311434
recur(altFormals.map(_.tail), args1)

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,15 +401,18 @@ object ProtoTypes {
401401
/** Same as `constrained(tl, EmptyTree)`, but returns just the created type lambda */
402402
def constrained(tl: TypeLambda)(implicit ctx: Context): TypeLambda = constrained(tl, EmptyTree)._1
403403

404-
/** Create a new TypeVar that represents a dependent method parameter singleton */
405-
def newDepTypeVar(tp: Type)(implicit ctx: Context): TypeVar = {
404+
def newTypeVar(bounds: TypeBounds)(implicit ctx: Context): TypeVar = {
406405
val poly = PolyType(DepParamName.fresh().toTypeName :: Nil)(
407-
pt => TypeBounds.upper(AndType(tp, defn.SingletonClass.typeRef)) :: Nil,
406+
pt => bounds :: Nil,
408407
pt => defn.AnyType)
409408
constrained(poly, untpd.EmptyTree, alwaysAddTypeVars = true)
410409
._2.head.tpe.asInstanceOf[TypeVar]
411410
}
412411

412+
/** Create a new TypeVar that represents a dependent method parameter singleton */
413+
def newDepTypeVar(tp: Type)(implicit ctx: Context): TypeVar =
414+
newTypeVar(TypeBounds.upper(AndType(tp, defn.SingletonClass.typeRef)))
415+
413416
/** The result type of `mt`, where all references to parameters of `mt` are
414417
* replaced by either wildcards (if typevarsMissContext) or TypeParamRefs.
415418
*/

0 commit comments

Comments
 (0)