|
| 1 | +package dotty.tools.dotc |
| 2 | +package transform |
| 3 | + |
| 4 | +import core._ |
| 5 | +import Names._ |
| 6 | +import StdNames.{nme, tpnme} |
| 7 | +import Types._ |
| 8 | +import dotty.tools.dotc.transform.MegaPhase._ |
| 9 | +import Flags._ |
| 10 | +import Contexts.Context |
| 11 | +import Symbols._ |
| 12 | +import Constants._ |
| 13 | +import Decorators._ |
| 14 | +import DenotTransformers._ |
| 15 | + |
| 16 | +object CompleteJavaEnums { |
| 17 | + val name: String = "completeJavaEnums" |
| 18 | + |
| 19 | + private val nameParamName: TermName = "$name".toTermName |
| 20 | + private val ordinalParamName: TermName = "$ordinal".toTermName |
| 21 | +} |
| 22 | + |
| 23 | +/** For Scala enums that inherit from java.lang.Enum: |
| 24 | + * Add constructor parameters for `name` and `ordinal` to pass from each |
| 25 | + * case to the java.lang.Enum class. |
| 26 | + */ |
| 27 | +class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase => |
| 28 | + import CompleteJavaEnums._ |
| 29 | + import ast.tpd._ |
| 30 | + |
| 31 | + override def phaseName: String = CompleteJavaEnums.name |
| 32 | + |
| 33 | + override def relaxedTypingInGroup: Boolean = true |
| 34 | + // Because it adds additional parameters to some constructors |
| 35 | + |
| 36 | + def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = |
| 37 | + if (sym.isConstructor && ( |
| 38 | + sym == defn.JavaEnumClass.primaryConstructor || |
| 39 | + derivesFromJavaEnum(sym.owner))) |
| 40 | + addConstrParams(sym.info) |
| 41 | + else tp |
| 42 | + |
| 43 | + /** Is `sym` a Scala enum class that derives from `java.lang.Enum`? |
| 44 | + */ |
| 45 | + private def derivesFromJavaEnum(sym: Symbol)(implicit ctx: Context) = |
| 46 | + sym.is(Enum, butNot = Case) && sym.derivesFrom(defn.JavaEnumClass) |
| 47 | + |
| 48 | + /** Add constructor parameters `$name: String` and `$ordinal: Int` to the end of |
| 49 | + * the last parameter list of (method- or poly-) type `tp`. |
| 50 | + */ |
| 51 | + private def addConstrParams(tp: Type)(implicit ctx: Context): Type = tp match { |
| 52 | + case tp: PolyType => |
| 53 | + tp.derivedLambdaType(resType = addConstrParams(tp.resType)) |
| 54 | + case tp: MethodType => |
| 55 | + tp.resType match { |
| 56 | + case restpe: MethodType => |
| 57 | + tp.derivedLambdaType(resType = addConstrParams(restpe)) |
| 58 | + case _ => |
| 59 | + tp.derivedLambdaType( |
| 60 | + paramNames = tp.paramNames ++ List(nameParamName, ordinalParamName), |
| 61 | + paramInfos = tp.paramInfos ++ List(defn.StringType, defn.IntType)) |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + /** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner` |
| 66 | + * with given flags (either `Param` or `ParamAccessor`) |
| 67 | + */ |
| 68 | + private def addedParams(owner: Symbol, flag: FlagSet)(implicit ctx: Context): List[ValDef] = { |
| 69 | + val nameParam = ctx.newSymbol(owner, nameParamName, flag | Synthetic, defn.StringType, coord = owner.span) |
| 70 | + val ordinalParam = ctx.newSymbol(owner, ordinalParamName, flag | Synthetic, defn.IntType, coord = owner.span) |
| 71 | + List(ValDef(nameParam), ValDef(ordinalParam)) |
| 72 | + } |
| 73 | + |
| 74 | + /** Add arguments `args` to the parent constructor application in `parents` that invokes |
| 75 | + * a constructor of `targetCls`, |
| 76 | + */ |
| 77 | + private def addEnumConstrArgs(targetCls: Symbol, parents: List[Tree], args: List[Tree])(implicit ctx: Context): List[Tree] = |
| 78 | + parents.map { |
| 79 | + case app @ Apply(fn, args0) if fn.symbol.owner == targetCls => cpy.Apply(app)(fn, args0 ++ args) |
| 80 | + case p => p |
| 81 | + } |
| 82 | + |
| 83 | + /** 1. If this is a constructor of a enum class that extends, add $name and $ordinal parameters to it. |
| 84 | + * |
| 85 | + * 2. If this is a $new method that creates simple cases, pass $name and $ordinal parameters |
| 86 | + * to the enum superclass. The $new method looks like this: |
| 87 | + * |
| 88 | + * def $new(..., enumTag: Int, name: String) = { |
| 89 | + * class $anon extends E(...) { ... } |
| 90 | + * new $anon |
| 91 | + * } |
| 92 | + * |
| 93 | + * After the transform it is expanded to |
| 94 | + * |
| 95 | + * def $new(..., enumTag: Int, name: String) = { |
| 96 | + * class $anon extends E(..., name, enumTag) { ... } |
| 97 | + * new $anon |
| 98 | + * } |
| 99 | + */ |
| 100 | + override def transformDefDef(tree: DefDef)(implicit ctx: Context): DefDef = { |
| 101 | + val sym = tree.symbol |
| 102 | + if (sym.isConstructor && derivesFromJavaEnum(sym.owner)) |
| 103 | + cpy.DefDef(tree)( |
| 104 | + vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param))) |
| 105 | + else if (sym.name == nme.DOLLAR_NEW && derivesFromJavaEnum(sym.owner.linkedClass)) { |
| 106 | + val Block((tdef @ TypeDef(tpnme.ANON_CLASS, templ: Template)) :: Nil, call) = tree.rhs |
| 107 | + val args = tree.vparamss.last.takeRight(2).map(param => ref(param.symbol)).reverse |
| 108 | + val templ1 = cpy.Template(templ)( |
| 109 | + parents = addEnumConstrArgs(sym.owner.linkedClass, templ.parents, args)) |
| 110 | + cpy.DefDef(tree)( |
| 111 | + rhs = cpy.Block(tree.rhs)(cpy.TypeDef(tdef)(tdef.name, templ1) :: Nil, call)) |
| 112 | + } |
| 113 | + else tree |
| 114 | + } |
| 115 | + |
| 116 | + /** 1. If this is an enum class, add $name and $ordinal parameters to its |
| 117 | + * parameter accessors and pass them on to the java.lang.Enum constructor. |
| 118 | + * |
| 119 | + * 2. If this is an anonymous class that implement a value enum case, |
| 120 | + * pass $name and $ordinal parameters to the enum superclass. The class |
| 121 | + * looks like this: |
| 122 | + * |
| 123 | + * class $anon extends E(...) { |
| 124 | + * ... |
| 125 | + * def enumTag = N |
| 126 | + * def toString = S |
| 127 | + * ... |
| 128 | + * } |
| 129 | + * |
| 130 | + * After the transform it is expanded to |
| 131 | + * |
| 132 | + * class $anon extends E(..., N, S) { |
| 133 | + * "same as before" |
| 134 | + * } |
| 135 | + */ |
| 136 | + override def transformTemplate(templ: Template)(implicit ctx: Context): Template = { |
| 137 | + val cls = templ.symbol.owner |
| 138 | + if (derivesFromJavaEnum(cls)) { |
| 139 | + val (params, rest) = decomposeTemplateBody(templ.body) |
| 140 | + val addedDefs = addedParams(cls, ParamAccessor) |
| 141 | + val addedSyms = addedDefs.map(_.symbol.entered) |
| 142 | + cpy.Template(templ)( |
| 143 | + parents = addEnumConstrArgs(defn.JavaEnumClass, templ.parents, addedSyms.map(ref)), |
| 144 | + body = params ++ addedDefs ++ rest) |
| 145 | + } |
| 146 | + else if (cls.isAnonymousClass && cls.owner.is(EnumCase) && derivesFromJavaEnum(cls.owner.owner.linkedClass)) { |
| 147 | + def rhsOf(name: TermName) = |
| 148 | + templ.body.collect { |
| 149 | + case mdef: DefDef if mdef.name == name => mdef.rhs |
| 150 | + }.head |
| 151 | + val args = List(rhsOf(nme.toString_), rhsOf(nme.enumTag)) |
| 152 | + cpy.Template(templ)( |
| 153 | + parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args)) |
| 154 | + } |
| 155 | + else templ |
| 156 | + } |
| 157 | +} |
0 commit comments