Skip to content

Fix #7757: Do auto-parameter-untupling also for overloaded methods #7766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1316,33 +1316,39 @@ object desugar {
Function(params, Match(makeSelector(selector, checkMode), cases))
}

/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
/** Map n-ary function `(x1: T1, ..., xn: Tn) => body` where n != 1 to unary function as follows:
*
* x$1 => {
* def p1 = x$1._1
* (x$1: (T1, ..., Tn)) => {
* def x1: T1 = x$1._1
* ...
* def pn = x$1._n
* def xn: Tn = x$1._n
* body
* }
*
* or if `isGenericTuple`
*
* x$1 => {
* def p1 = x$1.apply(0)
* (x$1: (T1, ... Tn) => {
* def x1: T1 = x$1.apply(0)
* ...
* def pn = x$1.apply(n-1)
* def xn: Tn = x$1.apply(n-1)
* body
* }
*
* If some of the Ti's are absent, omit the : (T1, ..., Tn) type ascription
* in the selector.
*/
def makeTupledFunction(params: List[ValDef], body: Tree, isGenericTuple: Boolean)(implicit ctx: Context): Tree = {
val param = makeSyntheticParameter()
val param = makeSyntheticParameter(
tpt =
if params.exists(_.tpt.isEmpty) then TypeTree()
else Tuple(params.map(_.tpt)))
def selector(n: Int) =
if (isGenericTuple) Apply(Select(refOfDef(param), nme.apply), Literal(Constant(n)))
else Select(refOfDef(param), nme.selectorName(n))
val vdefs =
params.zipWithIndex.map {
case (param, idx) =>
DefDef(param.name, Nil, Nil, TypeTree(), selector(idx)).withSpan(param.span)
DefDef(param.name, Nil, Nil, param.tpt, selector(idx)).withSpan(param.span)
}
Function(param :: Nil, Block(vdefs, body))
}
Expand Down
16 changes: 16 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Decorators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,22 @@ object Decorators {
else x1 :: xs1
}

/** Like `xs.lazyZip(xs.indices).map(f)`, but returns list `xs` itself
* - instead of a copy - if function `f` maps all elements of
* `xs` to themselves.
*/
def mapWithIndexConserve[U <: T](f: (T, Int) => U): List[U] =
def recur(xs: List[T], idx: Int): List[U] =
if xs.isEmpty then Nil
else
val x1 = f(xs.head, idx)
val xs1 = recur(xs.tail, idx + 1)
if (x1.asInstanceOf[AnyRef] eq xs.head.asInstanceOf[AnyRef])
&& (xs1 eq xs.tail)
then xs.asInstanceOf[List[U]]
else x1 :: xs1
recur(xs, 0)

final def hasSameLengthAs[U](ys: List[U]): Boolean = {
@tailrec def loop(xs: List[T], ys: List[U]): Boolean =
if (xs.isEmpty) ys.isEmpty
Expand Down
69 changes: 50 additions & 19 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package dotc
package typer

import core._
import ast.{Trees, tpd, untpd}
import ast.{Trees, tpd, untpd, desugar}
import util.Spans._
import util.Stats.record
import util.{SourcePosition, NoSourcePosition, SourceFile}
Expand Down Expand Up @@ -864,7 +864,7 @@ trait Applications extends Compatibility {
case funRef: TermRef =>
val app =
if (proto.allArgTypesAreCurrent())
new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt)
new ApplyToTyped(tree, fun1, funRef, proto.typedArgs(), pt)
else
new ApplyToUntyped(tree, fun1, funRef, proto, pt)(
given fun1.nullableInArgContext(given argCtx(tree)))
Expand All @@ -891,7 +891,7 @@ trait Applications extends Compatibility {
}

fun1.tpe match {
case err: ErrorType => cpy.Apply(tree)(fun1, proto.unforcedTypedArgs).withType(err)
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
case TryDynamicCallType => typedDynamicApply(tree, pt)
case _ =>
if (originalProto.isDropped) fun1
Expand Down Expand Up @@ -1635,14 +1635,46 @@ trait Applications extends Compatibility {
def narrowByTypes(alts: List[TermRef], argTypes: List[Type], resultType: Type): List[TermRef] =
alts filter (isApplicableMethodRef(_, argTypes, resultType))

/** Normalization steps before checking arguments:
*
* { expr } --> expr
* (x1, ..., xn) => expr --> ((x1, ..., xn)) => expr
* if n != 1, no alternative has a corresponding formal parameter that
* is an n-ary function, and at least one alternative has a corresponding
* formal parameter that is a unary function.
*/
def normArg(alts: List[TermRef], arg: untpd.Tree, idx: Int): untpd.Tree = arg match
case Block(Nil, expr) => normArg(alts, expr, idx)
case untpd.Function(args: List[untpd.ValDef] @unchecked, body) =>

// If ref refers to a method whose parameter at index `idx` is a function type,
// the arity of that function, otherise -1.
def paramCount(ref: TermRef) =
val formals = ref.widen.firstParamTypes
if formals.length > idx then
formals(idx) match
case defn.FunctionOf(args, _, _, _) => args.length
case _ => -1
else -1

val numArgs = args.length
if numArgs != 1
&& !alts.exists(paramCount(_) == numArgs)
&& alts.exists(paramCount(_) == 1)
then
desugar.makeTupledFunction(args, body, isGenericTuple = true)
// `isGenericTuple = true` is the safe choice here. It means the i'th tuple
// element is selected with `(i)` instead of `_i`, which gives the same code
// in the end, but the compilation time and the ascribed type are more involved.
// It also means that -Ytest-pickler -Xprint-types fails for sources exercising
// the idiom since after pickling the target is known, so _i is used directly.
else arg
case _ => arg
end normArg

val candidates = pt match {
case pt @ FunProto(args, resultType) =>
val numArgs = args.length
val normArgs = args.mapConserve {
case Block(Nil, expr) => expr
case x => x
}

def sizeFits(alt: TermRef): Boolean = alt.widen.stripPoly match {
case tp: MethodType =>
val ptypes = tp.paramInfos
Expand All @@ -1661,9 +1693,10 @@ trait Applications extends Compatibility {
alts.filter(sizeFits(_))

def narrowByShapes(alts: List[TermRef]): List[TermRef] =
if (normArgs exists untpd.isFunctionWithUnknownParamType)
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
else narrowByTypes(alts, normArgs map typeShape, resultType)
val normArgs = args.mapWithIndexConserve(normArg(alts, _, _))
if normArgs.exists(untpd.isFunctionWithUnknownParamType) then
if hasNamedArg(args) then narrowByTrees(alts, normArgs.map(treeShape), resultType)
else narrowByTypes(alts, normArgs.map(typeShape), resultType)
else
alts

Expand All @@ -1681,16 +1714,14 @@ trait Applications extends Compatibility {

val alts1 = narrowBySize(alts)
//ctx.log(i"narrowed by size: ${alts1.map(_.symbol.showDcl)}%, %")
if (isDetermined(alts1)) alts1
else {
if isDetermined(alts1) then alts1
else
val alts2 = narrowByShapes(alts1)
//ctx.log(i"narrowed by shape: ${alts2.map(_.symbol.showDcl)}%, %")
if (isDetermined(alts2)) alts2
else {
if isDetermined(alts2) then alts2
else
pretypeArgs(alts2, pt)
narrowByTrees(alts2, pt.unforcedTypedArgs, resultType)
}
}
narrowByTrees(alts2, pt.typedArgs(normArg(alts2, _, _)), resultType)

case pt @ PolyProto(targs1, pt1) if targs.isEmpty =>
val alts1 = alts.filter(pt.isMatchedBy(_))
Expand Down Expand Up @@ -1749,7 +1780,7 @@ trait Applications extends Compatibility {
else pt match {
case pt @ FunProto(_, resType: FunProto) =>
// try to narrow further with snd argument list
val advanced = advanceCandidates(pt.unforcedTypedArgs.tpes)
val advanced = advanceCandidates(pt.typedArgs().tpes)
resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped
.map(advanced.toMap) // map surviving result(s) back to original candidates
case _ =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object ErrorReporting {
case _: WildcardType | _: IgnoredProto => ""
case tp => em" and expected result type $tp"
}
em"arguments (${tp.unforcedTypedArgs.tpes}%, %)$result"
em"arguments (${tp.typedArgs().tpes}%, %)$result"
case _ =>
em"expected type $tp"
}
Expand Down
15 changes: 10 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ object ProtoTypes {
override def resultType(implicit ctx: Context): Type = resType

def isMatchedBy(tp: Type, keepConstraint: Boolean)(implicit ctx: Context): Boolean = {
val args = unforcedTypedArgs
val args = typedArgs()
def isPoly(tree: Tree) = tree.tpe.widenSingleton.isInstanceOf[PolyType]
// See remark in normalizedCompatible for why we can't keep the constraint
// if one of the arguments has a PolyType.
Expand Down Expand Up @@ -301,15 +301,18 @@ object ProtoTypes {
* However, any constraint changes are also propagated to the currently passed
* context.
*
* @param norm a normalization function that is applied to an untyped argument tree
* before it is typed. The second Int parameter is the parameter index.
*/
def unforcedTypedArgs(implicit ctx: Context): List[Tree] =
def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree = sameTree)(implicit ctx: Context): List[Tree] =
if (state.typedArgs.size == args.length) state.typedArgs
else {
val prevConstraint = this.ctx.typerState.constraint

try {
implicit val ctx = this.ctx
val args1 = args.mapconserve(cacheTypedArg(_, typer.typed(_), force = false))
val args1 = args.mapWithIndexConserve((arg, idx) =>
cacheTypedArg(arg, arg => typer.typed(norm(arg, idx)), force = false))
if (!args1.exists(arg => isUndefined(arg.tpe))) state.typedArgs = args1
args1
}
Expand Down Expand Up @@ -375,7 +378,7 @@ object ProtoTypes {
derivedFunProto(args, tm(resultType), typer)

def fold[T](x: T, ta: TypeAccumulator[T])(implicit ctx: Context): T =
ta(ta.foldOver(x, unforcedTypedArgs.tpes), resultType)
ta(ta.foldOver(x, typedArgs().tpes), resultType)

override def deepenProto(implicit ctx: Context): FunProto = derivedFunProto(args, resultType.deepenProto, typer)

Expand All @@ -389,7 +392,7 @@ object ProtoTypes {
* [](args): resultType, where args are known to be typed
*/
class FunProtoTyped(args: List[tpd.Tree], resultType: Type)(typer: Typer, isGivenApply: Boolean)(implicit ctx: Context) extends FunProto(args, resultType)(typer, isGivenApply)(ctx) {
override def unforcedTypedArgs(implicit ctx: Context): List[tpd.Tree] = args
override def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree)(implicit ctx: Context): List[tpd.Tree] = args
override def withContext(ctx: Context): FunProtoTyped = this
}

Expand Down Expand Up @@ -682,4 +685,6 @@ object ProtoTypes {
case _ => None
}
}

private val sameTree = (t: untpd.Tree, n: Int) => t
}
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ class Typer extends Namer
}

def typedFunctionValue(tree: untpd.Function, pt: Type)(implicit ctx: Context): Tree = {
val untpd.Function(params: List[untpd.ValDef] @unchecked, body) = tree
val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree

val isContextual = tree match {
case tree: untpd.FunctionWithMods => tree.mods.is(Given)
Expand Down
3 changes: 3 additions & 0 deletions compiler/test/dotc/pos-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ i7580.scala

# Nullability
nullable.scala

# parameter untupling with overloaded functions (see comment in Applications.normArg)
i7757.scala
10 changes: 10 additions & 0 deletions tests/pos/i7757.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
val m: Map[Int, String] = ???
val _ = m.map((a, b) => a + b.length)

trait Foo
def g(f: ((Int, Int)) => Int): Int = 1
def g(f: ((Int, Int)) => (Int, Int)): String = "2"

@main def Test =
val m: Foo = ???
m.g((x: Int, b: Int) => (x, x))