Skip to content

Commit 7c38fe8

Browse files
committed
Special case singletons
This leads to better generated code. Downside is a slight increase in the code of the deriving typeclasses.
1 parent f7d2658 commit 7c38fe8

File tree

2 files changed

+100
-87
lines changed

2 files changed

+100
-87
lines changed

tests/run/typeclass-derivation2c.check

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ListBuffer(0, 0, 11, 0, 22, 0, 33, 1, 0, 0, 11, 0, 22, 1, 1)
44
Cons(Cons(11,Cons(22,Cons(33,Nil))),Cons(Cons(11,Cons(22,Nil)),Nil))
55
ListBuffer(1, 2)
66
Pair(1,2)
7-
Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil())))
8-
Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil()))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil())), tl = Nil()))
9-
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil()))
10-
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil()))
7+
Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil)))
8+
Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil)), tl = Nil))
9+
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil))
10+
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil))

tests/run/typeclass-derivation2c.scala

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -27,32 +27,38 @@ object Deriving {
2727
* enums, case classes and objects, and their sealed parents.
2828
*/
2929
sealed abstract class Generic[T]
30+
object Generic {
3031

31-
/** The Generic for a sum type */
32-
abstract class GenericSum[T] extends Generic[T] {
32+
/** The Generic for a sum type */
33+
abstract class Sum[T] extends Generic[T] {
3334

34-
/** The ordinal number of the case class of `x`. For enums, `ordinal(x) == x.ordinal` */
35-
def ordinal(x: T): Int
35+
/** The ordinal number of the case class of `x`. For enums, `ordinal(x) == x.ordinal` */
36+
def ordinal(x: T): Int
3637

37-
/** The number of cases in the sum.
38-
* Implemented by an inline method in concrete subclasses.
39-
*/
40-
erased def numberOfCases: Int = ???
38+
/** The number of cases in the sum.
39+
* Implemented by an inline method in concrete subclasses.
40+
*/
41+
erased def numberOfCases: Int = ???
4142

42-
/** The Generic representations of the sum's alternatives.
43-
* Implemented by an inline method in concrete subclasses.
44-
*/
45-
erased def alternative(n: Int): GenericProduct[_ <: T] = ???
46-
}
43+
/** The Generic representations of the sum's alternatives.
44+
* Implemented by an inline method in concrete subclasses.
45+
*/
46+
erased def alternative(n: Int): Generic[_ <: T] = ???
47+
}
4748

48-
/** A Generic for a product type */
49-
abstract class GenericProduct[T] extends Generic[T] {
50-
type ElemTypes <: Tuple
51-
type CaseLabel
52-
type ElemLabels <: Tuple
49+
/** A Generic for a product type */
50+
abstract class Product[T] extends Generic[T] {
51+
type ElemTypes <: Tuple
52+
type CaseLabel <: String
53+
type ElemLabels <: Tuple
5354

54-
def toProduct(x: T): Product
55-
def fromProduct(p: Product): T
55+
def toProduct(x: T): scala.Product
56+
def fromProduct(p: scala.Product): T
57+
}
58+
59+
class Singleton[T](val value: T) extends Generic[T] {
60+
type CaseLabel <: String
61+
}
5662
}
5763
}
5864

@@ -65,13 +71,13 @@ sealed trait Lst[+T] // derives Eq, Pickler, Show
6571
object Lst {
6672
import Deriving._
6773

68-
class GenericLst[T] extends GenericSum[Lst[T]] {
74+
class GenericLst[T] extends Generic.Sum[Lst[T]] {
6975
def ordinal(x: Lst[T]) = x match {
7076
case x: Cons[_] => 0
7177
case Nil => 1
7278
}
7379
inline override def numberOfCases = 2
74-
inline override def alternative(n: Int) <: GenericProduct[_ <: Lst[T]] =
80+
inline override def alternative(n: Int) <: Generic[_ <: Lst[T]] =
7581
inline n match {
7682
case 0 => Cons.GenericCons[T]
7783
case 1 => Nil.GenericNil
@@ -85,7 +91,7 @@ object Lst {
8591
object Cons {
8692
def apply[T](x: T, xs: Lst[T]): Lst[T] = new Cons(x, xs)
8793

88-
class GenericCons[T] extends GenericProduct[Cons[T]] {
94+
class GenericCons[T] extends Generic.Product[Cons[T]] {
8995
type ElemTypes = (T, Lst[T])
9096
type CaseLabel = "Cons"
9197
type ElemLabels = ("hd", "tl")
@@ -98,12 +104,8 @@ object Lst {
98104
}
99105

100106
case object Nil extends Lst[Nothing] {
101-
class GenericNil extends GenericProduct[Nil.type] {
102-
type ElemTypes = Unit
107+
class GenericNil extends Generic.Singleton[Nil.type](Nil) {
103108
type CaseLabel = "Nil"
104-
type ElemLabels = Unit
105-
def toProduct(x: Nil.type): Product = EmptyProduct
106-
def fromProduct(p: Product): Nil.type = Nil
107109
}
108110
implicit def GenericNil: GenericNil = new GenericNil
109111
}
@@ -121,7 +123,7 @@ object Pair {
121123
// common compiler-generated infrastructure
122124
import Deriving._
123125

124-
class GenericPair[T] extends GenericProduct[Pair[T]] {
126+
class GenericPair[T] extends Generic.Product[Pair[T]] {
125127
type ElemTypes = (T, T)
126128
type CaseLabel = "Pair"
127129
type ElemLabels = ("x", "y")
@@ -143,13 +145,13 @@ sealed trait Either[+L, +R] extends Product with Serializable // derives Eq, Pic
143145
object Either {
144146
import Deriving._
145147

146-
class GenericEither[L, R] extends GenericSum[Either[L, R]] {
148+
class GenericEither[L, R] extends Generic.Sum[Either[L, R]] {
147149
def ordinal(x: Either[L, R]) = x match {
148150
case x: Left[L] => 0
149151
case x: Right[R] => 1
150152
}
151153
inline override def numberOfCases = 2
152-
inline override def alternative(n: Int) <: GenericProduct[_ <: Either[L, R]] =
154+
inline override def alternative(n: Int) <: Generic[_ <: Either[L, R]] =
153155
inline n match {
154156
case 0 => Left.GenericLeft[L]
155157
case 1 => Right.GenericRight[R]
@@ -167,7 +169,7 @@ case class Right[R](elem: R) extends Either[Nothing, R]
167169

168170
object Left {
169171
import Deriving._
170-
class GenericLeft[L] extends GenericProduct[Left[L]] {
172+
class GenericLeft[L] extends Generic.Product[Left[L]] {
171173
type ElemTypes = L *: Unit
172174
type CaseLabel = "Left"
173175
type ElemLabels = "x" *: Unit
@@ -179,7 +181,7 @@ object Left {
179181

180182
object Right {
181183
import Deriving._
182-
class GenericRight[R] extends GenericProduct[Right[R]] {
184+
class GenericRight[R] extends Generic.Product[Right[R]] {
183185
type ElemTypes = R *: Unit
184186
type CaseLabel = "Right"
185187
type ElemLabels = "x" *: Unit
@@ -217,26 +219,29 @@ object Eq {
217219
true
218220
}
219221

220-
inline def eqlCase[T](gp: GenericProduct[T])(x: T, y: T): Boolean =
221-
eqlElems[gp.ElemTypes](0)(gp.toProduct(x), gp.toProduct(y))
222+
inline def eqlProduct[T](g: Generic.Product[T])(x: T, y: T): Boolean =
223+
eqlElems[g.ElemTypes](0)(g.toProduct(x), g.toProduct(y))
222224

223-
inline def eqlCases[T](gs: GenericSum[T], n: Int)(x: T, y: T, ord: Int): Boolean =
224-
inline if (n == gs.numberOfCases)
225+
inline def eqlCases[T](g: Generic.Sum[T], n: Int)(x: T, y: T, ord: Int): Boolean =
226+
inline if (n == g.numberOfCases)
225227
false
226228
else if (ord == n)
227-
inline gs.alternative(n) match {
228-
case gp: GenericProduct[p] => eqlCase[p](gp)(x.asInstanceOf[p], y.asInstanceOf[p])
229+
inline g.alternative(n) match {
230+
case g: Generic.Product[p] => eqlProduct[p](g)(x.asInstanceOf[p], y.asInstanceOf[p])
231+
case g: Generic.Singleton[_] => true
229232
}
230-
else eqlCases[T](gs, n + 1)(x, y, ord)
233+
else eqlCases[T](g, n + 1)(x, y, ord)
231234

232235
inline def derived[T](implicit ev: Generic[T]): Eq[T] = new Eq[T] {
233236
def eql(x: T, y: T): Boolean =
234237
inline ev match {
235-
case gs: GenericSum[T] =>
236-
val ord = gs.ordinal(x)
237-
ord == gs.ordinal(y) && eqlCases[T](gs, 0)(x, y, ord)
238-
case gp: GenericProduct[T] =>
239-
eqlCase[T](gp)(x, y)
238+
case g: Generic.Sum[T] =>
239+
val ord = g.ordinal(x)
240+
ord == g.ordinal(y) && eqlCases[T](g, 0)(x, y, ord)
241+
case g: Generic.Product[T] =>
242+
eqlProduct[T](g)(x, y)
243+
case g: Generic.Singleton[_] =>
244+
true
240245
}
241246
}
242247

@@ -269,17 +274,18 @@ object Pickler {
269274
case _: Unit =>
270275
}
271276

272-
inline def pickleCase[T](gp: GenericProduct[T])(buf: mutable.ListBuffer[Int], x: T): Unit =
273-
pickleElems[gp.ElemTypes](0)(buf, gp.toProduct(x))
277+
inline def pickleProduct[T](g: Generic.Product[T])(buf: mutable.ListBuffer[Int], x: T): Unit =
278+
pickleElems[g.ElemTypes](0)(buf, g.toProduct(x))
274279

275-
inline def pickleCases[T](gs: GenericSum[T], inline n: Int)(buf: mutable.ListBuffer[Int], x: T, ord: Int): Unit =
276-
inline if (n == gs.numberOfCases)
280+
inline def pickleCases[T](g: Generic.Sum[T], inline n: Int)(buf: mutable.ListBuffer[Int], x: T, ord: Int): Unit =
281+
inline if (n == g.numberOfCases)
277282
()
278283
else if (ord == n)
279-
inline gs.alternative(n) match {
280-
case gp: GenericProduct[p] => pickleCase(gp)(buf, x.asInstanceOf[p])
284+
inline g.alternative(n) match {
285+
case g: Generic.Product[p] => pickleProduct(g)(buf, x.asInstanceOf[p])
286+
case g: Generic.Singleton[s] =>
281287
}
282-
else pickleCases[T](gs, n + 1)(buf, x, ord)
288+
else pickleCases[T](g, n + 1)(buf, x, ord)
283289

284290
inline def tryUnpickle[T](buf: mutable.ListBuffer[Int]): T = implicit match {
285291
case pkl: Pickler[T] => pkl.unpickle(buf)
@@ -293,42 +299,46 @@ object Pickler {
293299
case _: Unit =>
294300
}
295301

296-
inline def unpickleCase[T](gp: GenericProduct[T])(buf: mutable.ListBuffer[Int]): T = {
297-
inline val size = constValue[Tuple.Size[gp.ElemTypes]]
302+
inline def unpickleProduct[T](g: Generic.Product[T])(buf: mutable.ListBuffer[Int]): T = {
303+
inline val size = constValue[Tuple.Size[g.ElemTypes]]
298304
inline if (size == 0)
299-
gp.fromProduct(EmptyProduct)
305+
g.fromProduct(EmptyProduct)
300306
else {
301307
val elems = new Array[Object](size)
302-
unpickleElems[gp.ElemTypes](0)(buf, elems)
303-
gp.fromProduct(ArrayProduct(elems))
308+
unpickleElems[g.ElemTypes](0)(buf, elems)
309+
g.fromProduct(ArrayProduct(elems))
304310
}
305311
}
306312

307-
inline def unpickleCases[T](gs: GenericSum[T], n: Int)(buf: mutable.ListBuffer[Int], ord: Int): T =
308-
inline if (n == gs.numberOfCases)
313+
inline def unpickleCases[T](g: Generic.Sum[T], n: Int)(buf: mutable.ListBuffer[Int], ord: Int): T =
314+
inline if (n == g.numberOfCases)
309315
throw new IndexOutOfBoundsException(s"unexpected ordinal number: $ord")
310316
else if (ord == n)
311-
inline gs.alternative(n) match {
312-
case gp: GenericProduct[p] => unpickleCase(gp)(buf)
317+
inline g.alternative(n) match {
318+
case g: Generic.Product[p] => unpickleProduct(g)(buf)
319+
case g: Generic.Singleton[s] => g.value
313320
}
314-
else unpickleCases[T](gs, n + 1)(buf, ord)
321+
else unpickleCases[T](g, n + 1)(buf, ord)
315322

316323
inline def derived[T](implicit ev: Generic[T]): Pickler[T] = new {
317324
def pickle(buf: mutable.ListBuffer[Int], x: T): Unit =
318325
inline ev match {
319-
case gs: GenericSum[T] =>
320-
val ord = gs.ordinal(x)
326+
case g: Generic.Sum[T] =>
327+
val ord = g.ordinal(x)
321328
buf += ord
322-
pickleCases[T](gs, 0)(buf, x, ord)
323-
case gp: GenericProduct[p] =>
324-
pickleCase(gp)(buf, x)
329+
pickleCases[T](g, 0)(buf, x, ord)
330+
case g: Generic.Product[p] =>
331+
pickleProduct(g)(buf, x)
332+
case g: Generic.Singleton[_] =>
325333
}
326334
def unpickle(buf: mutable.ListBuffer[Int]): T =
327335
inline ev match {
328-
case gs: GenericSum[T] =>
329-
unpickleCases[T](gs, 0)(buf, nextInt(buf))
330-
case gp: GenericProduct[T] =>
331-
unpickleCase[T](gp)(buf)
336+
case g: Generic.Sum[T] =>
337+
unpickleCases[T](g, 0)(buf, nextInt(buf))
338+
case g: Generic.Product[T] =>
339+
unpickleProduct[T](g)(buf)
340+
case g: Generic.Singleton[s] =>
341+
constValue[s]
332342
}
333343
}
334344

@@ -363,27 +373,30 @@ object Show {
363373
Nil
364374
}
365375

366-
inline def showCase[T](gp: GenericProduct[T])(x: T): String = {
367-
val labl = constValue[gp.CaseLabel]
368-
showElems[gp.ElemTypes, gp.ElemLabels](0)(gp.toProduct(x)).mkString(s"$labl(", ", ", ")")
376+
inline def showProduct[T](g: Generic.Product[T])(x: T): String = {
377+
val labl = constValue[g.CaseLabel]
378+
showElems[g.ElemTypes, g.ElemLabels](0)(g.toProduct(x)).mkString(s"$labl(", ", ", ")")
369379
}
370380

371-
inline def showCases[T](gs: GenericSum[T], n: Int)(x: T, ord: Int): String =
372-
inline if (n == gs.numberOfCases)
381+
inline def showCases[T](g: Generic.Sum[T], n: Int)(x: T, ord: Int): String =
382+
inline if (n == g.numberOfCases)
373383
""
374384
else if (ord == n)
375-
inline gs.alternative(n) match {
376-
case gp: GenericProduct[p] => showCase(gp)(x.asInstanceOf[p])
385+
inline g.alternative(n) match {
386+
case g: Generic.Product[p] => showProduct(g)(x.asInstanceOf[p])
387+
case g: Generic.Singleton[s] => constValue[g.CaseLabel]
377388
}
378-
else showCases[T](gs, n + 1)(x, ord)
389+
else showCases[T](g, n + 1)(x, ord)
379390

380391
inline def derived[T](implicit ev: Generic[T]): Show[T] = new {
381392
def show(x: T): String =
382393
inline ev match {
383-
case gs: GenericSum[T] =>
384-
showCases(gs, 0)(x, gs.ordinal(x))
385-
case gp: GenericProduct[p] =>
386-
showCase(gp)(x)
394+
case g: Generic.Sum[T] =>
395+
showCases(g, 0)(x, g.ordinal(x))
396+
case g: Generic.Product[p] =>
397+
showProduct(g)(x)
398+
case g: Generic.Singleton[s] =>
399+
constValue[g.CaseLabel]
387400
}
388401
}
389402

0 commit comments

Comments
 (0)