Skip to content

Strengthen overloading resolution to deal with extension methods #6116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
case dummyTreeOfType(tp) :: Nil if !(tp isRef defn.NullClass) => "null: " ~ toText(tp)
case _ => toTextGlobal(args, ", ")
}
return "FunProto(" ~ (Str("given ") provided tp.isContextual) ~ argsText ~ "):" ~ toText(resultType)
return "[applied to " ~ (Str("given ") provided tp.isContextual) ~ "(" ~ argsText ~ ") returning " ~ toText(resultType) ~ "]"
case IgnoredProto(ignored) =>
return "?" ~ (("(ignored: " ~ toText(ignored) ~ ")") provided ctx.settings.verbose.value)
case tp @ PolyProto(targs, resType) =>
return "PolyProto(" ~ toTextGlobal(targs, ", ") ~ "): " ~ toText(resType)
return "[applied to [" ~ toTextGlobal(targs, ", ") ~ "] returning " ~ toText(resType)
case _ =>
}
super.toText(tp)
Expand Down
59 changes: 49 additions & 10 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Trees.Untyped
import Contexts._
import Flags._
import Symbols._
import Denotations.Denotation
import Types._
import Decorators._
import ErrorReporting._
Expand Down Expand Up @@ -1204,8 +1205,12 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
* result matching `resultType`?
*/
def hasExtensionMethod(tp: Type, name: TermName, argType: Type, resultType: Type)(implicit ctx: Context) = {
val mbr = tp.memberBasedOnFlags(name, required = ExtensionMethod)
mbr.exists && isApplicable(tp.select(name, mbr), argType :: Nil, resultType)
def qualifies(mbr: Denotation) =
mbr.exists && isApplicable(tp.select(name, mbr), argType :: Nil, resultType)
tp.memberBasedOnFlags(name, required = ExtensionMethod) match {
case mbr: SingleDenotation => qualifies(mbr)
case mbr => mbr.hasAltWith(qualifies(_))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of the dispatch here? Semantically it seems the same as:

tp.memberBasedOnFlags(name, required = ExtensionMethod).hasAltWith(qualifies(_))

Is it for performance? It seems we save several exists check this way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's for performance. We optimize for the common case where the member is not overloaded.

}

/** Compare owner inheritance level.
Expand Down Expand Up @@ -1627,16 +1632,50 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
}
else compat
}

/** For each candidate `C`, a proxy termref paired with `C`.
* The proxy termref has as symbol a copy of the original candidate symbol,
* with an info that strips the first value parameter list away.
* @param argTypes The types of the arguments of the FunProto `pt`.
*/
def advanceCandidates(argTypes: List[Type]): List[(TermRef, TermRef)] = {
def strippedType(tp: Type): Type = tp match {
case tp: PolyType =>
val rt = strippedType(tp.resultType)
if (rt.exists) tp.derivedLambdaType(resType = rt) else rt
case tp: MethodType =>
tp.instantiate(argTypes)
case _ =>
NoType
}
def cloneCandidate(cand: TermRef): List[(TermRef, TermRef)] = {
val strippedInfo = strippedType(cand.widen)
if (strippedInfo.exists) {
val sym = cand.symbol.asTerm.copy(info = strippedInfo)
(TermRef(NoPrefix, sym), cand) :: Nil
}
else Nil
}
overload.println(i"look at more params: ${candidates.head.symbol}: ${candidates.map(_.widen)}%, % with $pt, [$targs%, %]")
candidates.flatMap(cloneCandidate)
}

val found = narrowMostSpecific(candidates)
if (found.length <= 1) found
else {
val noDefaults = alts.filter(!_.symbol.hasDefaultParams)
if (noDefaults.length == 1) noDefaults // return unique alternative without default parameters if it exists
else {
val deepPt = pt.deepenProto
if (deepPt ne pt) resolveOverloaded(alts, deepPt, targs)
else alts
}
else pt match {
case pt @ FunProto(_, resType: FunProto) =>
// try to narrow further with snd argument list
val advanced = advanceCandidates(pt.typedArgs.tpes)
resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped
.map(advanced.toMap) // map surviving result(s) back to original candidates
case _ =>
val noDefaults = alts.filter(!_.symbol.hasDefaultParams)
if (noDefaults.length == 1) noDefaults // return unique alternative without default parameters if it exists
else {
val deepPt = pt.deepenProto
if (deepPt ne pt) resolveOverloaded(alts, deepPt, targs)
else alts
}
}
}

Expand Down
22 changes: 14 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,16 @@ object Implicits {
/** A "massaging" function for displayed types to give better info in error diagnostics */
def clarify(tp: Type)(implicit ctx: Context): Type = tp

final protected def qualify(implicit ctx: Context): String =
if (expectedType.exists)
if (argument.isEmpty) em"match type ${clarify(expectedType)}"
else em"convert from ${argument.tpe} to ${clarify(expectedType)}"
else
final protected def qualify(implicit ctx: Context): String = expectedType match {
case SelectionProto(name, mproto, _, _) if !argument.isEmpty =>
em"provide an extension method `$name` on ${argument.tpe}"
case NoType =>
if (argument.isEmpty) em"match expected type"
else em"convert from ${argument.tpe} to expected type"
case _ =>
if (argument.isEmpty) em"match type ${clarify(expectedType)}"
else em"convert from ${argument.tpe} to ${clarify(expectedType)}"
}

/** An explanation of the cause of the failure as a string */
def explanation(implicit ctx: Context): String
Expand Down Expand Up @@ -425,9 +428,12 @@ object Implicits {
class AmbiguousImplicits(val alt1: SearchSuccess, val alt2: SearchSuccess, val expectedType: Type, val argument: Tree) extends SearchFailureType {
def explanation(implicit ctx: Context): String =
em"both ${err.refStr(alt1.ref)} and ${err.refStr(alt2.ref)} $qualify"
override def whyNoConversion(implicit ctx: Context): String =
"\nNote that implicit conversions cannot be applied because they are ambiguous;" +
"\n" + explanation
override def whyNoConversion(implicit ctx: Context): String = {
val what = if (expectedType.isInstanceOf[SelectionProto]) "extension methods" else "conversions"
i"""
|Note that implicit $what cannot be applied because they are ambiguous;
|$explanation"""
}
}

class MismatchedImplicit(ref: TermRef,
Expand Down
11 changes: 8 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ object RefChecks {
* 1.8.1 M's type is a subtype of O's type, or
* 1.8.2 M is of type []S, O is of type ()T and S <: T, or
* 1.8.3 M is of type ()S, O is of type []T and S <: T, or
* 1.9 If M or O are erased, they must be both erased
* 1.9.1 If M or O are erased, they must both be erased
* 1.9.2 If M or O are extension methods, they must both be extension methods
* 1.10 If M is an inline or Scala-2 macro method, O cannot be deferred unless
* there's also a concrete method that M overrides.
* 1.11. If O is a Scala-2 macro, M must be a Scala-2 macro.
Expand Down Expand Up @@ -391,10 +392,14 @@ object RefChecks {
overrideError("may not override a non-lazy value")
} else if (other.is(Lazy) && !other.isRealMethod && !member.is(Lazy)) {
overrideError("must be declared lazy to override a lazy value")
} else if (member.is(Erased) && !other.is(Erased)) { // (1.9)
} else if (member.is(Erased) && !other.is(Erased)) { // (1.9.1)
overrideError("is erased, cannot override non-erased member")
} else if (other.is(Erased) && !member.is(Erased)) { // (1.9)
} else if (other.is(Erased) && !member.is(Erased)) { // (1.9.1)
overrideError("is not erased, cannot override erased member")
} else if (member.is(Extension) && !other.is(Extension)) { // (1.9.2)
overrideError("is an extension method, cannot override a normal method")
} else if (other.is(Extension) && !member.is(Extension)) { // (1.9.2)
overrideError("is a normal method, cannot override an extension method")
} else if ((member.isInlineMethod || member.is(Scala2Macro)) && other.is(Deferred) &&
member.extendedOverriddenSymbols.forall(_.is(Deferred))) { // (1.10)
overrideError("is an inline method, must override at least one concrete method")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class SignatureHelpTest {
}""".withSource
.signatureHelp(m1, List(sig0, sig1), None, 0)
.signatureHelp(m2, List(sig0, sig1), None, 0)
.signatureHelp(m3, List(sig0, sig1), Some(1), 1)
.signatureHelp(m3, List(), Some(1), 1) // TODO: investigate we do not get help at $m3
}

@Test def multipleParameterLists: Unit = {
Expand Down
14 changes: 14 additions & 0 deletions tests/neg/extmethod-overload.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object Test {
implied A {
def (x: Int) |+| (y: Int) = x + y
}
implied B {
def (x: Int) |+| (y: String) = x + y.length
}
assert((1 |+| 2) == 3) // error ambiguous

locally {
import B.|+|
assert((1 |+| "2") == 2) // OK
}
}
8 changes: 8 additions & 0 deletions tests/neg/extmethod-override.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class A {
def f(x: Int)(y: Int): Int = 0
def (x: Int) g (y: Int): Int = 1
}
class B extends A {
override def (x: Int) f (y: Int): Int = 1 // error
override def g(x: Int)(y: Int): Int = 0 // error
}
122 changes: 122 additions & 0 deletions tests/run/extmethod-overload.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
object Test extends App {
// warmup
def f(x: Int)(y: Int) = y
def f(x: Int)(y: String) = y.length
assert(f(1)(2) == 2)
assert(f(1)("two") == 3)

def g[T](x: T)(y: Int) = y
def g[T](x: T)(y: String) = y.length
assert(g[Int](1)(2) == 2)
assert(g[Int](1)("two") == 3)
assert(g(1)(2) == 2)
assert(g(1)("two") == 3)

def h[T](x: T)(y: T)(z: Int) = z
def h[T](x: T)(y: T)(z: String) = z.length
assert(h[Int](1)(1)(2) == 2)
assert(h[Int](1)(1)("two") == 3)
assert(h(1)(1)(2) == 2)
assert(h(1)(1)("two") == 3)

// Test with extension methods in implied object
object test1 {

implied Foo {
def (x: Int) |+| (y: Int) = x + y
def (x: Int) |+| (y: String) = x + y.length

def (xs: List[T]) +++ [T] (ys: List[T]): List[T] = xs ++ ys ++ ys
def (xs: List[T]) +++ [T] (ys: Iterator[T]): List[T] = xs ++ ys ++ ys
}

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test1

// Test with imported extension methods
object test2 {
import test1.Foo._

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test2

// Test with implied extension methods coming from base class
object test3 {
class Foo {
def (x: Int) |+| (y: Int) = x + y
def (x: Int) |+| (y: String) = x + y.length

def (xs: List[T]) +++ [T] (ys: List[T]): List[T] = xs ++ ys ++ ys
def (xs: List[T]) +++ [T] (ys: Iterator[T]): List[T] = xs ++ ys ++ ys
}
implied Bar for Foo

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test3

// Test with implied extension methods coming from implied alias
object test4 {
implied for test3.Foo = test3.Bar

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test4

class C {
def xx (x: Any) = 2
}
def (c: C) xx (x: Int) = 1

val c = new C
assert(c.xx(1) == 2) // member method takes precedence

object D {
def (x: Int) yy (y: Int) = x + y
}

implied {
def (x: Int) yy (y: Int) = x - y
}

import D._
assert((1 yy 2) == 3) // imported extension method takes precedence

trait Rectangle {
def a: Long
def b: Long
}

case class GenericRectangle(a: Long, b: Long) extends Rectangle
case class Square(a: Long) extends Rectangle {
def b: Long = a
}

def (rectangle: Rectangle) area: Long = 0
def (square: Square) area: Long = square.a * square.a
val rectangles = List(GenericRectangle(2, 3), Square(5))
val areas = rectangles.map(_.area)
assert(areas.sum == 0)
}