Skip to content

Commit 9b70af9

Browse files
authored
Convert SAM result types to function types (#17740)
2 parents 889c208 + f641a87 commit 9b70af9

File tree

6 files changed

+55
-15
lines changed

6 files changed

+55
-15
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5577,6 +5577,16 @@ object Types {
55775577
else None
55785578
}
55795579
else None
5580+
5581+
def isSamCompatible(lhs: Type, rhs: Type)(using Context): Boolean = rhs match
5582+
case SAMType(mt) if !isParamDependentRec(mt) =>
5583+
lhs <:< mt.toFunctionType(isJava = rhs.classSymbol.is(JavaDefined))
5584+
case _ => false
5585+
5586+
def isParamDependentRec(mt: MethodType)(using Context): Boolean =
5587+
mt.isParamDependent || mt.resultType.match
5588+
case mt: MethodType => isParamDependentRec(mt)
5589+
case _ => false
55805590
}
55815591

55825592
// ----- TypeMaps --------------------------------------------------------------------

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,9 +695,7 @@ trait Applications extends Compatibility {
695695
val argtpe1 = argtpe.widen
696696

697697
def SAMargOK =
698-
defn.isFunctionType(argtpe1) && formal.match
699-
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
700-
case _ => false
698+
defn.isFunctionType(argtpe1) && SAMType.isSamCompatible(argtpe, formal)
701699

702700
isCompatible(argtpe, formal)
703701
// Only allow SAM-conversion to PartialFunction if implicit conversions

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13241324
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
13251325
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
13261326
(formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
1327-
case SAMType(mt @ MethodTpe(_, formals, restpe)) =>
1327+
case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) if !SAMType.isParamDependentRec(mt) =>
1328+
val restpe = mt.resultType match
1329+
case mt: MethodType => mt.toFunctionType(isJava = pt1.classSymbol.is(JavaDefined))
1330+
case tp => tp
13281331
(formals,
13291332
if (mt.isResultDependent)
13301333
untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))
@@ -4115,17 +4118,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
41154118
// convert function literal to SAM closure
41164119
tree match {
41174120
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
4118-
if defn.isFunctionType(wtp) && !defn.isFunctionType(pt) =>
4119-
pt match {
4120-
case SAMType(sam)
4121-
if wtp <:< sam.toFunctionType(isJava = pt.classSymbol.is(JavaDefined)) =>
4122-
// was ... && isFullyDefined(pt, ForceDegree.flipBottom)
4123-
// but this prevents case blocks from implementing polymorphic partial functions,
4124-
// since we do not know the result parameter a priori. Have to wait until the
4125-
// body is typechecked.
4126-
return toSAM(tree)
4127-
case _ =>
4128-
}
4121+
if defn.isFunctionType(wtp) && !defn.isFunctionType(pt) && SAMType.isSamCompatible(wtp, pt) =>
4122+
// was ... && isFullyDefined(pt, ForceDegree.flipBottom)
4123+
// but this prevents case blocks from implementing polymorphic partial functions,
4124+
// since we do not know the result parameter a priori. Have to wait until the
4125+
// body is typechecked.
4126+
return toSAM(tree)
41294127
case _ =>
41304128
}
41314129

tests/neg/i17183.check

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
-- [E081] Type Error: tests/neg/i17183.scala:11:24 ---------------------------------------------------------------------
2+
11 |def test = Context(f = (_, _) => ???) // error // error
3+
| ^
4+
| Missing parameter type
5+
|
6+
| I could not infer the type of the parameter _$1 of expanded function:
7+
| (_$1, _$2) => ???.
8+
-- [E081] Type Error: tests/neg/i17183.scala:11:27 ---------------------------------------------------------------------
9+
11 |def test = Context(f = (_, _) => ???) // error // error
10+
| ^
11+
| Missing parameter type
12+
|
13+
| I could not infer the type of the parameter _$2 of expanded function:
14+
| (_$1, _$2) => ???.

tests/neg/i17183.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
trait Dependency {
2+
trait More
3+
}
4+
5+
trait MyFunc {
6+
def apply(a: Int, b: String)(using dep: Dependency, more: dep.More): String
7+
}
8+
9+
case class Context(f: MyFunc)
10+
11+
def test = Context(f = (_, _) => ???) // error // error

tests/pos/i17183.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
trait Dependency
2+
3+
trait MyFunc {
4+
def apply(a: Int, b: String)(using Dependency): String
5+
}
6+
7+
case class Context(f: MyFunc)
8+
9+
def test = Context(f = (_, _) => ???)

0 commit comments

Comments
 (0)