Skip to content

Remove ProductN parent on case classes #2314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 7 additions & 42 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ object desugar {

/** Names of methods that are added unconditionally to case classes */
def isDesugaredCaseClassMethodName(name: Name)(implicit ctx: Context): Boolean =
name == nme.copy ||
name == nme.productArity ||
name.isSelectorName
name == nme.copy || name.isSelectorName

// ----- DerivedTypeTrees -----------------------------------

Expand Down Expand Up @@ -291,7 +289,8 @@ object desugar {
case _ => false
}

val isCaseClass = mods.is(Case) && !mods.is(Module)
val isCaseClass = mods.is(Case) && !mods.is(Module)
val isCaseObject = mods.is(Case) && mods.is(Module)
val isEnum = mods.hasMod[Mod.Enum]
val isEnumCase = isLegalEnumCase(cdef)
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
Expand Down Expand Up @@ -360,31 +359,12 @@ object desugar {
// pN: TN = pN: @uncheckedVariance)(moreParams) =
// new C[...](p1, ..., pN)(moreParams)
//
// Above arity 22 we also synthesize:
// def productArity = N
// def productElement(i: Int): Any = i match { ... }
//
// Note: copy default parameters need @uncheckedVariance; see
// neg/t1843-variances.scala for a test case. The test would give
// two errors without @uncheckedVariance, one of them spurious.
val caseClassMeths = {
def syntheticProperty(name: TermName, rhs: Tree) =
DefDef(name, Nil, Nil, TypeTree(), rhs).withMods(synthetic)
def productArity = syntheticProperty(nme.productArity, Literal(Constant(arity)))
def productElement = {
val param = makeSyntheticParameter(tpt = ref(defn.IntType))
// case N => _${N + 1}
val cases = 0.until(arity).map { i =>
CaseDef(Literal(Constant(i)), EmptyTree, Select(This(EmptyTypeIdent), nme.selectorName(i)))
}
val ioob = ref(defn.IndexOutOfBoundsException.typeRef)
val error = Throw(New(ioob, List(List(Select(refOfDef(param), nme.toString_)))))
// case _ => throw new IndexOutOfBoundsException(i.toString)
val defaultCase = CaseDef(untpd.Ident(nme.WILDCARD), EmptyTree, error)
val body = Match(refOfDef(param), (cases :+ defaultCase).toList)
DefDef(nme.productElement, Nil, List(List(param)), TypeTree(defn.AnyType), body)
.withMods(synthetic)
}
def productElemMeths = {
val caseParams = constrVparamss.head.toArray
for (i <- 0 until arity if nme.selectorName(i) `ne` caseParams(i).name)
Expand Down Expand Up @@ -414,33 +394,19 @@ object desugar {
}
}

// Above MaxTupleArity we extend Product instead of ProductN, in this
// case we need to synthesise productElement & productArity.
def largeProductMeths =
if (arity > Definitions.MaxTupleArity) productElement :: productArity :: Nil
else Nil

if (isCaseClass)
largeProductMeths ::: copyMeths ::: enumTagMeths ::: productElemMeths.toList
copyMeths ::: enumTagMeths ::: productElemMeths.toList
else Nil
}

def anyRef = ref(defn.AnyRefAlias.typeRef)
def productConstr(n: Int) = {
val tycon = scalaDot((str.Product + n).toTypeName)
val targs = constrVparamss.head map (_.tpt)
if (targs.isEmpty) tycon else AppliedTypeTree(tycon, targs)
}
def product =
if (arity > Definitions.MaxTupleArity) scalaDot(str.Product.toTypeName)
else productConstr(arity)

// Case classes and case objects get Product/ProductN parents
// Case classes and case objects get Product parents
var parents1 = parents
if (isEnumCase && parents.isEmpty)
parents1 = enumClassTypeRef :: Nil
if (mods.is(Case))
parents1 = parents1 :+ product // TODO: This also adds Product0 to case objects. Do we want that?
if (isCaseClass | isCaseObject)
parents1 = parents1 :+ scalaDot(str.Product.toTypeName)
if (isEnum)
parents1 = parents1 :+ ref(defn.EnumType)

Expand Down Expand Up @@ -499,7 +465,6 @@ object desugar {
companionDefs(anyRef, Nil)
else Nil


// For an implicit class C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, .., pMN: TMN), the method
// synthetic implicit C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, ..., pMN: TMN): C[Ts] =
// new C[Ts](p11, ..., p1N) ... (pM1, ..., pMN) =
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ class Definitions {
def Product_canEqual(implicit ctx: Context) = Product_canEqualR.symbol
lazy val Product_productArityR = ProductClass.requiredMethodRef(nme.productArity)
def Product_productArity(implicit ctx: Context) = Product_productArityR.symbol
lazy val Product_productElementR = ProductClass.requiredMethodRef(nme.productElement)
def Product_productElement(implicit ctx: Context) = Product_productElementR.symbol
lazy val Product_productPrefixR = ProductClass.requiredMethodRef(nme.productPrefix)
def Product_productPrefix(implicit ctx: Context) = Product_productPrefixR.symbol
lazy val LanguageModuleRef = ctx.requiredModule("scala.language")
Expand Down Expand Up @@ -702,7 +704,6 @@ class Definitions {
def FunctionClassPerRun = new PerRun[Array[Symbol]](implicit ctx => ImplementedFunctionType.map(_.symbol.asClass))

lazy val TupleType = mkArityArray("scala.Tuple", MaxTupleArity, 2)
lazy val ProductNType = mkArityArray("scala.Product", MaxTupleArity, 0)

def FunctionClass(n: Int, isImplicit: Boolean = false)(implicit ctx: Context) =
if (isImplicit) ctx.requiredClass("scala.ImplicitFunction" + n.toString)
Expand All @@ -717,7 +718,6 @@ class Definitions {
else FunctionClass(n, isImplicit).typeRef

private lazy val TupleTypes: Set[TypeRef] = TupleType.toSet
private lazy val ProductTypes: Set[TypeRef] = ProductNType.toSet

/** If `cls` is a class in the scala package, its name, otherwise EmptyTypeName */
def scalaClassName(cls: Symbol)(implicit ctx: Context): TypeName =
Expand Down
112 changes: 84 additions & 28 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@ import scala.language.postfixOps

/** Synthetic method implementations for case classes, case objects,
* and value classes.
*
* Selectively added to case classes/objects, unless a non-default
* implementation already exists:
* def equals(other: Any): Boolean
* def hashCode(): Int
* def canEqual(other: Any): Boolean
* def toString(): String
* def productElement(i: Int): Any
* def productArity: Int
* def productPrefix: String
*
* Special handling:
* protected def readResolve(): AnyRef
*
* Selectively added to value classes, unless a non-default
* implementation already exists:
*
* def equals(other: Any): Boolean
* def hashCode(): Int
*/
Expand All @@ -44,14 +46,13 @@ class SyntheticMethods(thisTransformer: DenotTransformer) {
if (myValueSymbols.isEmpty) {
myValueSymbols = List(defn.Any_hashCode, defn.Any_equals)
myCaseSymbols = myValueSymbols ++ List(defn.Any_toString, defn.Product_canEqual,
defn.Product_productArity, defn.Product_productPrefix)
defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement)
}

def valueSymbols(implicit ctx: Context) = { initSymbols; myValueSymbols }
def caseSymbols(implicit ctx: Context) = { initSymbols; myCaseSymbols }

/** The synthetic methods of the case or value class `clazz`.
*/
/** The synthetic methods of the case or value class `clazz`. */
def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
val clazzType = clazz.typeRef
lazy val accessors =
Expand Down Expand Up @@ -91,25 +92,68 @@ class SyntheticMethods(thisTransformer: DenotTransformer) {
case nme.canEqual_ => vrefss => canEqualBody(vrefss.head.head)
case nme.productArity => vrefss => Literal(Constant(accessors.length))
case nme.productPrefix => ownName
case nme.productElement => vrefss => productElementBody(accessors.length, vrefss.head.head)
}
ctx.log(s"adding $synthetic to $clazz at ${ctx.phase}")
DefDef(synthetic, syntheticRHS(ctx.withOwner(synthetic)))
}

/** The class
*
* case class C(x: T, y: U)
* ```
* case class C(x: T, y: T)
* ```
*
* gets the `productElement` method:
*
* ```
* def productElement(index: Int): Any = index match {
* case 0 => this._1
* case 1 => this._2
* case _ => throw new IndexOutOfBoundsException(index.toString)
* }
* ```
*/
def productElementBody(arity: Int, index: Tree)(implicit ctx: Context): Tree = {
val ioob = defn.IndexOutOfBoundsException.typeRef
// Second constructor of ioob that takes a String argument
def filterStringConstructor(s: Symbol): Boolean = s.info match {
case m: MethodType if s.isConstructor => m.paramInfos == List(defn.StringType)
case _ => false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's an easier way to get this, but I'll play with it and do it in a different PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IndexOutOfBoundsException contains two declarations, its primary constructor (which takes not argument), and this other constructor that takes a String. I first did .decls.toList.tail.head, then changed it to this very precise filter, do you have something in between?

}
val constructor = ioob.typeSymbol.info.decls.find(filterStringConstructor _).asTerm
val stringIndex = Apply(Select(index, nme.toString_), Nil)
val error = Throw(New(ioob, constructor, List(stringIndex)))

// case _ => throw new IndexOutOfBoundsException(i.toString)
val defaultCase = CaseDef(Underscore(defn.IntType), EmptyTree, error)

// case N => _${N + 1}
val cases = 0.until(arity).map { i =>
CaseDef(Literal(Constant(i)), EmptyTree, Select(This(clazz), nme.selectorName(i)))
}

Match(index, (cases :+ defaultCase).toList)
}

/** The class
*
* ```
* case class C(x: T, y: U)
* ```
*
* gets the equals method:
* gets the `equals` method:
*
* def equals(that: Any): Boolean =
* (this eq that) || {
* that match {
* case x$0 @ (_: C) => this.x == this$0.x && this.y == x$0.y
* case _ => false
* }
* ```
* def equals(that: Any): Boolean =
* (this eq that) || {
* that match {
* case x$0 @ (_: C) => this.x == this$0.x && this.y == x$0.y
* case _ => false
* }
* ```
*
* If C is a value class the initial `eq` test is omitted.
* If `C` is a value class the initial `eq` test is omitted.
*/
def equalsBody(that: Tree)(implicit ctx: Context): Tree = {
val thatAsClazz = ctx.newSymbol(ctx.owner, nme.x_0, Synthetic, clazzType, coord = ctx.owner.pos) // x$0
Expand All @@ -131,11 +175,15 @@ class SyntheticMethods(thisTransformer: DenotTransformer) {

/** The class
*
* ```
* class C(x: T) extends AnyVal
* ```
*
* gets the hashCode method:
* gets the `hashCode` method:
*
* def hashCode: Int = x.hashCode()
* ```
* def hashCode: Int = x.hashCode()
* ```
*/
def valueHashCodeBody(implicit ctx: Context): Tree = {
assert(accessors.length == 1)
Expand All @@ -144,17 +192,21 @@ class SyntheticMethods(thisTransformer: DenotTransformer) {

/** The class
*
* package p
* case class C(x: T, y: T)
* ```
* package p
* case class C(x: T, y: T)
* ```
*
* gets the hashCode method:
* gets the `hashCode` method:
*
* def hashCode: Int = {
* <synthetic> var acc: Int = "p.C".hashCode // constant folded
* acc = Statics.mix(acc, x);
* acc = Statics.mix(acc, Statics.this.anyHash(y));
* Statics.finalizeHash(acc, 2)
* }
* ```
* def hashCode: Int = {
* <synthetic> var acc: Int = "p.C".hashCode // constant folded
* acc = Statics.mix(acc, x);
* acc = Statics.mix(acc, Statics.this.anyHash(y));
* Statics.finalizeHash(acc, 2)
* }
* ```
*/
def caseHashCodeBody(implicit ctx: Context): Tree = {
val acc = ctx.newSymbol(ctx.owner, "acc".toTermName, Mutable | Synthetic, defn.IntType, coord = ctx.owner.pos)
Expand All @@ -165,7 +217,7 @@ class SyntheticMethods(thisTransformer: DenotTransformer) {
Block(accDef :: mixes, finish)
}

/** The hashCode implementation for given symbol `sym`. */
/** The `hashCode` implementation for given symbol `sym`. */
def hashImpl(sym: Symbol)(implicit ctx: Context): Tree =
defn.scalaClassName(sym.info.finalResultType) match {
case tpnme.Unit | tpnme.Null => Literal(Constant(0))
Expand All @@ -180,11 +232,15 @@ class SyntheticMethods(thisTransformer: DenotTransformer) {

/** The class
*
* case class C(...)
* ```
* case class C(...)
* ```
*
* gets the canEqual method
* gets the `canEqual` method
*
* def canEqual(that: Any) = that.isInstanceOf[C]
* ```
* def canEqual(that: Any) = that.isInstanceOf[C]
* ```
*/
def canEqualBody(that: Tree): Tree = that.isInstance(clazzType)

Expand Down
2 changes: 1 addition & 1 deletion tests/neg/t1843-variances.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

object Crash {
trait UpdateType[A]
case class StateUpdate[+A](updateType : UpdateType[A], value : A) // error
case class StateUpdate[+A](updateType : UpdateType[A], value : A) // error // error
case object IntegerUpdateType extends UpdateType[Integer]

//However this method will cause a crash
Expand Down
37 changes: 37 additions & 0 deletions tests/run/i2314.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
case class A(i: Int, s: String)

case class B(i: Int, s: String) {
// No override, these methods will be added by SyntheticMethods only if
// there are not user defined.
def productArity = -1
def productElement(i: Int): Any = None
}

object Test {
def main(args: Array[String]): Unit = {
val a = A(1, "s")
assert(a.productArity == 2)
assert(a.productElement(0) == 1)
assert(a.productElement(1) == "s")

try {
a.productElement(-1)
???
} catch {
case e: IndexOutOfBoundsException =>
assert(e.getMessage == "-1")
}
try {
a.productElement(2)
???
} catch {
case e: IndexOutOfBoundsException =>
assert(e.getMessage == "2")
}

val b = B(1, "s")
assert(b.productArity == -1)
assert(b.productElement(0) == None)
assert(b.productElement(1) == None)
}
}