Skip to content

Commit 7eda38c

Browse files
committed
Merge pull request scala-js#1982 from nicolasstucki/support-secondary-contructors-in-JS-classes
Fix scala-js#1811: Add support for secondary constructors in JS classes
2 parents bf2aafe + c9d60c7 commit 7eda38c

File tree

5 files changed

+541
-64
lines changed

5 files changed

+541
-64
lines changed

compiler/src/main/scala/org/scalajs/core/compiler/GenJSCode.scala

Lines changed: 330 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -649,29 +649,341 @@ abstract class GenJSCode extends plugins.PluginComponent
649649
constructorTrees: List[DefDef]): js.Tree = {
650650
implicit val pos = classSym.pos
651651

652-
val (primaryCtorTree :: Nil, secondaryCtorTrees) =
653-
constructorTrees.partition(_.symbol.isPrimaryConstructor)
654-
655652
// Implementation restriction
656-
val sym = primaryCtorTree.symbol
653+
val syms = constructorTrees.map(_.symbol)
657654
val hasBadParam = enteringPhase(currentRun.uncurryPhase) {
658-
sym.paramss.flatten.exists(p => p.hasDefault || isRepeated(p))
655+
syms.exists(_.paramss.flatten.exists(p => p.hasDefault))
659656
}
660657
if (hasBadParam) {
661658
reporter.error(pos,
662659
"Implementation restriction: the constructor of a " +
663-
"Scala.js-defined JS classes cannot have default parameters nor " +
664-
"repeated parameters.")
660+
"Scala.js-defined JS classes cannot have default parameters.")
665661
}
666662

667-
// Implementation restriction
668-
for (tree <- secondaryCtorTrees) {
669-
reporter.error(tree.pos,
670-
"Implementation restriction: Scala.js-defined JS classes cannot " +
671-
"have secondary constructors")
663+
withNewLocalNameScope {
664+
val ctors: List[js.MethodDef] = constructorTrees.flatMap { tree =>
665+
genMethodWithCurrentLocalNameScope(tree)
666+
}
667+
668+
val dispatch =
669+
genJSConstructorExport(constructorTrees.map(_.symbol))
670+
val js.MethodDef(_, dispatchName, dispatchArgs, dispatchResultType,
671+
dispatchResolution) = dispatch
672+
673+
val jsConstructorBuilder = mkJSConstructorBuilder(ctors)
674+
675+
val overloadIdent = freshLocalIdent("overload")
676+
677+
// Section containing the overload resolution and casts of parameters
678+
val overloadSelection = mkOverloadSelection(jsConstructorBuilder,
679+
overloadIdent, dispatchResolution)
680+
681+
/* Section containing all the code executed before the call to `this`
682+
* for every secondary constructor.
683+
*/
684+
val prePrimaryCtorBody =
685+
jsConstructorBuilder.mkPrePrimaryCtorBody(overloadIdent)
686+
687+
val primaryCtorBody = jsConstructorBuilder.primaryCtorBody
688+
689+
/* Section containing all the code executed after the call to this for
690+
* every secondary constructor.
691+
*/
692+
val postPrimaryCtorBody =
693+
jsConstructorBuilder.mkPostPrimaryCtorBody(overloadIdent)
694+
695+
val newBody = js.Block(overloadSelection ::: prePrimaryCtorBody ::
696+
primaryCtorBody :: postPrimaryCtorBody :: Nil)
697+
698+
js.MethodDef(static = false, dispatchName, dispatchArgs, jstpe.NoType,
699+
newBody)(dispatch.optimizerHints, None)
700+
}
701+
}
702+
703+
private class ConstructorTree(val overrideNum: Int, val method: js.MethodDef,
704+
val subConstructors: List[ConstructorTree]) {
705+
706+
lazy val overrideNumBounds: (Int, Int) =
707+
if (subConstructors.isEmpty) (overrideNum, overrideNum)
708+
else (subConstructors.head.overrideNumBounds._1, overrideNum)
709+
710+
def get(methodName: String): Option[ConstructorTree] = {
711+
if (methodName == this.method.name.name) {
712+
Some(this)
713+
} else {
714+
subConstructors.iterator.map(_.get(methodName)).collectFirst {
715+
case Some(node) => node
716+
}
717+
}
718+
}
719+
720+
def getParamRefs(implicit pos: Position): List[js.VarRef] =
721+
method.args.map(_.ref)
722+
723+
def getAllParamDefsAsVars(implicit pos: Position): List[js.VarDef] = {
724+
val localDefs = method.args.map { pDef =>
725+
js.VarDef(pDef.name, pDef.ptpe, mutable = true, jstpe.zeroOf(pDef.ptpe))
726+
}
727+
localDefs ++ subConstructors.flatMap(_.getAllParamDefsAsVars)
728+
}
729+
}
730+
731+
private class JSConstructorBuilder(root: ConstructorTree) {
732+
733+
def primaryCtorBody: js.Tree = root.method.body
734+
735+
def hasSubConstructors: Boolean = root.subConstructors.nonEmpty
736+
737+
def getOverrideNum(methodName: String): Int =
738+
root.get(methodName).fold(-1)(_.overrideNum)
739+
740+
def getParamRefsFor(methodName: String)(implicit pos: Position): List[js.VarRef] =
741+
root.get(methodName).fold(List.empty[js.VarRef])(_.getParamRefs)
742+
743+
def getAllParamDefsAsVars(implicit pos: Position): List[js.VarDef] =
744+
root.getAllParamDefsAsVars
745+
746+
def mkPrePrimaryCtorBody(overrideNumIdent: js.Ident)(
747+
implicit pos: Position): js.Tree = {
748+
val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType)
749+
mkSubPreCalls(root, overrideNumRef)
750+
}
751+
752+
def mkPostPrimaryCtorBody(overrideNumIdent: js.Ident)(
753+
implicit pos: Position): js.Tree = {
754+
val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType)
755+
js.Block(mkSubPostCalls(root, overrideNumRef))
756+
}
757+
758+
private def mkSubPreCalls(constructorTree: ConstructorTree,
759+
overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = {
760+
val overrideNumss = constructorTree.subConstructors.map(_.overrideNumBounds)
761+
val paramRefs = constructorTree.getParamRefs
762+
val bodies = constructorTree.subConstructors.map { constructorTree =>
763+
mkPrePrimaryCtorBodyOnSndCtr(constructorTree, overrideNumRef, paramRefs)
764+
}
765+
overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) {
766+
case ((numBounds, body), acc) =>
767+
val cond = mkOverrideNumsCond(overrideNumRef, numBounds)
768+
js.If(cond, body, acc)(jstpe.BooleanType)
769+
}
770+
}
771+
772+
private def mkPrePrimaryCtorBodyOnSndCtr(constructorTree: ConstructorTree,
773+
overrideNumRef: js.VarRef, outputParams: List[js.VarRef])(
774+
implicit pos: Position): js.Tree = {
775+
val subCalls =
776+
mkSubPreCalls(constructorTree, overrideNumRef)
777+
778+
val preSuperCall = {
779+
constructorTree.method.body match {
780+
case js.Block(stats) =>
781+
val beforeSuperCall = stats.takeWhile {
782+
case js.ApplyStatic(_, mtd, _) => !ir.Definitions.isConstructorName(mtd.name)
783+
case _ => true
784+
}
785+
val superCallParams = stats.collectFirst {
786+
case js.ApplyStatic(_, mtd, js.This() :: args)
787+
if ir.Definitions.isConstructorName(mtd.name) =>
788+
zipMap(outputParams, args)(js.Assign(_, _))
789+
}.getOrElse(Nil)
790+
791+
beforeSuperCall ::: superCallParams
792+
793+
case js.ApplyStatic(_, mtd, js.This() :: args)
794+
if ir.Definitions.isConstructorName(mtd.name) =>
795+
zipMap(outputParams, args)(js.Assign(_, _))
796+
797+
case _ => Nil
798+
}
799+
}
800+
801+
js.Block(subCalls :: preSuperCall)
802+
}
803+
804+
private def mkSubPostCalls(constructorTree: ConstructorTree,
805+
overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = {
806+
val overrideNumss = constructorTree.subConstructors.map(_.overrideNumBounds)
807+
val bodies = constructorTree.subConstructors.map { ct =>
808+
mkPostPrimaryCtorBodyOnSndCtr(ct, overrideNumRef)
809+
}
810+
overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) {
811+
case ((numBounds, js.Skip()), acc) => acc
812+
813+
case ((numBounds, body), acc) =>
814+
val cond = mkOverrideNumsCond(overrideNumRef, numBounds)
815+
js.If(cond, body, acc)(jstpe.BooleanType)
816+
}
817+
}
818+
819+
private def mkPostPrimaryCtorBodyOnSndCtr(constructorTree: ConstructorTree,
820+
overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = {
821+
val postSuperCall = {
822+
constructorTree.method.body match {
823+
case js.Block(stats) =>
824+
stats.dropWhile {
825+
case js.ApplyStatic(_, mtd, _) => !ir.Definitions.isConstructorName(mtd.name)
826+
case _ => true
827+
}.tail
828+
829+
case _ => Nil
830+
}
831+
}
832+
js.Block(postSuperCall :+ mkSubPostCalls(constructorTree, overrideNumRef))
833+
}
834+
835+
private def mkOverrideNumsCond(numRef: js.VarRef,
836+
numBounds: (Int, Int))(implicit pos: Position) = numBounds match {
837+
case (lo, hi) if lo == hi =>
838+
js.BinaryOp(js.BinaryOp.===, js.IntLiteral(lo), numRef)
839+
840+
case (lo, hi) if lo == hi - 1 =>
841+
val lhs = js.BinaryOp(js.BinaryOp.===, numRef, js.IntLiteral(lo))
842+
val rhs = js.BinaryOp(js.BinaryOp.===, numRef, js.IntLiteral(hi))
843+
js.If(lhs, js.BooleanLiteral(true), rhs)(jstpe.BooleanType)
844+
845+
case (lo, hi) =>
846+
val lhs = js.BinaryOp(js.BinaryOp.Num_<=, js.IntLiteral(lo), numRef)
847+
val rhs = js.BinaryOp(js.BinaryOp.Num_<=, numRef, js.IntLiteral(hi))
848+
js.BinaryOp(js.BinaryOp.Boolean_&, lhs, rhs)
849+
js.If(lhs, rhs, js.BooleanLiteral(false))(jstpe.BooleanType)
850+
}
851+
}
852+
853+
private def zipMap[T, U, V](xs: List[T], ys: List[U])(
854+
f: (T, U) => V): List[V] = {
855+
for ((x, y) <- xs zip ys) yield f(x, y)
856+
}
857+
858+
/** mkOverloadSelection return a list of `stats` with that starts with:
859+
* 1) The definition for the local variable that will hold the overload
860+
* resolution number.
861+
* 2) The definitions of all local variables that are used as parameters
862+
* in all the constructors.
863+
* 3) The overload resolution match/if statements. For each overload the
864+
* overload number is assigned and the parameters are cast and assigned
865+
* to their corresponding variables.
866+
*/
867+
private def mkOverloadSelection(jsConstructorBuilder: JSConstructorBuilder,
868+
overloadIdent: js.Ident, dispatchResolution: js.Tree)(
869+
implicit pos: Position): List[js.Tree]= {
870+
if (!jsConstructorBuilder.hasSubConstructors) {
871+
dispatchResolution match {
872+
/* Dispatch to constructor with no arguments.
873+
* Contains trivial parameterless call to the constructor.
874+
*/
875+
case js.ApplyStatic(_, mtd, js.This() :: Nil)
876+
if ir.Definitions.isConstructorName(mtd.name) =>
877+
Nil
878+
879+
/* Dispatch to constructor with at least one argument.
880+
* Where js.Block's stats.init corresponds to the parameter casts and
881+
* js.Block's stats.last contains the call to the constructor.
882+
*/
883+
case js.Block(stats) =>
884+
val js.ApplyStatic(_, method, _) = stats.last
885+
val refs = jsConstructorBuilder.getParamRefsFor(method.name)
886+
val paramCasts = stats.init.map(_.asInstanceOf[js.VarDef])
887+
zipMap(refs, paramCasts) { (ref, paramCast) =>
888+
js.VarDef(ref.ident, ref.tpe, mutable = false, paramCast.rhs)
889+
}
890+
}
891+
} else {
892+
val overloadRef = js.VarRef(overloadIdent)(jstpe.IntType)
893+
894+
/* transformDispatch takes the body of the method generated by
895+
* `genJSConstructorExport` and transform it recursively.
896+
*/
897+
def transformDispatch(tree: js.Tree): js.Tree = tree match {
898+
/* Dispatch to constructor with no arguments.
899+
* Contains trivial parameterless call to the constructor.
900+
*/
901+
case js.ApplyStatic(_, method, js.This() :: Nil)
902+
if ir.Definitions.isConstructorName(method.name) =>
903+
js.Assign(overloadRef,
904+
js.IntLiteral(jsConstructorBuilder.getOverrideNum(method.name)))
905+
906+
/* Dispatch to constructor with at least one argument.
907+
* Where js.Block's stats.init corresponds to the parameter casts and
908+
* js.Block's stats.last contains the call to the constructor.
909+
*/
910+
case js.Block(stats) =>
911+
val js.ApplyStatic(_, method, _) = stats.last
912+
913+
val num = jsConstructorBuilder.getOverrideNum(method.name)
914+
val overloadAssign = js.Assign(overloadRef, js.IntLiteral(num))
915+
916+
val refs = jsConstructorBuilder.getParamRefsFor(method.name)
917+
val paramCasts = stats.init.map(_.asInstanceOf[js.VarDef].rhs)
918+
val parameterAssigns = zipMap(refs, paramCasts)(js.Assign(_, _))
919+
920+
js.Block(overloadAssign :: parameterAssigns)
921+
922+
// Parameter count resolution
923+
case js.Match(selector, cases, default) =>
924+
val newCases = cases.map {
925+
case (literals, body) => (literals, transformDispatch(body))
926+
}
927+
val newDefault = transformDispatch(default)
928+
js.Match(selector, newCases, newDefault)(tree.tpe)
929+
930+
// Parameter type resolution
931+
case js.If(cond, thenp, elsep) =>
932+
js.If(cond, transformDispatch(thenp),
933+
transformDispatch(elsep))(tree.tpe)
934+
935+
// Throw(StringLiteral(No matching overload))
936+
case tree: js.Throw =>
937+
tree
938+
}
939+
940+
val newDispatchResolution = transformDispatch(dispatchResolution)
941+
val allParamDefsAsVars = jsConstructorBuilder.getAllParamDefsAsVars
942+
val overrideNumDef =
943+
js.VarDef(overloadIdent, jstpe.IntType, mutable = true, js.IntLiteral(0))
944+
945+
overrideNumDef :: allParamDefsAsVars ::: newDispatchResolution :: Nil
946+
}
947+
}
948+
949+
private def mkJSConstructorBuilder(ctors: List[js.MethodDef])(
950+
implicit pos: Position): JSConstructorBuilder = {
951+
def findCtorForwarderCall(tree: js.Tree): String = tree match {
952+
case js.ApplyStatic(_, method, js.This() :: _)
953+
if ir.Definitions.isConstructorName(method.name) =>
954+
method.name
955+
956+
case js.Block(stats) =>
957+
stats.collectFirst {
958+
case js.ApplyStatic(_, method, js.This() :: _)
959+
if ir.Definitions.isConstructorName(method.name) =>
960+
method.name
961+
}.get
962+
}
963+
964+
val (primaryCtor :: Nil, secondaryCtors) = ctors.partition {
965+
_.body match {
966+
case js.Block(stats) =>
967+
stats.exists(_.isInstanceOf[js.JSSuperConstructorCall])
968+
969+
case _: js.JSSuperConstructorCall => true
970+
case _ => false
971+
}
972+
}
973+
974+
val ctorToChildren = secondaryCtors.map { ctor =>
975+
findCtorForwarderCall(ctor.body) -> ctor
976+
}.groupBy(_._1).mapValues(_.map(_._2)).withDefaultValue(Nil)
977+
978+
var overrideNum = -1
979+
def mkConstructorTree(method: js.MethodDef): ConstructorTree = {
980+
val methodName = method.name.name
981+
val subCtrTrees = ctorToChildren(methodName).map(mkConstructorTree)
982+
overrideNum += 1
983+
new ConstructorTree(overrideNum, method, subCtrTrees)
672984
}
673985

674-
genMethod(primaryCtorTree).get
986+
new JSConstructorBuilder(mkConstructorTree(primaryCtor))
675987
}
676988

677989
// Generate a method -------------------------------------------------------
@@ -720,9 +1032,7 @@ abstract class GenJSCode extends plugins.PluginComponent
7201032
val isJSClassConstructor =
7211033
sym.isClassConstructor && isScalaJSDefinedJSClass(currentClassSym)
7221034

723-
val methodName: js.PropertyName =
724-
if (isJSClassConstructor) js.StringLiteral("constructor")
725-
else encodeMethodSym(sym)
1035+
val methodName: js.PropertyName = encodeMethodSym(sym)
7261036

7271037
def jsParams = for (param <- params) yield {
7281038
implicit val pos = param.pos
@@ -793,12 +1103,11 @@ abstract class GenJSCode extends plugins.PluginComponent
7931103
val methodDef = {
7941104
if (isJSClassConstructor) {
7951105
val body0 = genStat(rhs)
796-
val body1 = moveAllStatementsAfterSuperConstructorCall(body0)
797-
val (patchedParams, patchedBody) =
798-
patchFunBodyWithBoxes(sym, jsParams, body1)
1106+
val body1 =
1107+
if (!sym.isPrimaryConstructor) body0
1108+
else moveAllStatementsAfterSuperConstructorCall(body0)
7991109
js.MethodDef(static = false, methodName,
800-
patchedParams, jstpe.NoType, patchedBody)(
801-
optimizerHints, None)
1110+
jsParams, jstpe.NoType, body1)(optimizerHints, None)
8021111
} else if (sym.isClassConstructor) {
8031112
js.MethodDef(static = false, methodName,
8041113
jsParams, jstpe.NoType,

0 commit comments

Comments
 (0)