@@ -21,7 +21,7 @@ import CheckRealizable._
21
21
import Variances .{Variance , setStructuralVariances , Invariant }
22
22
import typer .Nullables
23
23
import util .Stats ._
24
- import util .SimpleIdentitySet
24
+ import util .{ SimpleIdentityMap , SimpleIdentitySet }
25
25
import ast .tpd ._
26
26
import ast .TreeTypeMap
27
27
import printing .Texts ._
@@ -1746,7 +1746,7 @@ object Types {
1746
1746
t
1747
1747
case t if defn.isErasedFunctionType(t) =>
1748
1748
t
1749
- case t @ SAMType (_) =>
1749
+ case t @ SAMType (_, _ ) =>
1750
1750
t
1751
1751
case _ =>
1752
1752
NoType
@@ -5505,104 +5505,119 @@ object Types {
5505
5505
* A type is a SAM type if it is a reference to a class or trait, which
5506
5506
*
5507
5507
* - has a single abstract method with a method type (ExprType
5508
- * and PolyType not allowed!) whose result type is not an implicit function type
5509
- * and which is not marked inline.
5508
+ * and PolyType not allowed!) according to `possibleSamMethods`.
5510
5509
* - can be instantiated without arguments or with just () as argument.
5511
5510
*
5512
- * The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5513
- * type of the single abstract method.
5511
+ * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5512
+ * type of the single abstract method and `samParent` is a subtype of the matched
5513
+ * SAM type which has been stripped of wildcards to turn it into a valid parent
5514
+ * type.
5514
5515
*/
5515
5516
object SAMType {
5516
- def zeroParamClass (tp : Type )(using Context ): Type = tp match {
5517
+ /** If possible, return a type which is both a subtype of `origTp` and a type
5518
+ * application of `samClass` where none of the type arguments are
5519
+ * wildcards (thus making it a valid parent type), otherwise return
5520
+ * NoType.
5521
+ *
5522
+ * A wildcard in the original type will be replaced by its upper or lower bound in a way
5523
+ * that maximizes the number of possible implementations of `samMeth`. For example,
5524
+ * java.util.function defines an interface equivalent to:
5525
+ *
5526
+ * trait Function[T, R]:
5527
+ * def apply(t: T): R
5528
+ *
5529
+ * and it usually appears with wildcards to compensate for the lack of
5530
+ * definition-site variance in Java:
5531
+ *
5532
+ * (x => x.toInt): Function[? >: String, ? <: Int]
5533
+ *
5534
+ * When typechecking this lambda, we need to approximate the wildcards to find
5535
+ * a valid parent type for our lambda to extend. We can see that in `apply`,
5536
+ * `T` only appears contravariantly and `R` only appears covariantly, so by
5537
+ * minimizing the first parameter and maximizing the second, we maximize the
5538
+ * number of valid implementations of `apply` which lets us implement the lambda
5539
+ * with a closure equivalent to:
5540
+ *
5541
+ * new Function[String, Int] { def apply(x: String): Int = x.toInt }
5542
+ *
5543
+ * If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5544
+ * we arbitrarily pick the upper-bound.
5545
+ */
5546
+ def samParent (origTp : Type , samClass : Symbol , samMeth : Symbol )(using Context ): Type =
5547
+ val tp = origTp.baseType(samClass)
5548
+ if ! (tp <:< origTp) then NoType
5549
+ else tp match
5550
+ case tp @ AppliedType (tycon, args) if tp.hasWildcardArg =>
5551
+ val accu = new TypeAccumulator [VarianceMap [Symbol ]]:
5552
+ def apply (vmap : VarianceMap [Symbol ], t : Type ): VarianceMap [Symbol ] = t match
5553
+ case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) =>
5554
+ vmap.recordLocalVariance(tp.symbol, variance)
5555
+ case _ =>
5556
+ foldOver(vmap, t)
5557
+ val vmap = accu(VarianceMap .empty, samMeth.info)
5558
+ val tparams = tycon.typeParamSymbols
5559
+ val args1 = args.zipWithConserve(tparams):
5560
+ case (arg @ TypeBounds (lo, hi), tparam) =>
5561
+ val v = vmap.computedVariance(tparam)
5562
+ if v.uncheckedNN < 0 then lo
5563
+ else hi
5564
+ case (arg, _) => arg
5565
+ tp.derivedAppliedType(tycon, args1)
5566
+ case _ =>
5567
+ tp
5568
+
5569
+ def samClass (tp : Type )(using Context ): Symbol = tp match
5517
5570
case tp : ClassInfo =>
5518
- def zeroParams (tp : Type ): Boolean = tp.stripPoly match {
5571
+ def zeroParams (tp : Type ): Boolean = tp.stripPoly match
5519
5572
case mt : MethodType => mt.paramInfos.isEmpty && ! mt.resultType.isInstanceOf [MethodType ]
5520
5573
case et : ExprType => true
5521
5574
case _ => false
5522
- }
5523
- // `ContextFunctionN` does not have constructors
5524
- val ctor = tp.cls.primaryConstructor
5525
- if (! ctor.exists || zeroParams(ctor.info)) tp
5526
- else NoType
5575
+ val cls = tp.cls
5576
+ val validCtor =
5577
+ val ctor = cls.primaryConstructor
5578
+ // `ContextFunctionN` does not have constructors
5579
+ ! ctor.exists || zeroParams(ctor.info)
5580
+ val isInstantiable = ! cls.isOneOf(FinalOrSealed ) && (tp.appliedRef <:< tp.selfType)
5581
+ if validCtor && isInstantiable then tp.cls
5582
+ else NoSymbol
5527
5583
case tp : AppliedType =>
5528
- zeroParamClass (tp.superType)
5584
+ samClass (tp.superType)
5529
5585
case tp : TypeRef =>
5530
- zeroParamClass (tp.underlying)
5586
+ samClass (tp.underlying)
5531
5587
case tp : RefinedType =>
5532
- zeroParamClass (tp.underlying)
5588
+ samClass (tp.underlying)
5533
5589
case tp : TypeBounds =>
5534
- zeroParamClass (tp.underlying)
5590
+ samClass (tp.underlying)
5535
5591
case tp : TypeVar =>
5536
- zeroParamClass (tp.underlying)
5592
+ samClass (tp.underlying)
5537
5593
case tp : AnnotatedType =>
5538
- zeroParamClass(tp.underlying)
5539
- case _ =>
5540
- NoType
5541
- }
5542
- def isInstantiatable (tp : Type )(using Context ): Boolean = zeroParamClass(tp) match {
5543
- case cinfo : ClassInfo if ! cinfo.cls.isOneOf(FinalOrSealed ) =>
5544
- val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5545
- tp <:< selfType
5594
+ samClass(tp.underlying)
5546
5595
case _ =>
5547
- false
5548
- }
5549
- def unapply (tp : Type )(using Context ): Option [MethodType ] =
5550
- if (isInstantiatable(tp)) {
5551
- val absMems = tp.possibleSamMethods
5552
- if (absMems.size == 1 )
5553
- absMems.head.info match {
5554
- case mt : MethodType if ! mt.isParamDependent &&
5555
- mt.resultType.isValueTypeOrWildcard =>
5556
- val cls = tp.classSymbol
5557
-
5558
- // Given a SAM type such as:
5559
- //
5560
- // import java.util.function.Function
5561
- // Function[? >: String, ? <: Int]
5562
- //
5563
- // the single abstract method will have type:
5564
- //
5565
- // (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5566
- //
5567
- // which is not implementable outside of the scope of Function.
5568
- //
5569
- // To avoid this kind of issue, we approximate references to
5570
- // parameters of the SAM type by their bounds, this way in the
5571
- // above example we get:
5572
- //
5573
- // (x: String): Int
5574
- val approxParams = new ApproximatingTypeMap {
5575
- def apply (tp : Type ): Type = tp match {
5576
- case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) && tp.symbol.owner == cls =>
5577
- tp.info match {
5578
- case info : AliasingBounds =>
5579
- mapOver(info.alias)
5580
- case TypeBounds (lo, hi) =>
5581
- range(atVariance(- variance)(apply(lo)), apply(hi))
5582
- case _ =>
5583
- range(defn.NothingType , defn.AnyType ) // should happen only in error cases
5584
- }
5585
- case _ =>
5586
- mapOver(tp)
5587
- }
5588
- }
5589
- val approx =
5590
- if ctx.owner.isContainedIn(cls) then mt
5591
- else approxParams(mt).asInstanceOf [MethodType ]
5592
- Some (approx)
5596
+ NoSymbol
5597
+
5598
+ def unapply (tp : Type )(using Context ): Option [(MethodType , Type )] =
5599
+ val cls = samClass(tp)
5600
+ if cls.exists then
5601
+ val absMems =
5602
+ if tp.isRef(defn.PartialFunctionClass ) then
5603
+ // To maintain compatibility with 2.x, we treat PartialFunction specially,
5604
+ // pretending it is a SAM type. In the future it would be better to merge
5605
+ // Function and PartialFunction, have Function1 contain a isDefinedAt method
5606
+ // def isDefinedAt(x: T) = true
5607
+ // and overwrite that method whenever the function body is a sequence of
5608
+ // case clauses.
5609
+ List (defn.PartialFunction_apply )
5610
+ else
5611
+ tp.possibleSamMethods.map(_.symbol)
5612
+ if absMems.lengthCompare(1 ) == 0 then
5613
+ val samMethSym = absMems.head
5614
+ val parent = samParent(tp, cls, samMethSym)
5615
+ samMethSym.asSeenFrom(parent).info match
5616
+ case mt : MethodType if ! mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5617
+ Some (mt, parent)
5593
5618
case _ =>
5594
5619
None
5595
- }
5596
- else if (tp isRef defn.PartialFunctionClass )
5597
- // To maintain compatibility with 2.x, we treat PartialFunction specially,
5598
- // pretending it is a SAM type. In the future it would be better to merge
5599
- // Function and PartialFunction, have Function1 contain a isDefinedAt method
5600
- // def isDefinedAt(x: T) = true
5601
- // and overwrite that method whenever the function body is a sequence of
5602
- // case clauses.
5603
- absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf [MethodType ])
5604
5620
else None
5605
- }
5606
5621
else None
5607
5622
}
5608
5623
@@ -6435,6 +6450,37 @@ object Types {
6435
6450
}
6436
6451
}
6437
6452
6453
+ object VarianceMap :
6454
+ /** An immutable map representing the variance of keys of type `K` */
6455
+ opaque type VarianceMap [K <: AnyRef ] <: AnyRef = SimpleIdentityMap [K , Integer ]
6456
+ def empty [K <: AnyRef ]: VarianceMap [K ] = SimpleIdentityMap .empty[K ]
6457
+ extension [K <: AnyRef ](vmap : VarianceMap [K ])
6458
+ /** The backing map used to implement this VarianceMap. */
6459
+ inline def underlying : SimpleIdentityMap [K , Integer ] = vmap
6460
+
6461
+ /** Return a new map taking into account that K appears in a
6462
+ * {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6463
+ */
6464
+ def recordLocalVariance (k : K , localVariance : Int ): VarianceMap [K ] =
6465
+ val previousVariance = vmap(k)
6466
+ if previousVariance == null then
6467
+ vmap.updated(k, localVariance)
6468
+ else if previousVariance == localVariance || previousVariance == 0 then
6469
+ vmap
6470
+ else
6471
+ vmap.updated(k, 0 )
6472
+
6473
+ /** Return the variance of `k`:
6474
+ * - A positive value means that `k` appears only covariantly.
6475
+ * - A negative value means that `k` appears only contravariantly.
6476
+ * - A zero value means that `k` appears both covariantly and
6477
+ * contravariantly, or appears invariantly.
6478
+ * - A null value means that `k` does not appear at all.
6479
+ */
6480
+ def computedVariance (k : K ): Integer | Null =
6481
+ vmap(k)
6482
+ export VarianceMap .VarianceMap
6483
+
6438
6484
// ----- Name Filters --------------------------------------------------
6439
6485
6440
6486
/** A name filter selects or discards a member name of a type `pre`.
0 commit comments