Skip to content

Commit cf37bf5

Browse files
committed
Handle partial functions correctly in preTypeArgs
1 parent 762ebbf commit cf37bf5

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,18 @@ class Definitions {
702702
}
703703
}
704704

705+
object PartialFunctionOf {
706+
def apply(arg: Type, result: Type)(implicit ctx: Context) =
707+
PartialFunctionType.appliedTo(arg :: result :: Nil)
708+
def unapply(pft: Type)(implicit ctx: Context) = {
709+
if (pft.isRef(PartialFunctionClass)) {
710+
val targs = pft.dealias.argInfos
711+
if (targs.length == 2) Some((targs.head, targs.tail)) else None
712+
}
713+
else None
714+
}
715+
}
716+
705717
object ArrayOf {
706718
def apply(elem: Type)(implicit ctx: Context) =
707719
if (ctx.erasedTypes) JavaArrayType(elem)

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,27 +1405,29 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
14051405
if (untpd.isFunctionWithUnknownParamType(arg)) {
14061406
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
14071407
val formalsForArg: List[Type] = altFormals.map(_.head)
1408-
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1409-
// (p_1_1, ..., p_m_1) => r_1
1410-
// ...
1411-
// (p_1_n, ..., p_m_n) => r_n
1412-
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
1413-
formalsForArg.map(defn.FunctionOf.unapply)
1414-
if (decomposedFormalsForArg.forall(_.isDefined)) {
1415-
val formalParamTypessForArg: List[List[Type]] =
1416-
decomposedFormalsForArg.map(_.get._1)
1417-
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1418-
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1419-
// Given definitions above, for i = 1,...,m,
1420-
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1421-
// If all p_i_k's are the same, assume the type as formal parameter
1422-
// type of the i'th parameter of the closure.
1423-
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1424-
else WildcardType)
1425-
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
1426-
overload.println(i"pretype arg $arg with expected type $commonFormal")
1427-
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
1408+
def argTypesOfFormal(formal: Type): List[Type] =
1409+
formal match {
1410+
case defn.FunctionOf(args, result, isImplicit) => args
1411+
case defn.PartialFunctionOf(arg, result) => arg :: Nil
1412+
case _ => Nil
14281413
}
1414+
val formalParamTypessForArg: List[List[Type]] =
1415+
formalsForArg.map(argTypesOfFormal)
1416+
if (formalParamTypessForArg.forall(_.nonEmpty) &&
1417+
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1418+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1419+
// Given definitions above, for i = 1,...,m,
1420+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1421+
// If all p_i_k's are the same, assume the type as formal parameter
1422+
// type of the i'th parameter of the closure.
1423+
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1424+
else WildcardType)
1425+
def isPartial = formalsForArg.forall(_.isRef(defn.PartialFunctionClass))
1426+
val commonFormal =
1427+
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, newTypeVar(TypeBounds.empty))
1428+
else defn.FunctionOf(commonParamTypes, WildcardType)
1429+
overload.println(i"pretype arg $arg with expected type $commonFormal")
1430+
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
14291431
}
14301432
}
14311433
recur(altFormals.map(_.tail), args1)

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ object ProtoTypes {
238238
typer.adapt(targ, formal, arg)
239239
}
240240

241+
/** Retype argument, removing any previously cached entries */
242+
def retypeArg(arg: untpd.Tree, formal: Type)(implicit ctx: Context): Tree = {
243+
myTypedArg = myTypedArg.remove(arg)
244+
typedArg(arg, formal)
245+
}
246+
241247
/** The type of the argument `arg`.
242248
* @pre `arg` has been typed before
243249
*/
@@ -401,15 +407,18 @@ object ProtoTypes {
401407
/** Same as `constrained(tl, EmptyTree)`, but returns just the created type lambda */
402408
def constrained(tl: TypeLambda)(implicit ctx: Context): TypeLambda = constrained(tl, EmptyTree)._1
403409

404-
/** Create a new TypeVar that represents a dependent method parameter singleton */
405-
def newDepTypeVar(tp: Type)(implicit ctx: Context): TypeVar = {
410+
def newTypeVar(bounds: TypeBounds)(implicit ctx: Context): TypeVar = {
406411
val poly = PolyType(DepParamName.fresh().toTypeName :: Nil)(
407-
pt => TypeBounds.upper(AndType(tp, defn.SingletonClass.typeRef)) :: Nil,
412+
pt => bounds :: Nil,
408413
pt => defn.AnyType)
409414
constrained(poly, untpd.EmptyTree, alwaysAddTypeVars = true)
410415
._2.head.tpe.asInstanceOf[TypeVar]
411416
}
412417

418+
/** Create a new TypeVar that represents a dependent method parameter singleton */
419+
def newDepTypeVar(tp: Type)(implicit ctx: Context): TypeVar =
420+
newTypeVar(TypeBounds.upper(AndType(tp, defn.SingletonClass.typeRef)))
421+
413422
/** The result type of `mt`, where all references to parameters of `mt` are
414423
* replaced by either wildcards (if typevarsMissContext) or TypeParamRefs.
415424
*/

0 commit comments

Comments
 (0)