Skip to content

Commit 0da5970

Browse files
committed
Properly handle SAM types with wildcards
When typing a closure with an expected type containing a wildcard, the closure type itself should not contain wildcards, because it might be expanded to an anonymous class extending the closure type (this happens on non-JVM backends as well as on the JVM itself in situations where a SAM trait does not compile down to a SAM interface). We were already approximating wildcards in the method type returned by the SAMType extractor, but to fix this issue we had to change the extractor to perform the approximation on the expected type itself to generate a valid parent type. The SAMType extractor now returns both the approximated parent type and the type of the method itself. The wildcard approximation analysis relies on a new `VarianceMap` opaque type extracted from Inferencing#variances. Fixes #16065. Fixes #18096.
1 parent fef1110 commit 0da5970

11 files changed

+209
-131
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ class Definitions {
744744
@tu lazy val StringContextModule_processEscapes: Symbol = StringContextModule.requiredMethod(nme.processEscapes)
745745

746746
@tu lazy val PartialFunctionClass: ClassSymbol = requiredClass("scala.PartialFunction")
747+
@tu lazy val PartialFunction_apply: Symbol = PartialFunctionClass.requiredMethod(nme.apply)
747748
@tu lazy val PartialFunction_isDefinedAt: Symbol = PartialFunctionClass.requiredMethod(nme.isDefinedAt)
748749
@tu lazy val PartialFunction_applyOrElse: Symbol = PartialFunctionClass.requiredMethod(nme.applyOrElse)
749750

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 128 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CheckRealizable._
2121
import Variances.{Variance, setStructuralVariances, Invariant}
2222
import typer.Nullables
2323
import util.Stats._
24-
import util.SimpleIdentitySet
24+
import util.{SimpleIdentityMap, SimpleIdentitySet}
2525
import ast.tpd._
2626
import ast.TreeTypeMap
2727
import printing.Texts._
@@ -1746,7 +1746,7 @@ object Types {
17461746
t
17471747
case t if defn.isErasedFunctionType(t) =>
17481748
t
1749-
case t @ SAMType(_) =>
1749+
case t @ SAMType(_, _) =>
17501750
t
17511751
case _ =>
17521752
NoType
@@ -5505,104 +5505,119 @@ object Types {
55055505
* A type is a SAM type if it is a reference to a class or trait, which
55065506
*
55075507
* - has a single abstract method with a method type (ExprType
5508-
* and PolyType not allowed!) whose result type is not an implicit function type
5509-
* and which is not marked inline.
5508+
* and PolyType not allowed!) according to `possibleSamMethods`.
55105509
* - can be instantiated without arguments or with just () as argument.
55115510
*
5512-
* The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5513-
* type of the single abstract method.
5511+
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5512+
* type of the single abstract method and `samParent` is a subtype of the matched
5513+
* SAM type which has been stripped of wildcards to turn it into a valid parent
5514+
* type.
55145515
*/
55155516
object SAMType {
5516-
def zeroParamClass(tp: Type)(using Context): Type = tp match {
5517+
/** If possible, return a type which is both a subtype of `origTp` and a type
5518+
* application of `samClass` where none of the type arguments are
5519+
* wildcards (thus making it a valid parent type), otherwise return
5520+
* NoType.
5521+
*
5522+
* A wildcard in the original type will be replaced by its upper or lower bound in a way
5523+
* that maximizes the number of possible implementations of `samMeth`. For example,
5524+
* java.util.function defines an interface equivalent to:
5525+
*
5526+
* trait Function[T, R]:
5527+
* def apply(t: T): R
5528+
*
5529+
* and it usually appears with wildcards to compensate for the lack of
5530+
* definition-site variance in Java:
5531+
*
5532+
* (x => x.toInt): Function[? >: String, ? <: Int]
5533+
*
5534+
* When typechecking this lambda, we need to approximate the wildcards to find
5535+
* a valid parent type for our lambda to extend. We can see that in `apply`,
5536+
* `T` only appears contravariantly and `R` only appears covariantly, so by
5537+
* minimizing the first parameter and maximizing the second, we maximize the
5538+
* number of valid implementations of `apply` which lets us implement the lambda
5539+
* with a closure equivalent to:
5540+
*
5541+
* new Function[String, Int] { def apply(x: String): Int = x.toInt }
5542+
*
5543+
* If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5544+
* we arbitrarily pick the upper-bound.
5545+
*/
5546+
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
5547+
val tp = origTp.baseType(samClass)
5548+
if !(tp <:< origTp) then NoType
5549+
else tp match
5550+
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5551+
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5552+
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5553+
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5554+
vmap.recordLocalVariance(tp.symbol, variance)
5555+
case _ =>
5556+
foldOver(vmap, t)
5557+
val vmap = accu(VarianceMap.empty, samMeth.info)
5558+
val tparams = tycon.typeParamSymbols
5559+
val args1 = args.zipWithConserve(tparams):
5560+
case (arg @ TypeBounds(lo, hi), tparam) =>
5561+
val v = vmap.computedVariance(tparam)
5562+
if v.uncheckedNN < 0 then lo
5563+
else hi
5564+
case (arg, _) => arg
5565+
tp.derivedAppliedType(tycon, args1)
5566+
case _ =>
5567+
tp
5568+
5569+
def samClass(tp: Type)(using Context): Symbol = tp match
55175570
case tp: ClassInfo =>
5518-
def zeroParams(tp: Type): Boolean = tp.stripPoly match {
5571+
def zeroParams(tp: Type): Boolean = tp.stripPoly match
55195572
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
55205573
case et: ExprType => true
55215574
case _ => false
5522-
}
5523-
// `ContextFunctionN` does not have constructors
5524-
val ctor = tp.cls.primaryConstructor
5525-
if (!ctor.exists || zeroParams(ctor.info)) tp
5526-
else NoType
5575+
val cls = tp.cls
5576+
val validCtor =
5577+
val ctor = cls.primaryConstructor
5578+
// `ContextFunctionN` does not have constructors
5579+
!ctor.exists || zeroParams(ctor.info)
5580+
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
5581+
if validCtor && isInstantiable then tp.cls
5582+
else NoSymbol
55275583
case tp: AppliedType =>
5528-
zeroParamClass(tp.superType)
5584+
samClass(tp.superType)
55295585
case tp: TypeRef =>
5530-
zeroParamClass(tp.underlying)
5586+
samClass(tp.underlying)
55315587
case tp: RefinedType =>
5532-
zeroParamClass(tp.underlying)
5588+
samClass(tp.underlying)
55335589
case tp: TypeBounds =>
5534-
zeroParamClass(tp.underlying)
5590+
samClass(tp.underlying)
55355591
case tp: TypeVar =>
5536-
zeroParamClass(tp.underlying)
5592+
samClass(tp.underlying)
55375593
case tp: AnnotatedType =>
5538-
zeroParamClass(tp.underlying)
5539-
case _ =>
5540-
NoType
5541-
}
5542-
def isInstantiatable(tp: Type)(using Context): Boolean = zeroParamClass(tp) match {
5543-
case cinfo: ClassInfo if !cinfo.cls.isOneOf(FinalOrSealed) =>
5544-
val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5545-
tp <:< selfType
5594+
samClass(tp.underlying)
55465595
case _ =>
5547-
false
5548-
}
5549-
def unapply(tp: Type)(using Context): Option[MethodType] =
5550-
if (isInstantiatable(tp)) {
5551-
val absMems = tp.possibleSamMethods
5552-
if (absMems.size == 1)
5553-
absMems.head.info match {
5554-
case mt: MethodType if !mt.isParamDependent &&
5555-
mt.resultType.isValueTypeOrWildcard =>
5556-
val cls = tp.classSymbol
5557-
5558-
// Given a SAM type such as:
5559-
//
5560-
// import java.util.function.Function
5561-
// Function[? >: String, ? <: Int]
5562-
//
5563-
// the single abstract method will have type:
5564-
//
5565-
// (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5566-
//
5567-
// which is not implementable outside of the scope of Function.
5568-
//
5569-
// To avoid this kind of issue, we approximate references to
5570-
// parameters of the SAM type by their bounds, this way in the
5571-
// above example we get:
5572-
//
5573-
// (x: String): Int
5574-
val approxParams = new ApproximatingTypeMap {
5575-
def apply(tp: Type): Type = tp match {
5576-
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) && tp.symbol.owner == cls =>
5577-
tp.info match {
5578-
case info: AliasingBounds =>
5579-
mapOver(info.alias)
5580-
case TypeBounds(lo, hi) =>
5581-
range(atVariance(-variance)(apply(lo)), apply(hi))
5582-
case _ =>
5583-
range(defn.NothingType, defn.AnyType) // should happen only in error cases
5584-
}
5585-
case _ =>
5586-
mapOver(tp)
5587-
}
5588-
}
5589-
val approx =
5590-
if ctx.owner.isContainedIn(cls) then mt
5591-
else approxParams(mt).asInstanceOf[MethodType]
5592-
Some(approx)
5596+
NoSymbol
5597+
5598+
def unapply(tp: Type)(using Context): Option[(MethodType, Type)] =
5599+
val cls = samClass(tp)
5600+
if cls.exists then
5601+
val absMems =
5602+
if tp.isRef(defn.PartialFunctionClass) then
5603+
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5604+
// pretending it is a SAM type. In the future it would be better to merge
5605+
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5606+
// def isDefinedAt(x: T) = true
5607+
// and overwrite that method whenever the function body is a sequence of
5608+
// case clauses.
5609+
List(defn.PartialFunction_apply)
5610+
else
5611+
tp.possibleSamMethods.map(_.symbol)
5612+
if absMems.lengthCompare(1) == 0 then
5613+
val samMethSym = absMems.head
5614+
val parent = samParent(tp, cls, samMethSym)
5615+
samMethSym.asSeenFrom(parent).info match
5616+
case mt: MethodType if !mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5617+
Some(mt, parent)
55935618
case _ =>
55945619
None
5595-
}
5596-
else if (tp isRef defn.PartialFunctionClass)
5597-
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5598-
// pretending it is a SAM type. In the future it would be better to merge
5599-
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5600-
// def isDefinedAt(x: T) = true
5601-
// and overwrite that method whenever the function body is a sequence of
5602-
// case clauses.
5603-
absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf[MethodType])
56045620
else None
5605-
}
56065621
else None
56075622
}
56085623

@@ -6435,6 +6450,37 @@ object Types {
64356450
}
64366451
}
64376452

6453+
object VarianceMap:
6454+
/** An immutable map representing the variance of keys of type `K` */
6455+
opaque type VarianceMap[K <: AnyRef] <: AnyRef = SimpleIdentityMap[K, Integer]
6456+
def empty[K <: AnyRef]: VarianceMap[K] = SimpleIdentityMap.empty[K]
6457+
extension [K <: AnyRef](vmap: VarianceMap[K])
6458+
/** The backing map used to implement this VarianceMap. */
6459+
inline def underlying: SimpleIdentityMap[K, Integer] = vmap
6460+
6461+
/** Return a new map taking into account that K appears in a
6462+
* {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6463+
*/
6464+
def recordLocalVariance(k: K, localVariance: Int): VarianceMap[K] =
6465+
val previousVariance = vmap(k)
6466+
if previousVariance == null then
6467+
vmap.updated(k, localVariance)
6468+
else if previousVariance == localVariance || previousVariance == 0 then
6469+
vmap
6470+
else
6471+
vmap.updated(k, 0)
6472+
6473+
/** Return the variance of `k`:
6474+
* - A positive value means that `k` appears only covariantly.
6475+
* - A negative value means that `k` appears only contravariantly.
6476+
* - A zero value means that `k` appears both covariantly and
6477+
* contravariantly, or appears invariantly.
6478+
* - A null value means that `k` does not appear at all.
6479+
*/
6480+
def computedVariance(k: K): Integer | Null =
6481+
vmap(k)
6482+
export VarianceMap.VarianceMap
6483+
64386484
// ----- Name Filters --------------------------------------------------
64396485

64406486
/** A name filter selects or discards a member name of a type `pre`.

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class ExpandSAMs extends MiniPhase:
5050
tree // it's a plain function
5151
case tpe if defn.isContextFunctionType(tpe) =>
5252
tree
53-
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
53+
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
5454
val tpe1 = checkRefinements(tpe, fn)
5555
toPartialFunction(tree, tpe1)
56-
case tpe @ SAMType(_) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
56+
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
5757
checkRefinements(tpe, fn)
5858
tree
5959
case tpe =>

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ trait Applications extends Compatibility {
696696

697697
def SAMargOK =
698698
defn.isFunctionType(argtpe1) && formal.match
699-
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
699+
case SAMType(samMeth, samParent) => argtpe <:< samMeth.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined))
700700
case _ => false
701701

702702
isCompatible(argtpe, formal)
@@ -2074,7 +2074,7 @@ trait Applications extends Compatibility {
20742074
* new java.io.ObjectOutputStream(f)
20752075
*/
20762076
pt match {
2077-
case SAMType(mtp) =>
2077+
case SAMType(mtp, _) =>
20782078
narrowByTypes(alts, mtp.paramInfos, mtp.resultType)
20792079
case _ =>
20802080
// pick any alternatives that are not methods since these might be convertible

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ object Inferencing {
407407
val vs = variances(tp)
408408
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
409409
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
410-
vs foreachBinding { (tvar, v) =>
410+
vs.underlying foreachBinding { (tvar, v) =>
411411
if !tvar.isInstantiated then
412412
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
413413
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
@@ -440,8 +440,6 @@ object Inferencing {
440440
res
441441
}
442442

443-
type VarianceMap = SimpleIdentityMap[TypeVar, Integer]
444-
445443
/** All occurrences of type vars in `tp` that satisfy predicate
446444
* `include` mapped to their variances (-1/0/1) in both `tp` and
447445
* `pt.finalResultType`, where
@@ -465,23 +463,18 @@ object Inferencing {
465463
*
466464
* we want to instantiate U to x.type right away. No need to wait further.
467465
*/
468-
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap = {
466+
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
469467
Stats.record("variances")
470468
val constraint = ctx.typerState.constraint
471469

472-
object accu extends TypeAccumulator[VarianceMap] {
470+
object accu extends TypeAccumulator[VarianceMap[TypeVar]]:
473471
def setVariance(v: Int) = variance = v
474-
def apply(vmap: VarianceMap, t: Type): VarianceMap = t match {
472+
def apply(vmap: VarianceMap[TypeVar], t: Type): VarianceMap[TypeVar] = t match
475473
case t: TypeVar
476474
if !t.isInstantiated && accCtx.typerState.constraint.contains(t) =>
477-
val v = vmap(t)
478-
if (v == null) vmap.updated(t, variance)
479-
else if (v == variance || v == 0) vmap
480-
else vmap.updated(t, 0)
475+
vmap.recordLocalVariance(t, variance)
481476
case _ =>
482477
foldOver(vmap, t)
483-
}
484-
}
485478

486479
/** Include in `vmap` type variables occurring in the constraints of type variables
487480
* already in `vmap`. Specifically:
@@ -493,10 +486,10 @@ object Inferencing {
493486
* bounds as non-variant.
494487
* Do this in a fixpoint iteration until `vmap` stabilizes.
495488
*/
496-
def propagate(vmap: VarianceMap): VarianceMap = {
489+
def propagate(vmap: VarianceMap[TypeVar]): VarianceMap[TypeVar] = {
497490
var vmap1 = vmap
498491
def traverse(tp: Type) = { vmap1 = accu(vmap1, tp) }
499-
vmap.foreachBinding { (tvar, v) =>
492+
vmap.underlying.foreachBinding { (tvar, v) =>
500493
val param = tvar.origin
501494
constraint.entry(param) match
502495
case TypeBounds(lo, hi) =>
@@ -512,7 +505,7 @@ object Inferencing {
512505
if (vmap1 eq vmap) vmap else propagate(vmap1)
513506
}
514507

515-
propagate(accu(accu(SimpleIdentityMap.empty, tp), pt.finalResultType))
508+
propagate(accu(accu(VarianceMap.empty, tp), pt.finalResultType))
516509
}
517510

518511
/** Run the transformation after dealiasing but return the original type if it was a no-op. */
@@ -638,7 +631,7 @@ trait Inferencing { this: Typer =>
638631
if !tvar.isInstantiated then
639632
// isInstantiated needs to be checked again, since previous interpolations could already have
640633
// instantiated `tvar` through unification.
641-
val v = vs(tvar)
634+
val v = vs.computedVariance(tvar)
642635
if v == null then buf += ((tvar, 0))
643636
else if v.intValue != 0 then buf += ((tvar, v.intValue))
644637
else comparing(cmp =>

0 commit comments

Comments
 (0)