Skip to content

Allow partial eta expansion #2691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,13 @@ object Types {
case _ => NoType
}

/** If this is a FunProto or PolyProto, WildcardType, otherwise this. */
/** If this is a repeated type, its element type, otherwise the type itself */
def repeatedToSingle(implicit ctx: Context): Type = this match {
case tp @ ExprType(tp1) => tp.derivedExprType(tp1.repeatedToSingle)
case _ => if (isRepeatedParam) this.argTypesHi.head else this
}

/** If this is a FunProto or PolyProto, WildcardType, otherwise this. */
def notApplied: Type = this

// ----- Normalizing typerefs over refined types ----------------------------
Expand Down
6 changes: 1 addition & 5 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1056,12 +1056,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
*/
def isAsSpecific(alt1: TermRef, tp1: Type, alt2: TermRef, tp2: Type): Boolean = ctx.traceIndented(i"isAsSpecific $tp1 $tp2", overload) { tp1 match {
case tp1: MethodType => // (1)
def repeatedToSingle(tp: Type): Type = tp match {
case tp @ ExprType(tp1) => tp.derivedExprType(repeatedToSingle(tp1))
case _ => if (tp.isRepeatedParam) tp.argTypesHi.head else tp
}
val formals1 =
if (tp1.isVarArgsMethod && tp2.isVarArgsMethod) tp1.paramInfos map repeatedToSingle
if (tp1.isVarArgsMethod && tp2.isVarArgsMethod) tp1.paramInfos.map(_.repeatedToSingle)
else tp1.paramInfos
isApplicable(alt2, formals1, WildcardType) ||
tp1.paramInfos.isEmpty && tp2.isInstanceOf[LambdaType]
Expand Down
65 changes: 47 additions & 18 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -724,41 +724,70 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
*/
var fnBody = tree.body

/** A map from parameter names to unique positions where the parameter
* appears in the argument list of an application.
*/
var paramIndex = Map[Name, Int]()

/** If parameter `param` appears exactly once as an argument in `args`,
* the singleton list consisting of its position in `args`, otherwise `Nil`.
*/
def paramIndices(param: untpd.ValDef, args: List[untpd.Tree], start: Int): List[Int] = args match {
case arg :: args1 =>
if (refersTo(arg, param))
if (paramIndices(param, args1, start + 1).isEmpty) start :: Nil
else Nil
else paramIndices(param, args1, start + 1)
case _ => Nil
}

/** If function is of the form
* (x1, ..., xN) => f(x1, ..., XN)
* the type of `f`, otherwise NoType. (updates `fnBody` as a side effect).
* (x1, ..., xN) => f(... x1, ..., XN, ...)
* where each `xi` occurs exactly once in the argument list of `f` (in
* any order), the type of `f`, otherwise NoType.
* Updates `fnBody` and `paramIndex` as a side effect.
* @post: If result exists, `paramIndex` is defined for the name of
* every parameter in `params`.
*/
def calleeType: Type = fnBody match {
case Apply(expr, args) if (args corresponds params)(refersTo) =>
expr match {
case untpd.TypedSplice(expr1) =>
expr1.tpe
case _ =>
val protoArgs = args map (_ withType WildcardType)
val callProto = FunProto(protoArgs, WildcardType, this)
val expr1 = typedExpr(expr, callProto)
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
expr1.tpe
}
case Apply(expr, args) =>
paramIndex = {
for (param <- params; idx <- paramIndices(param, args, 0))
yield param.name -> idx
}.toMap
if (paramIndex.size == params.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a map, wouldn't this miss the (unlikely) case where the same param (one key) occurs in multiple args (multiple values, of which only the last one is retained). I don't think it's a big deal, as this could not arise from the _ syntax, but the other comments do mention "exactly once" as a requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The revised implementation of paramIndices makes sure that a parameter is entered in the map only if there is a unique occurrence in args.

expr match {
case untpd.TypedSplice(expr1) =>
expr1.tpe
case _ =>
val protoArgs = args map (_ withType WildcardType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, the current implementation in Scala 2 includes the known argument types / result type, to enable overload resolution. Would be good to indicate the difference in a comment.

val callProto = FunProto(protoArgs, WildcardType, this)
val expr1 = typedExpr(expr, callProto)
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
expr1.tpe
}
else NoType
case _ =>
NoType
}

/** Two attempts: First, if expected type is fully defined pick this one.
* Second, if function is of the form
* (x1, ..., xN) => f(x1, ..., XN)
* and f has a method type MT, pick the corresponding parameter type in MT,
* if this one is fully defined.
* (x1, ..., xN) => f(... x1, ..., XN, ...)
* where each `xi` occurs exactly once in the argument list of `f` (in
* any order), and f has a method type MT, pick the corresponding parameter
* type in MT, if this one is fully defined.
* If both attempts fail, issue a "missing parameter type" error.
*/
def inferredParamType(param: untpd.ValDef, formal: Type): Type = {
if (isFullyDefined(formal, ForceDegree.noBottom)) return formal
calleeType.widen match {
case mtpe: MethodType =>
val pos = params indexWhere (_.name == param.name)
val pos = paramIndex(param.name)
if (pos < mtpe.paramInfos.length) {
val ptype = mtpe.paramInfos(pos)
if (isFullyDefined(ptype, ForceDegree.noBottom)) return ptype
if (isFullyDefined(ptype, ForceDegree.noBottom) && !ptype.isRepeatedParam)
return ptype
}
case _ =>
}
Expand Down
29 changes: 29 additions & 0 deletions tests/pos/i2570.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
object Test {

def repeat(s: String, i: Int, j: Int = 22) = s * i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When discussing this with Seth, he suggested to play around with currying by adding another argument list. The partial application should be fine: repeat(_, 1), but what about repeat(_, 1)(_)? (uncurrying?)

Copy link
Contributor Author

@odersky odersky Jun 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In dotc you get:

1 |repeat(1, _)(_)
  |             ^
  |             unbound placeholder parameter; incorrect use of `_`

In scalac you get:

scala> repeat(1, _)(_) 
<console>:13: error: missing parameter type for expanded function ((x$1: <error>, x$2) => repeat(1, x$1)(x$2))
       repeat(1, _)(_)
                 ^
<console>:13: error: missing parameter type for expanded function ((x$1: <error>, x$2: <error>) => repeat(1, x$1)(x$2))
       repeat(1, _)(_)

The fact that the types already contain <error> markers makes me suspect that scalac also issues a parsing error but somehow suppresses the error message.


val f1 = repeat("abc", _)
val f2: Int => String = f1
val f3 = repeat(_, 3)
val f4: String => String = f3
val f5 = repeat("abc", _, _)
val f6: (Int, Int) => String = f5
val f7 = repeat(_, 11, _)
val f8: (String, Int) => String = f7

def sum(x: Int, y: => Int) = x + y

val g1 = sum(2, _)
val g2: (=> Int) => Int = g1

val h0: ((Int, => Int) => Int) = sum

def sum2(x: Int, ys: Int*) = (x /: ys)(_ + _)
val h1: ((Int, Seq[Int]) => Int) = sum2

// Not yet:
// val h1 = repeat
// val h2: (String, Int, Int) = h1
// val h3 = sum
// val h4: (Int, => Int) = h3
}