Skip to content

[Backport] Overloading resolution: Handle SAM types more like Java and Scala 2 #12131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ class Definitions {
ClassInfo(ScalaPackageClass.thisType, cls, ObjectType :: Nil, decls)
}
}
val flags0 = Trait | NoInits
val flags = if (name.isContextFunction) flags0 | Final else flags0
val flags = Trait | NoInits
newPermanentClassSymbol(ScalaPackageClass, name, flags, completer)
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5043,7 +5043,7 @@ object Types {
NoType
}
def isInstantiatable(tp: Type)(using Context): Boolean = zeroParamClass(tp) match {
case cinfo: ClassInfo =>
case cinfo: ClassInfo if !cinfo.cls.isOneOf(FinalOrSealed) =>
val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
tp <:< selfType
case _ =>
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ enum ErrorMessageID extends java.lang.Enum[ErrorMessageID] {
OverrideTypeMismatchErrorID,
OverrideErrorID,
MatchableWarningID,
IllegalParameterInitID
IllegalParameterInitID,
CannotExtendFunctionID

def errorNumber = ordinal - 2
}
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,12 @@ import transform.SymUtils._
def explain = ""
}

class CannotExtendContextFunction(sym: Symbol)(using Context)
extends SyntaxMsg(CannotExtendFunctionID) {
def msg = em"""$sym cannot extend a context function class"""
def explain = ""
}

class JavaEnumParentArgs(parent: Type)(using Context)
extends TypeMsg(JavaEnumParentArgsID) {
def msg = em"""not enough arguments for constructor Enum: ${hl("(name: String, ordinal: Int)")}: ${hl(parent.show)}"""
Expand Down
61 changes: 35 additions & 26 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,6 @@ trait Applications extends Compatibility {

/** The degree to which an argument has to match a formal parameter */
enum ArgMatch:
case SubType // argument is a relaxed subtype of formal
case Compatible // argument is compatible with formal
case CompatibleCAP // capture-converted argument is compatible with formal

Expand All @@ -635,21 +634,38 @@ trait Applications extends Compatibility {
// matches expected type
false
case argtpe =>
def SAMargOK = formal match {
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
case _ => false
}
if argMatch == ArgMatch.SubType then
argtpe relaxed_<:< formal.widenExpr
else
isCompatible(argtpe, formal)
|| ctx.mode.is(Mode.ImplicitsEnabled) && SAMargOK
|| argMatch == ArgMatch.CompatibleCAP
&& {
val argtpe1 = argtpe.widen
val captured = captureWildcards(argtpe1)
(captured ne argtpe1) && isCompatible(captured, formal.widenExpr)
}
val argtpe1 = argtpe.widen

def SAMargOK =
defn.isFunctionType(argtpe1) && formal.match
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
case _ => false

isCompatible(argtpe, formal)
// Only allow SAM-conversion to PartialFunction if implicit conversions
// are enabled. This is necessary to avoid ambiguity between an overload
// taking a PartialFunction and one taking a Function1 because
// PartialFunction extends Function1 but Function1 is SAM-convertible to
// PartialFunction. Concretely, given:
//
// def foo(a: Int => Int): Unit = println("1")
// def foo(a: PartialFunction[Int, Int]): Unit = println("2")
//
// - `foo(x => x)` will print 1, because the PartialFunction overload
// won't be seen as applicable in the first call to
// `resolveOverloaded`, this behavior happens to match what Java does
// since PartialFunction is not a SAM type according to Java
// (`isDefined` is abstract).
// - `foo { case x if x % 2 == 0 => x }` will print 2, because both
// overloads are applicable, but PartialFunction is a subtype of
// Function1 so it's more specific.
|| (!formal.isRef(defn.PartialFunctionClass) || ctx.mode.is(Mode.ImplicitsEnabled)) && SAMargOK
|| argMatch == ArgMatch.CompatibleCAP
&& {
val argtpe1 = argtpe.widen
val captured = captureWildcards(argtpe1)
(captured ne argtpe1) && isCompatible(captured, formal.widenExpr)
}

/** The type of the given argument */
protected def argType(arg: Arg, formal: Type): Type
Expand Down Expand Up @@ -1863,17 +1879,10 @@ trait Applications extends Compatibility {
else
alts

def narrowByTrees(alts: List[TermRef], args: List[Tree], resultType: Type): List[TermRef] = {
val alts2 = alts.filterConserve(alt =>
isApplicableMethodRef(alt, args, resultType, keepConstraint = false, ArgMatch.SubType)
def narrowByTrees(alts: List[TermRef], args: List[Tree], resultType: Type): List[TermRef] =
alts.filterConserve(alt =>
isApplicableMethodRef(alt, args, resultType, keepConstraint = false, ArgMatch.CompatibleCAP)
)
if (alts2.isEmpty && !ctx.isAfterTyper)
alts.filterConserve(alt =>
isApplicableMethodRef(alt, args, resultType, keepConstraint = false, ArgMatch.CompatibleCAP)
)
else
alts2
}

record("resolveOverloaded.FunProto", alts.length)
val alts1 = narrowBySize(alts)
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ object RefChecks {

/** Check that self type of this class conforms to self types of parents
* and required classes. Also check that only `enum` constructs extend
* `java.lang.Enum`.
* `java.lang.Enum` and no user-written class extends ContextFunctionN.
*/
private def checkParents(cls: Symbol, parentTrees: List[Tree])(using Context): Unit = cls.info match {
case cinfo: ClassInfo =>
Expand Down Expand Up @@ -132,6 +132,8 @@ object RefChecks {
case _ =>
false
}
if psyms.exists(defn.isContextFunctionClass) then
report.error(CannotExtendContextFunction(cls), cls.sourcePos)

/** Check that arguments passed to trait parameters conform to the parameter types
* in the current class. This is necessary since parameter types might be narrowed
Expand Down
4 changes: 0 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1398,10 +1398,6 @@ class Typer extends Namer
else
report.error(ex"result type of lambda is an underspecified SAM type $pt", tree.srcPos)
pt
if (pt.classSymbol.isOneOf(FinalOrSealed)) {
val offendingFlag = pt.classSymbol.flags & FinalOrSealed
report.error(ex"lambda cannot implement $offendingFlag ${pt.classSymbol}", tree.srcPos)
}
TypeTree(targetTpe)
case _ =>
if (mt.isParamDependent)
Expand Down
16 changes: 16 additions & 0 deletions tests/neg/i11938.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import java.util.function.Function

object Test {
def foo[V](v: V): Int = 1
def foo[U](fn: Function[Int, U]): Int = 2

def main(args: Array[String]): Unit = {
val f: Int => Int = x => x
foo(f) // error
// Like Scala 2, we emit an error here because the Function1 argument was
// deemed SAM-convertible to Function, even though it's not a lambda
// expression and therefore not convertible. If we wanted to support this,
// we would have to tweak TestApplication#argOK to look at the shape of
// `arg` and turn off SAM conversions when it's a non-closure tree.
}
}
22 changes: 22 additions & 0 deletions tests/run/i11938.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import java.util.function.Function

object Test {
def foo[V](v: V): Int = 1
def foo[U](fn: Function[Int, U]): Int = 2

def foo2(a: Int => Int): Int = 1
def foo2(a: PartialFunction[Int, Int]): Int = 2

def main(args: Array[String]): Unit = {
assert(foo((x: Int) => x) == 2)
val jf: Function[Int, Int] = x => x
assert(foo(jf) == 2)

assert(foo2(x => x) == 1)
val f: Int => Int = x => x
assert(foo2(f) == 1)
assert(foo2({ case x if x % 2 == 0 => x }) == 2)
val pf: PartialFunction[Int, Int] = { case x if x % 2 == 0 => x }
assert(foo2(pf) == 2)
}
}