Skip to content

Adapt type parameters of typed eta expansion according to expected variances #950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 17, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ object Flags {
/** A bridge method. Set by Erasure */
final val Bridge = termFlag(34, "<bridge>")

/** All class attributes are fully defined */
final val FullyCompleted = typeFlag(34, "<fully-completed>")

/** Symbol is a Java varargs bridge */ // (needed?)
final val VBridge = termFlag(35, "<vbridge>") // TODO remove

Expand Down
2 changes: 1 addition & 1 deletion src/dotty/tools/dotc/core/Hashable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object Hashable {
trait Hashable {
import Hashable._

protected def hashSeed: Int = getClass.getSimpleName.hashCode
protected def hashSeed: Int = getClass.hashCode

protected final def finishHash(hashCode: Int, arity: Int): Int =
avoidNotCached(hashing.finalizeHash(hashCode, arity))
Expand Down
93 changes: 45 additions & 48 deletions src/dotty/tools/dotc/core/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,64 +240,61 @@ object Names {
/** Create a term name from the characters in cs[offset..offset+len-1].
* Assume they are already encoded.
*/
def termName(cs: Array[Char], offset: Int, len: Int): TermName = {
def termName(cs: Array[Char], offset: Int, len: Int): TermName = synchronized {
util.Stats.record("termName")
val h = hashValue(cs, offset, len) & (table.size - 1)

synchronized {

/** Make sure the capacity of the character array is at least `n` */
def ensureCapacity(n: Int) =
if (n > chrs.length) {
val newchrs = new Array[Char](chrs.length * 2)
chrs.copyToArray(newchrs)
chrs = newchrs
}

/** Enter characters into chrs array. */
def enterChars(): Unit = {
ensureCapacity(nc + len)
var i = 0
while (i < len) {
chrs(nc + i) = cs(offset + i)
i += 1
}
nc += len
/** Make sure the capacity of the character array is at least `n` */
def ensureCapacity(n: Int) =
if (n > chrs.length) {
val newchrs = new Array[Char](chrs.length * 2)
chrs.copyToArray(newchrs)
chrs = newchrs
}

/** Rehash chain of names */
def rehash(name: TermName): Unit =
if (name != null) {
val oldNext = name.next
val h = hashValue(chrs, name.start, name.length) & (table.size - 1)
name.next = table(h)
table(h) = name
rehash(oldNext)
}
/** Enter characters into chrs array. */
def enterChars(): Unit = {
ensureCapacity(nc + len)
var i = 0
while (i < len) {
chrs(nc + i) = cs(offset + i)
i += 1
}
nc += len
}

/** Make sure the hash table is large enough for the given load factor */
def incTableSize() = {
size += 1
if (size.toDouble / table.size > fillFactor) {
val oldTable = table
table = new Array[TermName](table.size * 2)
for (i <- 0 until oldTable.size) rehash(oldTable(i))
}
/** Rehash chain of names */
def rehash(name: TermName): Unit =
if (name != null) {
val oldNext = name.next
val h = hashValue(chrs, name.start, name.length) & (table.size - 1)
name.next = table(h)
table(h) = name
rehash(oldNext)
}

val next = table(h)
var name = next
while (name ne null) {
if (name.length == len && equals(name.start, cs, offset, len))
return name
name = name.next
/** Make sure the hash table is large enough for the given load factor */
def incTableSize() = {
size += 1
if (size.toDouble / table.size > fillFactor) {
val oldTable = table
table = new Array[TermName](table.size * 2)
for (i <- 0 until oldTable.size) rehash(oldTable(i))
}
name = new TermName(nc, len, next)
enterChars()
table(h) = name
incTableSize()
name
}

val next = table(h)
var name = next
while (name ne null) {
if (name.length == len && equals(name.start, cs, offset, len))
return name
name = name.next
}
name = new TermName(nc, len, next)
enterChars()
table(h) = name
incTableSize()
name
}

/** Create a type name from the characters in cs[offset..offset+len-1].
Expand Down
51 changes: 34 additions & 17 deletions src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ object SymDenotations {

/** is this symbol a trait representing a type lambda? */
final def isLambdaTrait(implicit ctx: Context): Boolean =
isClass && name.startsWith(tpnme.LambdaPrefix)
isClass && name.startsWith(tpnme.LambdaPrefix) && owner == defn.ScalaPackageClass

/** Is this symbol a package object or its module class? */
def isPackageObject(implicit ctx: Context): Boolean = {
Expand Down Expand Up @@ -1211,10 +1211,20 @@ object SymDenotations {

/** The denotation is fully completed: all attributes are fully defined.
* ClassDenotations compiled from source are first completed, then fully completed.
* Packages are never fully completed since members can be added at any time.
* @see Namer#ClassCompleter
*/
private def isFullyCompleted(implicit ctx: Context): Boolean =
isCompleted && classParents.nonEmpty
private def isFullyCompleted(implicit ctx: Context): Boolean = {
def isFullyCompletedRef(tp: TypeRef) = tp.denot match {
case d: ClassDenotation => d.isFullyCompleted
case _ => false
}
def testFullyCompleted =
if (classParents.isEmpty) !is(Package) && symbol.eq(defn.AnyClass)
else classParents.forall(isFullyCompletedRef)
flagsUNSAFE.is(FullyCompleted) ||
isCompleted && testFullyCompleted && { setFlag(FullyCompleted); true }
}

// ------ syncing inheritance-related info -----------------------------

Expand Down Expand Up @@ -1300,7 +1310,7 @@ object SymDenotations {
baseTypeRefValid = ctx.runId
}

private def computeBases(implicit ctx: Context): Unit = {
private def computeBases(implicit ctx: Context): (List[ClassSymbol], BitSet) = {
if (myBaseClasses eq Nil) throw CyclicReference(this)
myBaseClasses = Nil
val seen = new mutable.BitSet
Expand All @@ -1324,17 +1334,22 @@ object SymDenotations {
case nil =>
to
}
myBaseClasses = classSymbol :: addParentBaseClasses(classParents, Nil)
mySuperClassBits = seen.toImmutable
val bcs = classSymbol :: addParentBaseClasses(classParents, Nil)
val scbits = seen.toImmutable
if (isFullyCompleted) {
myBaseClasses = bcs
mySuperClassBits = scbits
}
else myBaseClasses = null
(bcs, scbits)
}

/** A bitset that contains the superId's of all base classes */
private def superClassBits(implicit ctx: Context): BitSet =
if (classParents.isEmpty) BitSet() // can happen when called too early in Namers
else {
checkBasesUpToDate()
if (mySuperClassBits == null) computeBases
mySuperClassBits
if (mySuperClassBits != null) mySuperClassBits else computeBases._2
}

/** The base classes of this class in linearization order,
Expand All @@ -1344,8 +1359,7 @@ object SymDenotations {
if (classParents.isEmpty) classSymbol :: Nil // can happen when called too early in Namers
else {
checkBasesUpToDate()
if (myBaseClasses == null) computeBases
myBaseClasses
if (myBaseClasses != null) myBaseClasses else computeBases._1
}

final override def derivesFrom(base: Symbol)(implicit ctx: Context): Boolean =
Expand Down Expand Up @@ -1378,9 +1392,9 @@ object SymDenotations {
while (ps.nonEmpty) {
val parent = ps.head.typeSymbol
parent.denot match {
case classd: ClassDenotation =>
fp.include(classd.memberFingerPrint)
parent.denot.setFlag(Frozen)
case parentDenot: ClassDenotation =>
fp.include(parentDenot.memberFingerPrint)
if (parentDenot.isFullyCompleted) parentDenot.setFlag(Frozen)
case _ =>
}
ps = ps.tail
Expand All @@ -1393,10 +1407,13 @@ object SymDenotations {
* not be used for package classes because cache never
* gets invalidated.
*/
def memberFingerPrint(implicit ctx: Context): FingerPrint = {
if (myMemberFingerPrint == FingerPrint.unknown) myMemberFingerPrint = computeMemberFingerPrint
myMemberFingerPrint
}
def memberFingerPrint(implicit ctx: Context): FingerPrint =
if (myMemberFingerPrint != FingerPrint.unknown) myMemberFingerPrint
else {
val fp = computeMemberFingerPrint
if (isFullyCompleted) myMemberFingerPrint = fp
fp
}

private[this] var myMemberCache: LRUCache[Name, PreDenotation] = null
private[this] var myMemberCachePeriod: Period = Nowhere
Expand Down
25 changes: 16 additions & 9 deletions src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -524,29 +524,36 @@ class TypeApplications(val self: Type) extends AnyVal {
}
}

/** Convert a type constructor `TC` with type parameters `T1, ..., Tn` to
/** Convert a type constructor `TC` which has type parameters `T1, ..., Tn`
* in a context where type parameters `U1,...,Un` are expected to
*
* LambdaXYZ { Apply = TC[hk$0, ..., hk$n] }
*
* where XYZ is a corresponds to the variances of the type parameters.
* Here, XYZ corresponds to the variances of
* - `U1,...,Un` if the variances of `T1,...,Tn` are pairwise compatible with `U1,...,Un`,
* - `T1,...,Tn` otherwise.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming that the "otherwise" case here only reached in ill-typed programs, and will result in an error being issued in adapt. Is that assumption right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can also appear in tryLifted, which gets called during a subtype check. Essentially, we compare an unexpanded hk type with a type lambda and go on expanding the unexpanded type. It seems we can get a discrepancy in variances there.

* v1 is compatible with v2, if v1 = v2 or v2 is non-variant.
*/
def EtaExpand(implicit ctx: Context): Type = {
val tparams = typeParams
self.appliedTo(tparams map (_.typeRef)).LambdaAbstract(tparams)
def EtaExpand(tparams: List[Symbol])(implicit ctx: Context): Type = {
def varianceCompatible(actual: Symbol, formal: Symbol) =
formal.variance == 0 || actual.variance == formal.variance
val tparamsToUse =
if (typeParams.corresponds(tparams)(varianceCompatible)) tparams else typeParams
self.appliedTo(tparams map (_.typeRef)).LambdaAbstract(tparamsToUse)
//.ensuring(res => res.EtaReduce =:= self, s"res = $res, core = ${res.EtaReduce}, self = $self, hc = ${res.hashCode}")
}

/** Eta expand if `bound` is a higher-kinded type */
def EtaExpandIfHK(bound: Type)(implicit ctx: Context): Type =
if (bound.isHK && !isHK && self.typeSymbol.isClass && typeParams.nonEmpty) EtaExpand
if (bound.isHK && !isHK && self.typeSymbol.isClass && typeParams.nonEmpty) EtaExpand(bound.typeParams)
else self

/** Eta expand the prefix in front of any refinements. */
def EtaExpandCore(implicit ctx: Context): Type = self.stripTypeVar match {
case self: RefinedType =>
self.derivedRefinedType(self.parent.EtaExpandCore, self.refinedName, self.refinedInfo)
case _ =>
self.EtaExpand
self.EtaExpand(self.typeParams)
}

/** If `self` is a (potentially partially instantiated) eta expansion of type T, return T,
Expand Down Expand Up @@ -645,7 +652,7 @@ class TypeApplications(val self: Type) extends AnyVal {
param2.variance == param2.variance || param2.variance == 0
if (classBounds.exists(tycon.derivesFrom(_)) &&
tycon.typeParams.corresponds(tparams)(variancesMatch)) {
val expanded = tycon.EtaExpand
val expanded = tycon.EtaExpand(tparams)
val lifted = (expanded /: targs) { (partialInst, targ) =>
val tparam = partialInst.typeParams.head
RefinedType(partialInst, tparam.name, targ.bounds.withVariance(tparam.variance))
Expand All @@ -659,7 +666,7 @@ class TypeApplications(val self: Type) extends AnyVal {
false
}
tparams.nonEmpty &&
(typeParams.nonEmpty && p(EtaExpand) ||
(typeParams.hasSameLengthAs(tparams) && p(EtaExpand(tparams)) ||
classBounds.nonEmpty && tryLift(self.baseClasses))
}
}
2 changes: 1 addition & 1 deletion src/dotty/tools/dotc/core/Uniques.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ object Uniques {
}

final class RefinedUniques extends HashSet[RefinedType](Config.initialUniquesCapacity) with Hashable {
override val hashSeed = "CachedRefinedType".hashCode // some types start life as CachedRefinedTypes, need to have same hash seed
override val hashSeed = classOf[CachedRefinedType].hashCode // some types start life as CachedRefinedTypes, need to have same hash seed
override def hash(x: RefinedType): Int = x.hash

private def findPrevious(h: Int, parent: Type, refinedName: Name, refinedInfo: Type): RefinedType = {
Expand Down
51 changes: 47 additions & 4 deletions src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ object Scala2Unpickler {
if (tsym.exists) tsym.setFlag(TypeParam)
else denot.enter(tparam, decls)
}
denot.info = ClassInfo(
denot.owner.thisType, denot.classSymbol, parentRefs, decls, ost) // more refined infowith parents
if (!(denot.flagsUNSAFE is JavaModule)) ensureConstructor(denot.symbol.asClass, decls)

val scalacCompanion = denot.classSymbol.scalacLinkedClass
Expand Down Expand Up @@ -151,6 +149,51 @@ object Scala2Unpickler {
denot.info = ClassInfo( // final info
denot.owner.thisType, denot.classSymbol, parentRefs, declsInRightOrder, ost)
}

/** Adapt arguments to type parameters so that variance of type lambda arguments
* agrees with variance of corresponding higherkinded type parameters. Example:
*
* class Companion[+CC[X]]
* Companion[List]
*
* with adaptArgs, this will expand to
*
* Companion[[X] => List[X]]
*
* instead of
*
* Companion[[+X] => List[X]]
*
* even though `List` is covariant. This adaptation is necessary to ignore conflicting
* variances in overriding members that have types of hk-type parameters such as `Companion[GenTraversable]`
* or `Companion[ListBuffer]`. Without the adaptation we would end up with
*
* Companion[[+X] => GenTraversable[X]]
* Companion[[X] => List[X]]
*
* and the second is not a subtype of the first. So if we have overridding memebrs of the two
* types we get an error.
*/
def adaptArgs(tparams: List[Symbol], args: List[Type])(implicit ctx: Context): List[Type] = tparams match {
case tparam :: tparams1 =>
val boundLambda = tparam.infoOrCompleter match {
case TypeBounds(_, hi) => hi.LambdaClass(forcing = false)
case _ => NoSymbol
}
def adaptArg(arg: Type): Type = arg match {
case arg: TypeRef if arg.symbol.isLambdaTrait =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this change, but the definition of isLambdaTrait:

    final def isLambdaTrait(implicit ctx: Context): Boolean =
      isClass && name.startsWith(tpnme.LambdaPrefix)

Looks prone to falsely matching:

scala> class `Lambda|`
defined class Lambda$bar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we should restrict this to members of the scala package.

assert(arg.symbol.typeParams.length == boundLambda.typeParams.length)
arg.prefix.select(boundLambda)
case arg: RefinedType =>
arg.derivedRefinedType(adaptArg(arg.parent), arg.refinedName, arg.refinedInfo)
case _ =>
arg
}
val arg = args.head
val adapted = if (boundLambda.exists) adaptArg(arg) else arg
adapted :: adaptArgs(tparams1, args.tail)
case nil => args
}
}

/** Unpickle symbol table information descending from a class and/or module root
Expand Down Expand Up @@ -723,8 +766,8 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
else TypeRef(pre, sym.name.asTypeName)
val args = until(end, readTypeRef)
if (sym == defn.ByNameParamClass2x) ExprType(args.head)
else if (args.isEmpty && sym.typeParams.nonEmpty) tycon.EtaExpand
else tycon.appliedTo(args)
else if (args.isEmpty && sym.typeParams.nonEmpty) tycon.EtaExpand(sym.typeParams)
else tycon.appliedTo(adaptArgs(sym.typeParams, args))
case TYPEBOUNDStpe =>
TypeBounds(readTypeRef(), readTypeRef())
case REFINEDtpe =>
Expand Down
Loading