Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit be7676b

Browse files
committedJul 14, 2023
Properly handle SAM types with wildcards
When typing a closure with an expected type containing a wildcard, the closure type itself should not contain wildcards, because it might be expanded to an anonymous class extending the closure type (this happens on non-JVM backends as well as on the JVM itself in situations where a SAM trait does not compile down to a SAM interface). We were already approximating wildcards in the method type returned by the SAMType extractor, but to fix this issue we had to change the extractor to perform the approximation on the expected type itself to generate a valid parent type. The SAMType extractor now returns both the approximated parent type and the type of the method itself. The wildcard approximation analysis relies on a new `VarianceMap` opaque type extracted from Inferencing#variances. Fixes #16065. Fixes #18096.
1 parent d9a7900 commit be7676b

File tree

11 files changed

+209
-131
lines changed

11 files changed

+209
-131
lines changed
 

‎compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ class Definitions {
744744
@tu lazy val StringContextModule_processEscapes: Symbol = StringContextModule.requiredMethod(nme.processEscapes)
745745

746746
@tu lazy val PartialFunctionClass: ClassSymbol = requiredClass("scala.PartialFunction")
747+
@tu lazy val PartialFunction_apply: Symbol = PartialFunctionClass.requiredMethod(nme.apply)
747748
@tu lazy val PartialFunction_isDefinedAt: Symbol = PartialFunctionClass.requiredMethod(nme.isDefinedAt)
748749
@tu lazy val PartialFunction_applyOrElse: Symbol = PartialFunctionClass.requiredMethod(nme.applyOrElse)
749750

‎compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 128 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CheckRealizable._
2121
import Variances.{Variance, setStructuralVariances, Invariant}
2222
import typer.Nullables
2323
import util.Stats._
24-
import util.SimpleIdentitySet
24+
import util.{SimpleIdentityMap, SimpleIdentitySet}
2525
import ast.tpd._
2626
import ast.TreeTypeMap
2727
import printing.Texts._
@@ -1751,7 +1751,7 @@ object Types {
17511751
t
17521752
case t if defn.isErasedFunctionType(t) =>
17531753
t
1754-
case t @ SAMType(_) =>
1754+
case t @ SAMType(_, _) =>
17551755
t
17561756
case _ =>
17571757
NoType
@@ -5520,104 +5520,119 @@ object Types {
55205520
* A type is a SAM type if it is a reference to a class or trait, which
55215521
*
55225522
* - has a single abstract method with a method type (ExprType
5523-
* and PolyType not allowed!) whose result type is not an implicit function type
5524-
* and which is not marked inline.
5523+
* and PolyType not allowed!) according to `possibleSamMethods`.
55255524
* - can be instantiated without arguments or with just () as argument.
55265525
*
5527-
* The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5528-
* type of the single abstract method.
5526+
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5527+
* type of the single abstract method and `samParent` is a subtype of the matched
5528+
* SAM type which has been stripped of wildcards to turn it into a valid parent
5529+
* type.
55295530
*/
55305531
object SAMType {
5531-
def zeroParamClass(tp: Type)(using Context): Type = tp match {
5532+
/** If possible, return a type which is both a subtype of `origTp` and a type
5533+
* application of `samClass` where none of the type arguments are
5534+
* wildcards (thus making it a valid parent type), otherwise return
5535+
* NoType.
5536+
*
5537+
* A wildcard in the original type will be replaced by its upper or lower bound in a way
5538+
* that maximizes the number of possible implementations of `samMeth`. For example,
5539+
* java.util.function defines an interface equivalent to:
5540+
*
5541+
* trait Function[T, R]:
5542+
* def apply(t: T): R
5543+
*
5544+
* and it usually appears with wildcards to compensate for the lack of
5545+
* definition-site variance in Java:
5546+
*
5547+
* (x => x.toInt): Function[? >: String, ? <: Int]
5548+
*
5549+
* When typechecking this lambda, we need to approximate the wildcards to find
5550+
* a valid parent type for our lambda to extend. We can see that in `apply`,
5551+
* `T` only appears contravariantly and `R` only appears covariantly, so by
5552+
* minimizing the first parameter and maximizing the second, we maximize the
5553+
* number of valid implementations of `apply` which lets us implement the lambda
5554+
* with a closure equivalent to:
5555+
*
5556+
* new Function[String, Int] { def apply(x: String): Int = x.toInt }
5557+
*
5558+
* If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5559+
* we arbitrarily pick the upper-bound.
5560+
*/
5561+
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
5562+
val tp = origTp.baseType(samClass)
5563+
if !(tp <:< origTp) then NoType
5564+
else tp match
5565+
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5566+
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5567+
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5568+
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5569+
vmap.recordLocalVariance(tp.symbol, variance)
5570+
case _ =>
5571+
foldOver(vmap, t)
5572+
val vmap = accu(VarianceMap.empty, samMeth.info)
5573+
val tparams = tycon.typeParamSymbols
5574+
val args1 = args.zipWithConserve(tparams):
5575+
case (arg @ TypeBounds(lo, hi), tparam) =>
5576+
val v = vmap.computedVariance(tparam)
5577+
if v.uncheckedNN < 0 then lo
5578+
else hi
5579+
case (arg, _) => arg
5580+
tp.derivedAppliedType(tycon, args1)
5581+
case _ =>
5582+
tp
5583+
5584+
def samClass(tp: Type)(using Context): Symbol = tp match
55325585
case tp: ClassInfo =>
5533-
def zeroParams(tp: Type): Boolean = tp.stripPoly match {
5586+
def zeroParams(tp: Type): Boolean = tp.stripPoly match
55345587
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
55355588
case et: ExprType => true
55365589
case _ => false
5537-
}
5538-
// `ContextFunctionN` does not have constructors
5539-
val ctor = tp.cls.primaryConstructor
5540-
if (!ctor.exists || zeroParams(ctor.info)) tp
5541-
else NoType
5590+
val cls = tp.cls
5591+
val validCtor =
5592+
val ctor = cls.primaryConstructor
5593+
// `ContextFunctionN` does not have constructors
5594+
!ctor.exists || zeroParams(ctor.info)
5595+
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
5596+
if validCtor && isInstantiable then tp.cls
5597+
else NoSymbol
55425598
case tp: AppliedType =>
5543-
zeroParamClass(tp.superType)
5599+
samClass(tp.superType)
55445600
case tp: TypeRef =>
5545-
zeroParamClass(tp.underlying)
5601+
samClass(tp.underlying)
55465602
case tp: RefinedType =>
5547-
zeroParamClass(tp.underlying)
5603+
samClass(tp.underlying)
55485604
case tp: TypeBounds =>
5549-
zeroParamClass(tp.underlying)
5605+
samClass(tp.underlying)
55505606
case tp: TypeVar =>
5551-
zeroParamClass(tp.underlying)
5607+
samClass(tp.underlying)
55525608
case tp: AnnotatedType =>
5553-
zeroParamClass(tp.underlying)
5554-
case _ =>
5555-
NoType
5556-
}
5557-
def isInstantiatable(tp: Type)(using Context): Boolean = zeroParamClass(tp) match {
5558-
case cinfo: ClassInfo if !cinfo.cls.isOneOf(FinalOrSealed) =>
5559-
val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5560-
tp <:< selfType
5609+
samClass(tp.underlying)
55615610
case _ =>
5562-
false
5563-
}
5564-
def unapply(tp: Type)(using Context): Option[MethodType] =
5565-
if (isInstantiatable(tp)) {
5566-
val absMems = tp.possibleSamMethods
5567-
if (absMems.size == 1)
5568-
absMems.head.info match {
5569-
case mt: MethodType if !mt.isParamDependent &&
5570-
mt.resultType.isValueTypeOrWildcard =>
5571-
val cls = tp.classSymbol
5572-
5573-
// Given a SAM type such as:
5574-
//
5575-
// import java.util.function.Function
5576-
// Function[? >: String, ? <: Int]
5577-
//
5578-
// the single abstract method will have type:
5579-
//
5580-
// (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5581-
//
5582-
// which is not implementable outside of the scope of Function.
5583-
//
5584-
// To avoid this kind of issue, we approximate references to
5585-
// parameters of the SAM type by their bounds, this way in the
5586-
// above example we get:
5587-
//
5588-
// (x: String): Int
5589-
val approxParams = new ApproximatingTypeMap {
5590-
def apply(tp: Type): Type = tp match {
5591-
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) && tp.symbol.owner == cls =>
5592-
tp.info match {
5593-
case info: AliasingBounds =>
5594-
mapOver(info.alias)
5595-
case TypeBounds(lo, hi) =>
5596-
range(atVariance(-variance)(apply(lo)), apply(hi))
5597-
case _ =>
5598-
range(defn.NothingType, defn.AnyType) // should happen only in error cases
5599-
}
5600-
case _ =>
5601-
mapOver(tp)
5602-
}
5603-
}
5604-
val approx =
5605-
if ctx.owner.isContainedIn(cls) then mt
5606-
else approxParams(mt).asInstanceOf[MethodType]
5607-
Some(approx)
5611+
NoSymbol
5612+
5613+
def unapply(tp: Type)(using Context): Option[(MethodType, Type)] =
5614+
val cls = samClass(tp)
5615+
if cls.exists then
5616+
val absMems =
5617+
if tp.isRef(defn.PartialFunctionClass) then
5618+
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5619+
// pretending it is a SAM type. In the future it would be better to merge
5620+
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5621+
// def isDefinedAt(x: T) = true
5622+
// and overwrite that method whenever the function body is a sequence of
5623+
// case clauses.
5624+
List(defn.PartialFunction_apply)
5625+
else
5626+
tp.possibleSamMethods.map(_.symbol)
5627+
if absMems.lengthCompare(1) == 0 then
5628+
val samMethSym = absMems.head
5629+
val parent = samParent(tp, cls, samMethSym)
5630+
samMethSym.asSeenFrom(parent).info match
5631+
case mt: MethodType if !mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5632+
Some(mt, parent)
56085633
case _ =>
56095634
None
5610-
}
5611-
else if (tp isRef defn.PartialFunctionClass)
5612-
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5613-
// pretending it is a SAM type. In the future it would be better to merge
5614-
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5615-
// def isDefinedAt(x: T) = true
5616-
// and overwrite that method whenever the function body is a sequence of
5617-
// case clauses.
5618-
absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf[MethodType])
56195635
else None
5620-
}
56215636
else None
56225637
}
56235638

@@ -6450,6 +6465,37 @@ object Types {
64506465
}
64516466
}
64526467

6468+
object VarianceMap:
6469+
/** An immutable map representing the variance of keys of type `K` */
6470+
opaque type VarianceMap[K <: AnyRef] <: AnyRef = SimpleIdentityMap[K, Integer]
6471+
def empty[K <: AnyRef]: VarianceMap[K] = SimpleIdentityMap.empty[K]
6472+
extension [K <: AnyRef](vmap: VarianceMap[K])
6473+
/** The backing map used to implement this VarianceMap. */
6474+
inline def underlying: SimpleIdentityMap[K, Integer] = vmap
6475+
6476+
/** Return a new map taking into account that K appears in a
6477+
* {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6478+
*/
6479+
def recordLocalVariance(k: K, localVariance: Int): VarianceMap[K] =
6480+
val previousVariance = vmap(k)
6481+
if previousVariance == null then
6482+
vmap.updated(k, localVariance)
6483+
else if previousVariance == localVariance || previousVariance == 0 then
6484+
vmap
6485+
else
6486+
vmap.updated(k, 0)
6487+
6488+
/** Return the variance of `k`:
6489+
* - A positive value means that `k` appears only covariantly.
6490+
* - A negative value means that `k` appears only contravariantly.
6491+
* - A zero value means that `k` appears both covariantly and
6492+
* contravariantly, or appears invariantly.
6493+
* - A null value means that `k` does not appear at all.
6494+
*/
6495+
def computedVariance(k: K): Integer | Null =
6496+
vmap(k)
6497+
export VarianceMap.VarianceMap
6498+
64536499
// ----- Name Filters --------------------------------------------------
64546500

64556501
/** A name filter selects or discards a member name of a type `pre`.

‎compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class ExpandSAMs extends MiniPhase:
5050
tree // it's a plain function
5151
case tpe if defn.isContextFunctionType(tpe) =>
5252
tree
53-
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
53+
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
5454
val tpe1 = checkRefinements(tpe, fn)
5555
toPartialFunction(tree, tpe1)
56-
case tpe @ SAMType(_) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
56+
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
5757
checkRefinements(tpe, fn)
5858
tree
5959
case tpe =>

‎compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ trait Applications extends Compatibility {
696696

697697
def SAMargOK =
698698
defn.isFunctionNType(argtpe1) && formal.match
699-
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
699+
case SAMType(samMeth, samParent) => argtpe <:< samMeth.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined))
700700
case _ => false
701701

702702
isCompatible(argtpe, formal)
@@ -2074,7 +2074,7 @@ trait Applications extends Compatibility {
20742074
* new java.io.ObjectOutputStream(f)
20752075
*/
20762076
pt match {
2077-
case SAMType(mtp) =>
2077+
case SAMType(mtp, _) =>
20782078
narrowByTypes(alts, mtp.paramInfos, mtp.resultType)
20792079
case _ =>
20802080
// pick any alternatives that are not methods since these might be convertible

‎compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ object Inferencing {
407407
val vs = variances(tp)
408408
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
409409
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
410-
vs foreachBinding { (tvar, v) =>
410+
vs.underlying foreachBinding { (tvar, v) =>
411411
if !tvar.isInstantiated then
412412
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
413413
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
@@ -440,8 +440,6 @@ object Inferencing {
440440
res
441441
}
442442

443-
type VarianceMap = SimpleIdentityMap[TypeVar, Integer]
444-
445443
/** All occurrences of type vars in `tp` that satisfy predicate
446444
* `include` mapped to their variances (-1/0/1) in both `tp` and
447445
* `pt.finalResultType`, where
@@ -465,23 +463,18 @@ object Inferencing {
465463
*
466464
* we want to instantiate U to x.type right away. No need to wait further.
467465
*/
468-
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap = {
466+
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
469467
Stats.record("variances")
470468
val constraint = ctx.typerState.constraint
471469

472-
object accu extends TypeAccumulator[VarianceMap] {
470+
object accu extends TypeAccumulator[VarianceMap[TypeVar]]:
473471
def setVariance(v: Int) = variance = v
474-
def apply(vmap: VarianceMap, t: Type): VarianceMap = t match {
472+
def apply(vmap: VarianceMap[TypeVar], t: Type): VarianceMap[TypeVar] = t match
475473
case t: TypeVar
476474
if !t.isInstantiated && accCtx.typerState.constraint.contains(t) =>
477-
val v = vmap(t)
478-
if (v == null) vmap.updated(t, variance)
479-
else if (v == variance || v == 0) vmap
480-
else vmap.updated(t, 0)
475+
vmap.recordLocalVariance(t, variance)
481476
case _ =>
482477
foldOver(vmap, t)
483-
}
484-
}
485478

486479
/** Include in `vmap` type variables occurring in the constraints of type variables
487480
* already in `vmap`. Specifically:
@@ -493,10 +486,10 @@ object Inferencing {
493486
* bounds as non-variant.
494487
* Do this in a fixpoint iteration until `vmap` stabilizes.
495488
*/
496-
def propagate(vmap: VarianceMap): VarianceMap = {
489+
def propagate(vmap: VarianceMap[TypeVar]): VarianceMap[TypeVar] = {
497490
var vmap1 = vmap
498491
def traverse(tp: Type) = { vmap1 = accu(vmap1, tp) }
499-
vmap.foreachBinding { (tvar, v) =>
492+
vmap.underlying.foreachBinding { (tvar, v) =>
500493
val param = tvar.origin
501494
constraint.entry(param) match
502495
case TypeBounds(lo, hi) =>
@@ -512,7 +505,7 @@ object Inferencing {
512505
if (vmap1 eq vmap) vmap else propagate(vmap1)
513506
}
514507

515-
propagate(accu(accu(SimpleIdentityMap.empty, tp), pt.finalResultType))
508+
propagate(accu(accu(VarianceMap.empty, tp), pt.finalResultType))
516509
}
517510

518511
/** Run the transformation after dealiasing but return the original type if it was a no-op. */
@@ -638,7 +631,7 @@ trait Inferencing { this: Typer =>
638631
if !tvar.isInstantiated then
639632
// isInstantiated needs to be checked again, since previous interpolations could already have
640633
// instantiated `tvar` through unification.
641-
val v = vs(tvar)
634+
val v = vs.computedVariance(tvar)
642635
if v == null then buf += ((tvar, 0))
643636
else if v.intValue != 0 then buf += ((tvar, v.intValue))
644637
else comparing(cmp =>

‎compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
10491049
case _ => tree
10501050
}
10511051

1052-
10531052
def typedNamedArg(tree: untpd.NamedArg, pt: Type)(using Context): NamedArg = {
10541053
/* Special case for resolving types for arguments of an annotation defined in Java.
10551054
* It allows that value of any type T can appear in positions where Array[T] is expected.
@@ -1330,9 +1329,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13301329
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
13311330
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
13321331
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
1333-
case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) =>
1332+
case SAMType(mt @ MethodTpe(_, formals, _), samParent) =>
13341333
val restpe = mt.resultType match
1335-
case mt: MethodType => mt.toFunctionType(isJava = pt1.classSymbol.is(JavaDefined))
1334+
case mt: MethodType => mt.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined))
13361335
case tp => tp
13371336
(formals,
13381337
if (mt.isResultDependent)
@@ -1686,28 +1685,22 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16861685
meth1.tpe.widen match {
16871686
case mt: MethodType =>
16881687
pt.findFunctionType match {
1689-
case pt @ SAMType(sam)
1690-
if !defn.isFunctionNType(pt) && mt <:< sam =>
1688+
case SAMType(samMeth, samParent)
1689+
if !defn.isFunctionNType(samParent) && mt <:< samMeth =>
16911690
if defn.isContextFunctionType(mt.resultType) then
16921691
report.error(
1693-
em"""Implementation restriction: cannot convert this expression to `$pt`
1692+
em"""Implementation restriction: cannot convert this expression to `$samParent`
16941693
|because its result type `${mt.resultType}` is a contextual function type.""",
16951694
tree.srcPos)
1696-
1697-
// SAMs of the form C[?] where C is a class cannot be conversion targets.
1698-
// The resulting class `class $anon extends C[?] {...}` would be illegal,
1699-
// since type arguments to `C`'s super constructor cannot be constructed.
1700-
def isWildcardClassSAM =
1701-
!pt.classSymbol.is(Trait) && pt.argInfos.exists(_.isInstanceOf[TypeBounds])
17021695
val targetTpe =
1703-
if isFullyDefined(pt, ForceDegree.all) && !isWildcardClassSAM then
1704-
pt
1705-
else if pt.isRef(defn.PartialFunctionClass) then
1696+
if isFullyDefined(samParent, ForceDegree.all) then
1697+
samParent
1698+
else if samParent.isRef(defn.PartialFunctionClass) then
17061699
// Replace the underspecified expected type by one based on the closure method type
17071700
defn.PartialFunctionOf(mt.firstParamTypes.head, mt.resultType)
17081701
else
1709-
report.error(em"result type of lambda is an underspecified SAM type $pt", tree.srcPos)
1710-
pt
1702+
report.error(em"result type of lambda is an underspecified SAM type $samParent", tree.srcPos)
1703+
samParent
17111704
TypeTree(targetTpe)
17121705
case _ =>
17131706
if (mt.isParamDependent)
@@ -4000,8 +3993,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
40003993
else
40013994
if (!defn.isFunctionNType(pt))
40023995
pt match {
4003-
case SAMType(_) if !pt.classSymbol.hasAnnotation(defn.FunctionalInterfaceAnnot) =>
4004-
report.warning(em"${tree.symbol} is eta-expanded even though $pt does not have the @FunctionalInterface annotation.", tree.srcPos)
3996+
case SAMType(_, samParent) if !pt1.classSymbol.hasAnnotation(defn.FunctionalInterfaceAnnot) =>
3997+
report.warning(em"${tree.symbol} is eta-expanded even though $samParent does not have the @FunctionalInterface annotation.", tree.srcPos)
40053998
case _ =>
40063999
}
40074000
simplify(typed(etaExpand(tree, wtp, arity), pt), pt, locked)
@@ -4169,9 +4162,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
41694162
}
41704163
}
41714164

4172-
def toSAM(tree: Tree): Tree = tree match {
4173-
case tree: Block => tpd.cpy.Block(tree)(tree.stats, toSAM(tree.expr))
4174-
case tree: Closure => cpy.Closure(tree)(tpt = TypeTree(pt)).withType(pt)
4165+
def toSAM(tree: Tree, samParent: Type): Tree = tree match {
4166+
case tree: Block => tpd.cpy.Block(tree)(tree.stats, toSAM(tree.expr, samParent))
4167+
case tree: Closure => cpy.Closure(tree)(tpt = TypeTree(samParent)).withType(samParent)
41754168
}
41764169

41774170
def adaptToSubType(wtp: Type): Tree =
@@ -4210,13 +4203,13 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
42104203
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
42114204
if defn.isFunctionNType(wtp) && !defn.isFunctionNType(pt) =>
42124205
pt match {
4213-
case SAMType(sam)
4214-
if wtp <:< sam.toFunctionType(isJava = pt.classSymbol.is(JavaDefined)) =>
4206+
case SAMType(samMeth, samParent)
4207+
if wtp <:< samMeth.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined)) =>
42154208
// was ... && isFullyDefined(pt, ForceDegree.flipBottom)
42164209
// but this prevents case blocks from implementing polymorphic partial functions,
42174210
// since we do not know the result parameter a priori. Have to wait until the
42184211
// body is typechecked.
4219-
return toSAM(tree)
4212+
return toSAM(tree, samParent)
42204213
case _ =>
42214214
}
42224215
case _ =>

‎tests/neg/i15741.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
type IS = Int ?=> String
99

10-
def pf3: PartialFunction[String, IS] = {
10+
def pf3: PartialFunction[String, IS] = { // error
1111
case "hoge" => get
1212
case "huga" => get
1313
} // error

‎tests/neg/i8012.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ class C extends Q[?] // error: Type argument must be fully defined
99

1010
object O {
1111
def m(i: Int): Int = i
12-
val x: Q[_] = m // error: result type of lambda is an underspecified SAM type Q[?]
13-
}
12+
val x: Q[_] = m
13+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
object I0 { val i1: PartialFunction[_, Int] = { case i2 => i2 } }
1+
object I0 { val i1: PartialFunction[_, Any] = { case i2 => i2 } }

‎tests/pos/i18096.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
trait F1[-T1, +R] extends AnyRef { def apply(v1: T1): R }
2+
class R { def l: List[Any] = Nil }
3+
class S { def m[T](f: F1[R, ? <: List[T]]): S = this }
4+
class T1 { def t1(s: S) = s.m((r: R) => r.l) }

‎tests/run/i16065.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
trait Consumer1[T]:
2+
var x: Int = 1 // To force anonymous class generation
3+
def accept(x: T): Unit
4+
5+
trait Consumer2[T]:
6+
def accept(x: T): Unit
7+
8+
trait Producer1[T]:
9+
var x: Int = 1 // To force anonymous class generation
10+
def produce(x: Any): T
11+
12+
trait Producer2[T]:
13+
def produce(x: Any): T
14+
15+
trait ProdCons1[T]:
16+
var x: Int = 1 // To force anonymous class generation
17+
def apply(x: T): T
18+
19+
trait ProdCons2[T]:
20+
var x: Int = 1 // To force anonymous class generation
21+
def apply(x: T): T
22+
23+
object Test {
24+
def main(args: Array[String]): Unit = {
25+
val a1: Consumer1[? >: String] = x => ()
26+
a1.accept("foo")
27+
28+
val a2: Consumer2[? >: String] = x => ()
29+
a2.accept("foo")
30+
31+
val b1: Producer1[? <: String] = x => ""
32+
val bo1: String = b1.produce(1)
33+
34+
val b2: Producer2[? <: String] = x => ""
35+
val bo2: String = b2.produce(1)
36+
37+
val c1: ProdCons1[? <: String] = x => x
38+
val c2: ProdCons2[? <: String] = x => x
39+
// Can't do much with `c1` or `c2` but we should still pass Ycheck.
40+
}
41+
}

0 commit comments

Comments
 (0)
Please sign in to comment.