diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index f32bc2e1f570..9671c603b1c3 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1316,33 +1316,39 @@ object desugar { Function(params, Match(makeSelector(selector, checkMode), cases)) } - /** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows: + /** Map n-ary function `(x1: T1, ..., xn: Tn) => body` where n != 1 to unary function as follows: * - * x$1 => { - * def p1 = x$1._1 + * (x$1: (T1, ..., Tn)) => { + * def x1: T1 = x$1._1 * ... - * def pn = x$1._n + * def xn: Tn = x$1._n * body * } * * or if `isGenericTuple` * - * x$1 => { - * def p1 = x$1.apply(0) + * (x$1: (T1, ... Tn) => { + * def x1: T1 = x$1.apply(0) * ... - * def pn = x$1.apply(n-1) + * def xn: Tn = x$1.apply(n-1) * body * } + * + * If some of the Ti's are absent, omit the : (T1, ..., Tn) type ascription + * in the selector. */ def makeTupledFunction(params: List[ValDef], body: Tree, isGenericTuple: Boolean)(implicit ctx: Context): Tree = { - val param = makeSyntheticParameter() + val param = makeSyntheticParameter( + tpt = + if params.exists(_.tpt.isEmpty) then TypeTree() + else Tuple(params.map(_.tpt))) def selector(n: Int) = if (isGenericTuple) Apply(Select(refOfDef(param), nme.apply), Literal(Constant(n))) else Select(refOfDef(param), nme.selectorName(n)) val vdefs = params.zipWithIndex.map { case (param, idx) => - DefDef(param.name, Nil, Nil, TypeTree(), selector(idx)).withSpan(param.span) + DefDef(param.name, Nil, Nil, param.tpt, selector(idx)).withSpan(param.span) } Function(param :: Nil, Block(vdefs, body)) } diff --git a/compiler/src/dotty/tools/dotc/core/Decorators.scala b/compiler/src/dotty/tools/dotc/core/Decorators.scala index 279ce0d29293..71ba89770625 100644 --- a/compiler/src/dotty/tools/dotc/core/Decorators.scala +++ b/compiler/src/dotty/tools/dotc/core/Decorators.scala @@ -114,6 +114,22 @@ object Decorators { else x1 :: xs1 } + /** Like `xs.lazyZip(xs.indices).map(f)`, but returns list `xs` itself + * - instead of a copy - if function `f` maps all elements of + * `xs` to themselves. + */ + def mapWithIndexConserve[U <: T](f: (T, Int) => U): List[U] = + def recur(xs: List[T], idx: Int): List[U] = + if xs.isEmpty then Nil + else + val x1 = f(xs.head, idx) + val xs1 = recur(xs.tail, idx + 1) + if (x1.asInstanceOf[AnyRef] eq xs.head.asInstanceOf[AnyRef]) + && (xs1 eq xs.tail) + then xs.asInstanceOf[List[U]] + else x1 :: xs1 + recur(xs, 0) + final def hasSameLengthAs[U](ys: List[U]): Boolean = { @tailrec def loop(xs: List[T], ys: List[U]): Boolean = if (xs.isEmpty) ys.isEmpty diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index a3073e06c25a..808586ffed96 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -3,7 +3,7 @@ package dotc package typer import core._ -import ast.{Trees, tpd, untpd} +import ast.{Trees, tpd, untpd, desugar} import util.Spans._ import util.Stats.record import util.{SourcePosition, NoSourcePosition, SourceFile} @@ -864,7 +864,7 @@ trait Applications extends Compatibility { case funRef: TermRef => val app = if (proto.allArgTypesAreCurrent()) - new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt) + new ApplyToTyped(tree, fun1, funRef, proto.typedArgs(), pt) else new ApplyToUntyped(tree, fun1, funRef, proto, pt)( given fun1.nullableInArgContext(given argCtx(tree))) @@ -891,7 +891,7 @@ trait Applications extends Compatibility { } fun1.tpe match { - case err: ErrorType => cpy.Apply(tree)(fun1, proto.unforcedTypedArgs).withType(err) + case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err) case TryDynamicCallType => typedDynamicApply(tree, pt) case _ => if (originalProto.isDropped) fun1 @@ -1635,14 +1635,46 @@ trait Applications extends Compatibility { def narrowByTypes(alts: List[TermRef], argTypes: List[Type], resultType: Type): List[TermRef] = alts filter (isApplicableMethodRef(_, argTypes, resultType)) + /** Normalization steps before checking arguments: + * + * { expr } --> expr + * (x1, ..., xn) => expr --> ((x1, ..., xn)) => expr + * if n != 1, no alternative has a corresponding formal parameter that + * is an n-ary function, and at least one alternative has a corresponding + * formal parameter that is a unary function. + */ + def normArg(alts: List[TermRef], arg: untpd.Tree, idx: Int): untpd.Tree = arg match + case Block(Nil, expr) => normArg(alts, expr, idx) + case untpd.Function(args: List[untpd.ValDef] @unchecked, body) => + + // If ref refers to a method whose parameter at index `idx` is a function type, + // the arity of that function, otherise -1. + def paramCount(ref: TermRef) = + val formals = ref.widen.firstParamTypes + if formals.length > idx then + formals(idx) match + case defn.FunctionOf(args, _, _, _) => args.length + case _ => -1 + else -1 + + val numArgs = args.length + if numArgs != 1 + && !alts.exists(paramCount(_) == numArgs) + && alts.exists(paramCount(_) == 1) + then + desugar.makeTupledFunction(args, body, isGenericTuple = true) + // `isGenericTuple = true` is the safe choice here. It means the i'th tuple + // element is selected with `(i)` instead of `_i`, which gives the same code + // in the end, but the compilation time and the ascribed type are more involved. + // It also means that -Ytest-pickler -Xprint-types fails for sources exercising + // the idiom since after pickling the target is known, so _i is used directly. + else arg + case _ => arg + end normArg + val candidates = pt match { case pt @ FunProto(args, resultType) => val numArgs = args.length - val normArgs = args.mapConserve { - case Block(Nil, expr) => expr - case x => x - } - def sizeFits(alt: TermRef): Boolean = alt.widen.stripPoly match { case tp: MethodType => val ptypes = tp.paramInfos @@ -1661,9 +1693,10 @@ trait Applications extends Compatibility { alts.filter(sizeFits(_)) def narrowByShapes(alts: List[TermRef]): List[TermRef] = - if (normArgs exists untpd.isFunctionWithUnknownParamType) - if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType) - else narrowByTypes(alts, normArgs map typeShape, resultType) + val normArgs = args.mapWithIndexConserve(normArg(alts, _, _)) + if normArgs.exists(untpd.isFunctionWithUnknownParamType) then + if hasNamedArg(args) then narrowByTrees(alts, normArgs.map(treeShape), resultType) + else narrowByTypes(alts, normArgs.map(typeShape), resultType) else alts @@ -1681,16 +1714,14 @@ trait Applications extends Compatibility { val alts1 = narrowBySize(alts) //ctx.log(i"narrowed by size: ${alts1.map(_.symbol.showDcl)}%, %") - if (isDetermined(alts1)) alts1 - else { + if isDetermined(alts1) then alts1 + else val alts2 = narrowByShapes(alts1) //ctx.log(i"narrowed by shape: ${alts2.map(_.symbol.showDcl)}%, %") - if (isDetermined(alts2)) alts2 - else { + if isDetermined(alts2) then alts2 + else pretypeArgs(alts2, pt) - narrowByTrees(alts2, pt.unforcedTypedArgs, resultType) - } - } + narrowByTrees(alts2, pt.typedArgs(normArg(alts2, _, _)), resultType) case pt @ PolyProto(targs1, pt1) if targs.isEmpty => val alts1 = alts.filter(pt.isMatchedBy(_)) @@ -1749,7 +1780,7 @@ trait Applications extends Compatibility { else pt match { case pt @ FunProto(_, resType: FunProto) => // try to narrow further with snd argument list - val advanced = advanceCandidates(pt.unforcedTypedArgs.tpes) + val advanced = advanceCandidates(pt.typedArgs().tpes) resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped .map(advanced.toMap) // map surviving result(s) back to original candidates case _ => diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index a2fcd6295143..b03ace7f8d50 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -55,7 +55,7 @@ object ErrorReporting { case _: WildcardType | _: IgnoredProto => "" case tp => em" and expected result type $tp" } - em"arguments (${tp.unforcedTypedArgs.tpes}%, %)$result" + em"arguments (${tp.typedArgs().tpes}%, %)$result" case _ => em"expected type $tp" } diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index cb2bdb20822a..2fa5b8105525 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -250,7 +250,7 @@ object ProtoTypes { override def resultType(implicit ctx: Context): Type = resType def isMatchedBy(tp: Type, keepConstraint: Boolean)(implicit ctx: Context): Boolean = { - val args = unforcedTypedArgs + val args = typedArgs() def isPoly(tree: Tree) = tree.tpe.widenSingleton.isInstanceOf[PolyType] // See remark in normalizedCompatible for why we can't keep the constraint // if one of the arguments has a PolyType. @@ -301,15 +301,18 @@ object ProtoTypes { * However, any constraint changes are also propagated to the currently passed * context. * + * @param norm a normalization function that is applied to an untyped argument tree + * before it is typed. The second Int parameter is the parameter index. */ - def unforcedTypedArgs(implicit ctx: Context): List[Tree] = + def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree = sameTree)(implicit ctx: Context): List[Tree] = if (state.typedArgs.size == args.length) state.typedArgs else { val prevConstraint = this.ctx.typerState.constraint try { implicit val ctx = this.ctx - val args1 = args.mapconserve(cacheTypedArg(_, typer.typed(_), force = false)) + val args1 = args.mapWithIndexConserve((arg, idx) => + cacheTypedArg(arg, arg => typer.typed(norm(arg, idx)), force = false)) if (!args1.exists(arg => isUndefined(arg.tpe))) state.typedArgs = args1 args1 } @@ -375,7 +378,7 @@ object ProtoTypes { derivedFunProto(args, tm(resultType), typer) def fold[T](x: T, ta: TypeAccumulator[T])(implicit ctx: Context): T = - ta(ta.foldOver(x, unforcedTypedArgs.tpes), resultType) + ta(ta.foldOver(x, typedArgs().tpes), resultType) override def deepenProto(implicit ctx: Context): FunProto = derivedFunProto(args, resultType.deepenProto, typer) @@ -389,7 +392,7 @@ object ProtoTypes { * [](args): resultType, where args are known to be typed */ class FunProtoTyped(args: List[tpd.Tree], resultType: Type)(typer: Typer, isGivenApply: Boolean)(implicit ctx: Context) extends FunProto(args, resultType)(typer, isGivenApply)(ctx) { - override def unforcedTypedArgs(implicit ctx: Context): List[tpd.Tree] = args + override def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree)(implicit ctx: Context): List[tpd.Tree] = args override def withContext(ctx: Context): FunProtoTyped = this } @@ -682,4 +685,6 @@ object ProtoTypes { case _ => None } } + + private val sameTree = (t: untpd.Tree, n: Int) => t } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index b795ab296467..78086949966f 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -947,7 +947,7 @@ class Typer extends Namer } def typedFunctionValue(tree: untpd.Function, pt: Type)(implicit ctx: Context): Tree = { - val untpd.Function(params: List[untpd.ValDef] @unchecked, body) = tree + val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree val isContextual = tree match { case tree: untpd.FunctionWithMods => tree.mods.is(Given) diff --git a/compiler/test/dotc/pos-test-pickling.blacklist b/compiler/test/dotc/pos-test-pickling.blacklist index e5c98b5666aa..76e6ff22551a 100644 --- a/compiler/test/dotc/pos-test-pickling.blacklist +++ b/compiler/test/dotc/pos-test-pickling.blacklist @@ -32,3 +32,6 @@ i7580.scala # Nullability nullable.scala + +# parameter untupling with overloaded functions (see comment in Applications.normArg) +i7757.scala \ No newline at end of file diff --git a/tests/pos/i7757.scala b/tests/pos/i7757.scala new file mode 100644 index 000000000000..0f35c260b7a3 --- /dev/null +++ b/tests/pos/i7757.scala @@ -0,0 +1,10 @@ +val m: Map[Int, String] = ??? +val _ = m.map((a, b) => a + b.length) + +trait Foo + def g(f: ((Int, Int)) => Int): Int = 1 + def g(f: ((Int, Int)) => (Int, Int)): String = "2" + +@main def Test = + val m: Foo = ??? + m.g((x: Int, b: Int) => (x, x)) \ No newline at end of file