diff --git a/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala b/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala index 183b25213b4d..c533f2e15d9e 100644 --- a/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala +++ b/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala @@ -104,6 +104,17 @@ class JSCodeGen()(using genCtx: Context) { } } + private def withPerMethodBodyState[A](methodSym: Symbol)(body: => A): A = { + withScopedVars( + currentMethodSym := methodSym, + thisLocalVarIdent := None, + isModuleInitialized := new ScopedVar.VarBox(false), + undefinedDefaultParams := mutable.Set.empty, + ) { + body + } + } + private def acquireContextualJSClassValue[A](f: Option[js.Tree] => A): A = { val jsClassValue = contextualJSClassValue.get withScopedVars( @@ -456,6 +467,13 @@ class JSCodeGen()(using genCtx: Context) { i"genNonNativeJSClass() must be called only for non-native JS classes: $sym") assert(sym.superClass != NoSymbol, sym) + if (hasDefaultCtorArgsAndJSModule(sym)) { + report.error( + "Implementation restriction: " + + "constructors of non-native JS classes cannot have default parameters if their companion module is JS native.", + td) + } + val classIdent = encodeClassNameIdent(sym) val originalName = originalNameOfClass(sym) @@ -522,14 +540,25 @@ class JSCodeGen()(using genCtx: Context) { val topLevelExports = jsExportsGen.genTopLevelExports(sym) - val (jsClassCaptures, generatedConstructor) = - genJSClassCapturesAndConstructor(sym, constructorTrees.toList) + val (generatedConstructor, jsClassCaptures) = withNewLocalNameScope { + val isNested = sym.isNestedJSClass - /* If there is one, the JS super class value is always the first JS class - * capture. This is a JSCodeGen-specific invariant (the IR does not rely - * on this) enforced in genJSClassCapturesAndConstructor. - */ - val jsSuperClass = jsClassCaptures.map(_.head.ref) + if (isNested) + localNames.reserveLocalName(JSSuperClassParamName) + + val (captures, ctor) = genJSClassCapturesAndConstructor(constructorTrees.toList) + + val jsClassCaptures = if (isNested) { + val superParam = js.ParamDef(js.LocalIdent(JSSuperClassParamName), + NoOriginalName, jstpe.AnyType, mutable = false) + Some(superParam :: captures) + } else { + assert(captures.isEmpty, s"found non nested JS class with captures $captures at $pos") + None + } + + (ctor, jsClassCaptures) + } // Generate fields (and add to methods + ctors) val generatedMembers = { @@ -555,7 +584,7 @@ class JSCodeGen()(using genCtx: Context) { jsClassCaptures, Some(encodeClassNameIdent(sym.superClass)), genClassInterfaces(sym, forJSClass = true), - jsSuperClass, + jsSuperClass = jsClassCaptures.map(_.head.ref), None, hashedMemberDefs, topLevelExports)( @@ -960,47 +989,327 @@ class JSCodeGen()(using genCtx: Context) { // Constructor of a non-native JS class ------------------------------------ - def genJSClassCapturesAndConstructor(classSym: Symbol, - constructorTrees: List[DefDef]): (Option[List[js.ParamDef]], js.JSMethodDef) = { - implicit val pos = classSym.span + def genJSClassCapturesAndConstructor(constructorTrees: List[DefDef])( + implicit pos: SourcePosition): (List[js.ParamDef], js.JSMethodDef) = { + /* We need to merge all Scala constructors into a single one because the + * IR, like JavaScript, only allows a single one. + * + * We do this by applying: + * 1. Applying runtime type based dispatch, just like exports. + * 2. Splitting secondary ctors into parts before and after the `this` call. + * 3. Topo-sorting all constructor statements and including/excluding + * them based on the overload that was chosen. + */ - if (hasDefaultCtorArgsAndJSModule(classSym)) { - report.error( - "Implementation restriction: " + - "constructors of non-native JS classes cannot have default parameters if their companion module is JS native.", - classSym.srcPos) - val ctorDef = js.JSMethodDef(js.MemberFlags.empty, - js.StringLiteral("constructor"), Nil, None, js.Skip())( - OptimizerHints.empty, None) - (None, ctorDef) - } else { - withNewLocalNameScope { - localNames.reserveLocalName(JSSuperClassParamName) + val (primaryTree :: Nil, secondaryTrees) = + constructorTrees.partition(_.symbol.isPrimaryConstructor) - val ctors: List[js.MethodDef] = constructorTrees.flatMap { tree => - genMethodWithCurrentLocalNameScope(tree) - } + val primaryCtor = genPrimaryJSClassCtor(primaryTree) + val secondaryCtors = secondaryTrees.map(genSecondaryJSClassCtor(_)) - val (captureParams, dispatch) = - jsExportsGen.genJSConstructorDispatch(constructorTrees.map(_.symbol)) + // VarDefs for the parameters of all constructors. + val paramVarDefs = for { + vparam <- constructorTrees.flatMap(_.paramss.flatten) + } yield { + val sym = vparam.symbol + val tpe = toIRType(sym.info) + js.VarDef(encodeLocalSym(sym), originalNameOfLocal(sym), tpe, mutable = true, jstpe.zeroOf(tpe))(vparam.span) + } - /* Ensure that the first JS class capture is a reference to the JS super class value. - * genNonNativeJSClass and genNewAnonJSClass rely on this. - */ - val captureParamsWithJSSuperClass = captureParams.map { params => - val jsSuperClassParam = js.ParamDef( - js.LocalIdent(JSSuperClassParamName), NoOriginalName, - jstpe.AnyType, mutable = false) - jsSuperClassParam :: params + /* organize constructors in a called-by tree + * (the implicit root is the primary constructor) + */ + val ctorTree = { + val ctorToChildren = secondaryCtors + .groupBy(_.targetCtor) + .withDefaultValue(Nil) + + /* when constructing the call-by tree, we use pre-order traversal to + * assign overload numbers. + * this puts all descendants of a ctor in a range of overloads numbers. + * + * this property is useful, later, when we need to make statements + * conditional based on the chosen overload. + */ + var nextOverloadNum = 0 + def subTree[T <: JSCtor](ctor: T): ConstructorTree[T] = { + val overloadNum = nextOverloadNum + nextOverloadNum += 1 + val subtrees = ctorToChildren(ctor.sym).map(subTree(_)) + new ConstructorTree(overloadNum, ctor, subtrees) + } + + subTree(primaryCtor) + } + + /* prepare overload dispatch for all constructors. + * as a side-product, we retrieve the capture parameters. + */ + val (exports, jsClassCaptures) = { + val exports = List.newBuilder[jsExportsGen.Exported] + val jsClassCaptures = List.newBuilder[js.ParamDef] + + def add(tree: ConstructorTree[_ <: JSCtor]): Unit = { + val (e, c) = genJSClassCtorDispatch(tree.ctor.sym, tree.ctor.params, tree.overloadNum) + exports += e + jsClassCaptures ++= c + tree.subCtors.foreach(add(_)) + } + + add(ctorTree) + + (exports.result(), jsClassCaptures.result()) + } + + val (formalArgs, restParam, overloadDispatchBody) = + jsExportsGen.genOverloadDispatch(JSName.Literal("constructor"), exports, jstpe.IntType) + + val overloadVar = js.VarDef(freshLocalIdent("overload"), NoOriginalName, + jstpe.IntType, mutable = false, overloadDispatchBody) + + val ctorStats = genJSClassCtorStats(overloadVar.ref, ctorTree) + + val constructorBody = js.Block( + paramVarDefs ::: List(overloadVar, ctorStats, js.Undefined())) + + val constructorDef = js.JSMethodDef( + js.MemberFlags.empty, + js.StringLiteral("constructor"), + formalArgs, restParam, constructorBody)(OptimizerHints.empty, None) + + (jsClassCaptures, constructorDef) + } + + private def genPrimaryJSClassCtor(dd: DefDef): PrimaryJSCtor = { + val sym = dd.symbol + val Block(stats, _) = dd.rhs + assert(sym.isPrimaryConstructor, s"called with non-primary ctor: $sym") + + var jsSuperCall: Option[js.JSSuperConstructorCall] = None + val jsStats = List.newBuilder[js.Tree] + + /* Move all statements after the super constructor call since JS + * cannot access `this` before the super constructor call. + * + * dotc inserts statements before the super constructor call for param + * accessor initializers (including val's and var's declared in the + * params). We move those after the super constructor call, and are + * therefore executed later than for a Scala class. + */ + withPerMethodBodyState(sym) { + stats.foreach { + case tree @ Apply(fun @ Select(Super(This(_), _), _), args) + if fun.symbol.isClassConstructor => + assert(jsSuperCall.isEmpty, s"Found 2 JS Super calls at ${dd.sourcePos}") + implicit val pos: Position = tree.span + jsSuperCall = Some(js.JSSuperConstructorCall(genActualJSArgs(fun.symbol, args))) + + case stat => + val jsStat = genStat(stat) + assert(jsSuperCall.isDefined || !jsStat.isInstanceOf[js.VarDef], + "Trying to move a local VarDef after the super constructor call of a non-native JS class at " + + dd.sourcePos) + jsStats += jsStat + } + } + + assert(jsSuperCall.isDefined, + s"Did not find Super call in primary JS construtor at ${dd.sourcePos}") + + val params = dd.paramss.flatten.map(_.symbol) + + new PrimaryJSCtor(sym, params, jsSuperCall.get :: jsStats.result()) + } + + private def genSecondaryJSClassCtor(dd: DefDef): SplitSecondaryJSCtor = { + val sym = dd.symbol + val Block(stats, _) = dd.rhs + assert(!sym.isPrimaryConstructor, s"called with primary ctor $sym") + + val beforeThisCall = List.newBuilder[js.Tree] + var thisCall: Option[(Symbol, List[js.Tree])] = None + val afterThisCall = List.newBuilder[js.Tree] + + withPerMethodBodyState(sym) { + stats.foreach { + case tree @ Apply(fun @ Select(This(_), _), args) + if fun.symbol.isClassConstructor => + assert(thisCall.isEmpty, + s"duplicate this() call in secondary JS constructor at ${dd.sourcePos}") + + implicit val pos: Position = tree.span + val sym = fun.symbol + thisCall = Some((sym, genActualArgs(sym, args))) + + case stat => + val jsStat = genStat(stat) + if (thisCall.isEmpty) + beforeThisCall += jsStat + else + afterThisCall += jsStat + } + } + + val Some((targetCtor, ctorArgs)) = thisCall + + val params = dd.paramss.flatten.map(_.symbol) + + new SplitSecondaryJSCtor(sym, params, beforeThisCall.result(), targetCtor, + ctorArgs, afterThisCall.result()) + } + + private def genJSClassCtorDispatch(sym: Symbol, allParams: List[Symbol], + overloadNum: Int): (jsExportsGen.Exported, List[js.ParamDef]) = { + + implicit val pos: SourcePosition = sym.sourcePos + + /* `allParams` are the parameters as seen from inside the constructor body, + * i.e., the ones generated by the trees in the constructor body. + */ + val (captureParamsAndInfos, normalParamsAndInfos) = + allParams.zip(sym.jsParamInfos).partition(_._2.capture) + + /* For class captures, we need to generate different names than the ones + * used by the constructor body. This is necessary so that we can forward + * captures properly between constructor delegation calls. + */ + val (jsClassCaptures, captureAssigns) = (for { + (param, info) <- captureParamsAndInfos + } yield { + val ident = freshLocalIdent(param.name.toTermName) + val jsClassCapture = + js.ParamDef(ident, originalNameOfLocal(param), toIRType(info.info), mutable = false) + val captureAssign = + js.Assign(genVarRef(param), jsClassCapture.ref) + (jsClassCapture, captureAssign) + }).unzip + + val normalInfos = normalParamsAndInfos.map(_._2).toIndexedSeq + + val jsExport = new jsExportsGen.Exported(sym, normalInfos) { + def genBody(formalArgsRegistry: jsExportsGen.FormalArgsRegistry): js.Tree = { + val paramAssigns = for { + ((param, info), i) <- normalParamsAndInfos.zipWithIndex + } yield { + val rhs = jsExportsGen.genScalaArg(this, i, formalArgsRegistry, info, static = true)( + prevArgsCount => allParams.take(prevArgsCount).map(genVarRef(_))) + + js.Assign(genVarRef(param), rhs) } - val ctorDef = JSConstructorGen.buildJSConstructorDef(dispatch, ctors, freshLocalIdent("overload")) { - msg => report.error(msg, classSym.srcPos) + js.Block(captureAssigns ::: paramAssigns, js.IntLiteral(overloadNum)) + } + } + + (jsExport, jsClassCaptures) + } + + /** generates a sequence of JS constructor statements based on a constructor tree. */ + private def genJSClassCtorStats(overloadVar: js.VarRef, + ctorTree: ConstructorTree[PrimaryJSCtor])(implicit pos: Position): js.Tree = { + + /* generates a statement that conditionally executes body iff the chosen + * overload is any of the descendants of `tree` (including itself). + * + * here we use the property from building the trees, that a set of + * descendants always has a range of overload numbers. + */ + def ifOverload(tree: ConstructorTree[_], body: js.Tree): js.Tree = body match { + case js.Skip() => js.Skip() + + case body => + val x = overloadVar + val cond = { + import tree.{lo, hi} + + if (lo == hi) { + js.BinaryOp(js.BinaryOp.Int_==, js.IntLiteral(lo), x) + } else { + val lhs = js.BinaryOp(js.BinaryOp.Int_<=, js.IntLiteral(lo), x) + val rhs = js.BinaryOp(js.BinaryOp.Int_<=, x, js.IntLiteral(hi)) + js.If(lhs, rhs, js.BooleanLiteral(false))(jstpe.BooleanType) + } } - (captureParamsWithJSSuperClass, ctorDef) + js.If(cond, body, js.Skip())(jstpe.NoType) + } + + /* preStats / postStats use pre/post order traversal respectively to + * generate a topo-sorted sequence of statements. + */ + + def preStats(tree: ConstructorTree[SplitSecondaryJSCtor], + nextParams: List[Symbol]): js.Tree = { + assert(tree.ctor.ctorArgs.size == nextParams.size, "param count mismatch") + + val inner = tree.subCtors.map(preStats(_, tree.ctor.params)) + + /* Reject undefined params (i.e. using a default value of another + * constructor) via implementation restriction. + * + * This is mostly for historical reasons. The ideal solution here would + * be to recognize calls to default param getters of JS class + * constructors and not even translate them to UndefinedParam in the + * first place. + */ + def isUndefinedParam(tree: js.Tree): Boolean = tree match { + case js.Transient(UndefinedParam) => true + case _ => false + } + + if (tree.ctor.ctorArgs.exists(isUndefinedParam)) { + report.error( + "Implementation restriction: " + + "in a JS class, a secondary constructor calling another constructor " + + "with default parameters must provide the values of all parameters.", + tree.ctor.sym.sourcePos) } + + val assignments = for { + (param, arg) <- nextParams.zip(tree.ctor.ctorArgs) + if !isUndefinedParam(arg) + } yield { + js.Assign(genVarRef(param), arg) + } + + ifOverload(tree, js.Block(inner ++ tree.ctor.beforeCall ++ assignments)) } + + def postStats(tree: ConstructorTree[SplitSecondaryJSCtor]): js.Tree = { + val inner = tree.subCtors.map(postStats(_)) + ifOverload(tree, js.Block(tree.ctor.afterCall ++ inner)) + } + + val primaryCtor = ctorTree.ctor + val secondaryCtorTrees = ctorTree.subCtors + + js.Block( + secondaryCtorTrees.map(preStats(_, primaryCtor.params)) ++ + primaryCtor.body ++ + secondaryCtorTrees.map(postStats(_)) + ) + } + + private sealed trait JSCtor { + val sym: Symbol + val params: List[Symbol] + } + + private class PrimaryJSCtor(val sym: Symbol, + val params: List[Symbol], val body: List[js.Tree]) extends JSCtor + + private class SplitSecondaryJSCtor(val sym: Symbol, + val params: List[Symbol], val beforeCall: List[js.Tree], + val targetCtor: Symbol, val ctorArgs: List[js.Tree], + val afterCall: List[js.Tree]) extends JSCtor + + private class ConstructorTree[Ctor <: JSCtor]( + val overloadNum: Int, val ctor: Ctor, + val subCtors: List[ConstructorTree[SplitSecondaryJSCtor]]) { + val lo: Int = overloadNum + val hi: Int = subCtors.lastOption.fold(lo)(_.hi) + + assert(lo <= hi, "bad overload range") } // Generate a method ------------------------------------------------------- @@ -1043,12 +1352,7 @@ class JSCodeGen()(using genCtx: Context) { val vparamss = dd.termParamss val rhs = dd.rhs - withScopedVars( - currentMethodSym := sym, - undefinedDefaultParams := mutable.Set.empty, - thisLocalVarIdent := None, - isModuleInitialized := new ScopedVar.VarBox(false) - ) { + withPerMethodBodyState(sym) { assert(vparamss.isEmpty || vparamss.tail.isEmpty, "Malformed parameter list: " + vparamss) val params = if (vparamss.isEmpty) Nil else vparamss.head.map(_.symbol) @@ -1104,15 +1408,9 @@ class JSCodeGen()(using genCtx: Context) { val methodDef = { if (sym.isClassConstructor) { - val body0 = genStat(rhs) - val body1 = { - val needsMove = currentClassSym.isNonNativeJSClass && sym.isPrimaryConstructor - if (needsMove) moveAllStatementsAfterSuperConstructorCall(body0) - else body0 - } val namespace = js.MemberNamespace.Constructor js.MethodDef(js.MemberFlags.empty.withNamespace(namespace), - methodName, originalName, jsParams, jstpe.NoType, Some(body1))( + methodName, originalName, jsParams, jstpe.NoType, Some(genStat(rhs)))( optimizerHints, None) } else { val namespace = if (isMethodStaticInIR(sym)) { @@ -1179,53 +1477,6 @@ class JSCodeGen()(using genCtx: Context) { } } - /** Moves all statements after the super constructor call. - * - * This is used for the primary constructor of a non-native JS class, - * because those cannot access `this` before the super constructor call. - * - * Normally, in Scala, param accessors (i.e., fields declared directly in - * constructor parameters) are initialized *before* the super constructor - * call. This is important for cases like - * - * abstract class A { - * def a: Int - * println(a) - * } - * class B(val a: Int) extends A - * - * where `a` is supposed to be correctly initialized by the time `println` - * is executed. - * - * However, in a JavaScript class, this is forbidden: it is not allowed to - * read the `this` value in a constructor before the super constructor call. - * - * Therefore, for JavaScript classes, we specifically move all those early - * assignments after the super constructor call, to comply with JavaScript - * limitations. This clearly introduces a semantic difference in - * initialization order between Scala classes and JavaScript classes, but - * there is nothing we can do about it. That difference in behavior is - * basically spec'ed in Scala.js the language, since specifying it any other - * way would prevent JavaScript classes from ever having constructor - * parameters. - * - * We do the same thing in Scala 2, obviously. - */ - private def moveAllStatementsAfterSuperConstructorCall(body: js.Tree): js.Tree = { - val bodyStats = body match { - case js.Block(stats) => stats - case _ => body :: Nil - } - - val (beforeSuper, superCall :: afterSuper) = - bodyStats.span(!_.isInstanceOf[js.JSSuperConstructorCall]) - - assert(!beforeSuper.exists(_.isInstanceOf[js.VarDef]), - s"Trying to move a local VarDef after the super constructor call of a non-native JS class at ${body.pos}") - - js.Block(superCall :: beforeSuper ::: afterSuper)(body.pos) - } - // ParamDefs --------------------------------------------------------------- def genParamDef(sym: Symbol): js.ParamDef = @@ -1370,8 +1621,12 @@ class JSCodeGen()(using genCtx: Context) { } case If(cond, thenp, elsep) => + val tpe = + if (isStat) jstpe.NoType + else toIRType(tree.tpe) + js.If(genExpr(cond), genStatOrExpr(thenp, isStat), - genStatOrExpr(elsep, isStat))(toIRType(tree.tpe)) + genStatOrExpr(elsep, isStat))(tpe) case Labeled(bind, expr) => js.Labeled(encodeLabelSym(bind.symbol), toIRType(tree.tpe), genStatOrExpr(expr, isStat)) @@ -1439,7 +1694,7 @@ class JSCodeGen()(using genCtx: Context) { */ js.Transient(UndefinedParam) } else { - js.VarRef(encodeLocalSym(sym))(toIRType(sym.info)) + genVarRef(sym) } } { select => genStatOrExpr(select, isStat) @@ -1524,9 +1779,7 @@ class JSCodeGen()(using genCtx: Context) { } case _ => - js.Assign( - js.VarRef(encodeLocalSym(sym))(toIRType(sym.info)), - genRhs) + js.Assign(genVarRef(sym), genRhs) } /** Array constructor */ @@ -1608,7 +1861,10 @@ class JSCodeGen()(using genCtx: Context) { val Try(block, catches, finalizer) = tree val blockAST = genStatOrExpr(block, isStat) - val resultType = toIRType(tree.tpe) + + val resultType = + if (isStat) jstpe.NoType + else toIRType(tree.tpe) val handled = if (catches.isEmpty) blockAST @@ -2377,7 +2633,10 @@ class JSCodeGen()(using genCtx: Context) { js.UnaryOp(IntToLong, intValue) } case jstpe.FloatType => - js.UnaryOp(js.UnaryOp.DoubleToFloat, doubleValue) + if (from == jstpe.LongType) + js.UnaryOp(js.UnaryOp.LongToFloat, value) + else + js.UnaryOp(js.UnaryOp.DoubleToFloat, doubleValue) case jstpe.DoubleType => doubleValue } @@ -2804,9 +3063,8 @@ class JSCodeGen()(using genCtx: Context) { s"Trying to call the super constructor of Object in a non-native JS class at $pos") genApplyMethod(genReceiver, sym, genScalaArgs) } else if (sym.isClassConstructor) { - assert(genReceiver.isInstanceOf[js.This], - s"Trying to call a JS super constructor with a non-`this` receiver at $pos") - js.JSSuperConstructorCall(genJSArgs) + throw new AssertionError( + s"calling a JS super constructor should have happened in genPrimaryJSClassCtor at $pos") } else if (sym.owner.isNonNativeJSClass && !sym.isJSExposed) { // Reroute to the static method genApplyJSClassMethod(genReceiver, sym, genScalaArgs) @@ -3437,6 +3695,15 @@ class JSCodeGen()(using genCtx: Context) { // BoxedUnit.UNIT, which is the boxed version of () js.Undefined() + case JS_IMPORT => + // js.import(arg) + val arg = genArgs1 + js.JSImportCall(arg) + + case JS_IMPORT_META => + // js.import.meta + js.JSImportMeta() + case JS_NATIVE => // js.native report.error( @@ -3727,46 +3994,24 @@ class JSCodeGen()(using genCtx: Context) { private def genActualJSArgs(sym: Symbol, args: List[Tree])( implicit pos: Position): List[js.TreeOrJSSpread] = { - def paramNamesAndTypes(using Context): List[(Names.TermName, Type)] = - sym.info.paramNamess.flatten.zip(sym.info.paramInfoss.flatten) - - val wereRepeated = atPhase(elimRepeatedPhase) { - val list = - for ((name, tpe) <- paramNamesAndTypes) - yield (name -> tpe.isRepeatedParam) - list.toMap - } - - val paramTypes = atPhase(elimErasedValueTypePhase) { - paramNamesAndTypes.toMap - } - var reversedArgs: List[js.TreeOrJSSpread] = Nil - val argsParamNamesAndTypes = args.zip(paramNamesAndTypes) - for ((arg, (paramName, paramType)) <- argsParamNamesAndTypes) { - val wasRepeated = wereRepeated.get(paramName) - - wasRepeated match { - case Some(true) => - reversedArgs = genJSRepeatedParam(arg) reverse_::: reversedArgs - - case Some(false) => - val unboxedArg = genExpr(arg) - val boxedArg = unboxedArg match { - case js.Transient(UndefinedParam) => - unboxedArg - case _ => - val tpe = paramTypes.getOrElse(paramName, paramType) - box(unboxedArg, tpe) - } - reversedArgs ::= boxedArg - - case None => - // This is a parameter introduced by erasure or lambdalift, which we ignore. - assert(sym.isClassConstructor, - i"Found an unknown param $paramName in method " + - i"${sym.fullName}, which is not a class constructor, at $pos") + for ((arg, info) <- args.zip(sym.jsParamInfos)) { + if (info.repeated) { + reversedArgs = genJSRepeatedParam(arg) reverse_::: reversedArgs + } else if (info.capture) { + // Ignore captures + assert(sym.isClassConstructor, + i"Found a capture param in method ${sym.fullName}, which is not a class constructor, at $pos") + } else { + val unboxedArg = genExpr(arg) + val boxedArg = unboxedArg match { + case js.Transient(UndefinedParam) => + unboxedArg + case _ => + box(unboxedArg, info.info) + } + reversedArgs ::= boxedArg } } @@ -3899,6 +4144,9 @@ class JSCodeGen()(using genCtx: Context) { } } + private def genVarRef(sym: Symbol)(implicit pos: Position): js.VarRef = + js.VarRef(encodeLocalSym(sym))(toIRType(sym.info)) + private def genAssignableField(sym: Symbol, qualifier: Tree)(implicit pos: SourcePosition): (js.AssignLhs, Boolean) = { def qual = genExpr(qualifier) diff --git a/compiler/src/dotty/tools/backend/sjs/JSConstructorGen.scala b/compiler/src/dotty/tools/backend/sjs/JSConstructorGen.scala deleted file mode 100644 index 25ec8ff53c6b..000000000000 --- a/compiler/src/dotty/tools/backend/sjs/JSConstructorGen.scala +++ /dev/null @@ -1,376 +0,0 @@ -package dotty.tools.backend.sjs - -import org.scalajs.ir -import org.scalajs.ir.{Position, Trees => js, Types => jstpe} -import org.scalajs.ir.Names._ -import org.scalajs.ir.OriginalName.NoOriginalName - -import JSCodeGen.UndefinedParam - -object JSConstructorGen { - - /** Builds one JS constructor out of several "init" methods and their - * dispatcher. - * - * This method and the rest of this file are copied verbatim from `GenJSCode` - * for scalac, since there is no dependency on the compiler trees/symbols/etc. - * We are only manipulating IR trees and types. - * - * The only difference is the two parameters `overloadIdent` and `reportError`, - * which are added so that this entire file can be even more isolated. - */ - def buildJSConstructorDef(dispatch: js.JSMethodDef, ctors: List[js.MethodDef], - overloadIdent: js.LocalIdent)( - reportError: String => Unit)( - implicit pos: Position): js.JSMethodDef = { - - val js.JSMethodDef(_, dispatchName, dispatchArgs, dispatchRestParam, dispatchResolution) = - dispatch - - val jsConstructorBuilder = mkJSConstructorBuilder(ctors, reportError) - - // Section containing the overload resolution and casts of parameters - val overloadSelection = mkOverloadSelection(jsConstructorBuilder, - overloadIdent, dispatchResolution) - - /* Section containing all the code executed before the call to `this` - * for every secondary constructor. - */ - val prePrimaryCtorBody = - jsConstructorBuilder.mkPrePrimaryCtorBody(overloadIdent) - - val primaryCtorBody = jsConstructorBuilder.primaryCtorBody - - /* Section containing all the code executed after the call to this for - * every secondary constructor. - */ - val postPrimaryCtorBody = - jsConstructorBuilder.mkPostPrimaryCtorBody(overloadIdent) - - val newBody = js.Block(overloadSelection ::: prePrimaryCtorBody :: - primaryCtorBody :: postPrimaryCtorBody :: js.Undefined() :: Nil) - - js.JSMethodDef(js.MemberFlags.empty, dispatchName, dispatchArgs, dispatchRestParam, newBody)( - dispatch.optimizerHints, None) - } - - private class ConstructorTree(val overrideNum: Int, val method: js.MethodDef, - val subConstructors: List[ConstructorTree]) { - - lazy val overrideNumBounds: (Int, Int) = - if (subConstructors.isEmpty) (overrideNum, overrideNum) - else (subConstructors.head.overrideNumBounds._1, overrideNum) - - def get(methodName: MethodName): Option[ConstructorTree] = { - if (methodName == this.method.methodName) { - Some(this) - } else { - subConstructors.iterator.map(_.get(methodName)).collectFirst { - case Some(node) => node - } - } - } - - def getParamRefs(implicit pos: Position): List[js.VarRef] = - method.args.map(_.ref) - - def getAllParamDefsAsVars(implicit pos: Position): List[js.VarDef] = { - val localDefs = method.args.map { pDef => - js.VarDef(pDef.name, pDef.originalName, pDef.ptpe, mutable = true, - jstpe.zeroOf(pDef.ptpe)) - } - localDefs ++ subConstructors.flatMap(_.getAllParamDefsAsVars) - } - } - - private class JSConstructorBuilder(root: ConstructorTree, reportError: String => Unit) { - - def primaryCtorBody: js.Tree = root.method.body.getOrElse( - throw new AssertionError("Found abstract constructor")) - - def hasSubConstructors: Boolean = root.subConstructors.nonEmpty - - def getOverrideNum(methodName: MethodName): Int = - root.get(methodName).fold(-1)(_.overrideNum) - - def getParamRefsFor(methodName: MethodName)(implicit pos: Position): List[js.VarRef] = - root.get(methodName).fold(List.empty[js.VarRef])(_.getParamRefs) - - def getAllParamDefsAsVars(implicit pos: Position): List[js.VarDef] = - root.getAllParamDefsAsVars - - def mkPrePrimaryCtorBody(overrideNumIdent: js.LocalIdent)( - implicit pos: Position): js.Tree = { - val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType) - mkSubPreCalls(root, overrideNumRef) - } - - def mkPostPrimaryCtorBody(overrideNumIdent: js.LocalIdent)( - implicit pos: Position): js.Tree = { - val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType) - js.Block(mkSubPostCalls(root, overrideNumRef)) - } - - private def mkSubPreCalls(constructorTree: ConstructorTree, - overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = { - val overrideNumss = constructorTree.subConstructors.map(_.overrideNumBounds) - val paramRefs = constructorTree.getParamRefs - val bodies = constructorTree.subConstructors.map { constructorTree => - mkPrePrimaryCtorBodyOnSndCtr(constructorTree, overrideNumRef, paramRefs) - } - overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) { - case ((numBounds, body), acc) => - val cond = mkOverrideNumsCond(overrideNumRef, numBounds) - js.If(cond, body, acc)(jstpe.BooleanType) - } - } - - private def mkPrePrimaryCtorBodyOnSndCtr(constructorTree: ConstructorTree, - overrideNumRef: js.VarRef, outputParams: List[js.VarRef])( - implicit pos: Position): js.Tree = { - val subCalls = - mkSubPreCalls(constructorTree, overrideNumRef) - - val preSuperCall = { - def checkForUndefinedParams(args: List[js.Tree]): List[js.Tree] = { - def isUndefinedParam(tree: js.Tree): Boolean = tree match { - case js.Transient(UndefinedParam) => true - case _ => false - } - - if (!args.exists(isUndefinedParam)) { - args - } else { - /* If we find an undefined param here, we're in trouble, because - * the handling of a default param for the target constructor has - * already been done during overload resolution. If we store an - * `undefined` now, it will fall through without being properly - * processed. - * - * Since this seems very tricky to deal with, and a pretty rare - * use case (with a workaround), we emit an "implementation - * restriction" error. - */ - reportError( - "Implementation restriction: in a JS class, a secondary " + - "constructor calling another constructor with default " + - "parameters must provide the values of all parameters.") - - /* Replace undefined params by undefined to prevent subsequent - * compiler crashes. - */ - args.map { arg => - if (isUndefinedParam(arg)) - js.Undefined()(arg.pos) - else - arg - } - } - } - - constructorTree.method.body.get match { - case js.Block(stats) => - val beforeSuperCall = stats.takeWhile { - case js.ApplyStatic(_, _, mtd, _) => !mtd.name.isConstructor - case _ => true - } - val superCallParams = stats.collectFirst { - case js.ApplyStatic(_, _, mtd, js.This() :: args) - if mtd.name.isConstructor => - val checkedArgs = checkForUndefinedParams(args) - zipMap(outputParams, checkedArgs)(js.Assign(_, _)) - }.getOrElse(Nil) - - beforeSuperCall ::: superCallParams - - case js.ApplyStatic(_, _, mtd, js.This() :: args) - if mtd.name.isConstructor => - val checkedArgs = checkForUndefinedParams(args) - zipMap(outputParams, checkedArgs)(js.Assign(_, _)) - - case _ => Nil - } - } - - js.Block(subCalls :: preSuperCall) - } - - private def mkSubPostCalls(constructorTree: ConstructorTree, - overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = { - val overrideNumss = constructorTree.subConstructors.map(_.overrideNumBounds) - val bodies = constructorTree.subConstructors.map { ct => - mkPostPrimaryCtorBodyOnSndCtr(ct, overrideNumRef) - } - overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) { - case ((numBounds, js.Skip()), acc) => acc - - case ((numBounds, body), acc) => - val cond = mkOverrideNumsCond(overrideNumRef, numBounds) - js.If(cond, body, acc)(jstpe.BooleanType) - } - } - - private def mkPostPrimaryCtorBodyOnSndCtr(constructorTree: ConstructorTree, - overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = { - val postSuperCall = { - constructorTree.method.body.get match { - case js.Block(stats) => - stats.dropWhile { - case js.ApplyStatic(_, _, mtd, _) => !mtd.name.isConstructor - case _ => true - }.tail - - case _ => Nil - } - } - js.Block(postSuperCall :+ mkSubPostCalls(constructorTree, overrideNumRef)) - } - - private def mkOverrideNumsCond(numRef: js.VarRef, - numBounds: (Int, Int))(implicit pos: Position) = numBounds match { - case (lo, hi) if lo == hi => - js.BinaryOp(js.BinaryOp.Int_==, js.IntLiteral(lo), numRef) - - case (lo, hi) if lo == hi - 1 => - val lhs = js.BinaryOp(js.BinaryOp.Int_==, numRef, js.IntLiteral(lo)) - val rhs = js.BinaryOp(js.BinaryOp.Int_==, numRef, js.IntLiteral(hi)) - js.If(lhs, js.BooleanLiteral(true), rhs)(jstpe.BooleanType) - - case (lo, hi) => - val lhs = js.BinaryOp(js.BinaryOp.Int_<=, js.IntLiteral(lo), numRef) - val rhs = js.BinaryOp(js.BinaryOp.Int_<=, numRef, js.IntLiteral(hi)) - js.BinaryOp(js.BinaryOp.Boolean_&, lhs, rhs) - js.If(lhs, rhs, js.BooleanLiteral(false))(jstpe.BooleanType) - } - } - - private def zipMap[T, U, V](xs: List[T], ys: List[U])( - f: (T, U) => V): List[V] = { - for ((x, y) <- xs zip ys) yield f(x, y) - } - - /** mkOverloadSelection return a list of `stats` with that starts with: - * 1) The definition for the local variable that will hold the overload - * resolution number. - * 2) The definitions of all local variables that are used as parameters - * in all the constructors. - * 3) The overload resolution match/if statements. For each overload the - * overload number is assigned and the parameters are cast and assigned - * to their corresponding variables. - */ - private def mkOverloadSelection(jsConstructorBuilder: JSConstructorBuilder, - overloadIdent: js.LocalIdent, dispatchResolution: js.Tree)( - implicit pos: Position): List[js.Tree] = { - - def deconstructApplyCtor(body: js.Tree): (List[js.Tree], MethodName, List[js.Tree]) = { - val (prepStats, applyCtor) = (body: @unchecked) match { - case applyCtor: js.ApplyStatic => - (Nil, applyCtor) - case js.Block(prepStats :+ (applyCtor: js.ApplyStatic)) => - (prepStats, applyCtor) - } - val js.ApplyStatic(_, _, js.MethodIdent(ctorName), js.This() :: ctorArgs) = - applyCtor - assert(ctorName.isConstructor, - s"unexpected super constructor call to non-constructor $ctorName at ${applyCtor.pos}") - (prepStats, ctorName, ctorArgs) - } - - if (!jsConstructorBuilder.hasSubConstructors) { - val (prepStats, ctorName, ctorArgs) = - deconstructApplyCtor(dispatchResolution) - - val refs = jsConstructorBuilder.getParamRefsFor(ctorName) - assert(refs.size == ctorArgs.size, s"at $pos") - val assignCtorParams = zipMap(refs, ctorArgs) { (ref, ctorArg) => - js.VarDef(ref.ident, NoOriginalName, ref.tpe, mutable = false, ctorArg) - } - - prepStats ::: assignCtorParams - } else { - val overloadRef = js.VarRef(overloadIdent)(jstpe.IntType) - - /* transformDispatch takes the body of the method generated by - * `genJSConstructorDispatch` and transform it recursively. - */ - def transformDispatch(tree: js.Tree): js.Tree = tree match { - // Parameter count resolution - case js.Match(selector, cases, default) => - val newCases = cases.map { - case (literals, body) => (literals, transformDispatch(body)) - } - val newDefault = transformDispatch(default) - js.Match(selector, newCases, newDefault)(tree.tpe) - - // Parameter type resolution - case js.If(cond, thenp, elsep) => - js.If(cond, transformDispatch(thenp), - transformDispatch(elsep))(tree.tpe) - - // Throw(StringLiteral(No matching overload)) - case tree: js.Throw => - tree - - // Overload resolution done, apply the constructor - case _ => - val (prepStats, ctorName, ctorArgs) = deconstructApplyCtor(tree) - - val num = jsConstructorBuilder.getOverrideNum(ctorName) - val overloadAssign = js.Assign(overloadRef, js.IntLiteral(num)) - - val refs = jsConstructorBuilder.getParamRefsFor(ctorName) - assert(refs.size == ctorArgs.size, s"at $pos") - val assignCtorParams = zipMap(refs, ctorArgs)(js.Assign(_, _)) - - js.Block(overloadAssign :: prepStats ::: assignCtorParams) - } - - val newDispatchResolution = transformDispatch(dispatchResolution) - val allParamDefsAsVars = jsConstructorBuilder.getAllParamDefsAsVars - val overrideNumDef = js.VarDef(overloadIdent, NoOriginalName, - jstpe.IntType, mutable = true, js.IntLiteral(0)) - - overrideNumDef :: allParamDefsAsVars ::: newDispatchResolution :: Nil - } - } - - private def mkJSConstructorBuilder(ctors: List[js.MethodDef], reportError: String => Unit)( - implicit pos: Position): JSConstructorBuilder = { - def findCtorForwarderCall(tree: js.Tree): MethodName = (tree: @unchecked) match { - case js.ApplyStatic(_, _, method, js.This() :: _) - if method.name.isConstructor => - method.name - - case js.Block(stats) => - stats.collectFirst { - case js.ApplyStatic(_, _, method, js.This() :: _) - if method.name.isConstructor => - method.name - }.get - } - - val (primaryCtor :: Nil, secondaryCtors) = ctors.partition { - _.body.get match { - case js.Block(stats) => - stats.exists(_.isInstanceOf[js.JSSuperConstructorCall]) - - case _: js.JSSuperConstructorCall => true - case _ => false - } - } - - val ctorToChildren = secondaryCtors.map { ctor => - findCtorForwarderCall(ctor.body.get) -> ctor - }.groupBy(_._1).map(kv => kv._1 -> kv._2.map(_._2)).withDefaultValue(Nil) - - var overrideNum = -1 - def mkConstructorTree(method: js.MethodDef): ConstructorTree = { - val subCtrTrees = ctorToChildren(method.methodName).map(mkConstructorTree) - overrideNum += 1 - new ConstructorTree(overrideNum, method, subCtrTrees) - } - - new JSConstructorBuilder(mkConstructorTree(primaryCtor), reportError: String => Unit) - } - -} diff --git a/compiler/src/dotty/tools/backend/sjs/JSDefinitions.scala b/compiler/src/dotty/tools/backend/sjs/JSDefinitions.scala index a97b6ad2687e..c02e0c030657 100644 --- a/compiler/src/dotty/tools/backend/sjs/JSDefinitions.scala +++ b/compiler/src/dotty/tools/backend/sjs/JSDefinitions.scala @@ -145,6 +145,13 @@ final class JSDefinitions()(using Context) { @threadUnsafe lazy val JSConstructorTag_materializeR = JSConstructorTagModule.requiredMethodRef("materialize") def JSConstructorTag_materialize(using Context) = JSConstructorTag_materializeR.symbol + @threadUnsafe lazy val JSImportModuleRef = requiredModuleRef("scala.scalajs.js.import") + def JSImportModule(using Context) = JSImportModuleRef.symbol + @threadUnsafe lazy val JSImport_applyR = JSImportModule.requiredMethodRef(nme.apply) + def JSImport_apply(using Context) = JSImport_applyR.symbol + @threadUnsafe lazy val JSImport_metaR = JSImportModule.requiredMethodRef("meta") + def JSImport_meta(using Context) = JSImport_metaR.symbol + @threadUnsafe lazy val RuntimePackageVal = requiredPackage("scala.scalajs.runtime") @threadUnsafe lazy val RuntimePackageClass = RuntimePackageVal.moduleClass.asClass @threadUnsafe lazy val RuntimePackage_wrapJavaScriptExceptionR = RuntimePackageClass.requiredMethodRef("wrapJavaScriptException") diff --git a/compiler/src/dotty/tools/backend/sjs/JSExportsGen.scala b/compiler/src/dotty/tools/backend/sjs/JSExportsGen.scala index f52cf1d29d53..22aafc95e11b 100644 --- a/compiler/src/dotty/tools/backend/sjs/JSExportsGen.scala +++ b/compiler/src/dotty/tools/backend/sjs/JSExportsGen.scala @@ -174,7 +174,7 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { js.TopLevelJSClassExportDef(info.moduleID, info.jsName) case Constructor | Method => - val exported = tups.map(t => Exported(t._2)) + val exported = tups.map(_._2) val methodDef = withNewLocalNameScope { genExportMethod(exported, JSName.Literal(info.jsName), static = true) @@ -330,32 +330,10 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { if (isProp) genExportProperty(alts, jsName, static) else - genExportMethod(alts.map(Exported.apply), jsName, static) + genExportMethod(alts, jsName, static) } } - def genJSConstructorDispatch(alts: List[Symbol]): (Option[List[js.ParamDef]], js.JSMethodDef) = { - val exporteds = alts.map(Exported.apply) - - val isConstructorOfNestedJSClass = exporteds.head.isConstructorOfNestedJSClass - assert(exporteds.tail.forall(_.isConstructorOfNestedJSClass == isConstructorOfNestedJSClass), - s"Alternative constructors $alts do not agree on whether they are in a nested JS class or not") - val captureParams = if (!isConstructorOfNestedJSClass) { - None - } else { - Some(for { - exported <- exporteds - param <- exported.captureParamsFront ::: exported.captureParamsBack - } yield { - param - }) - } - - val ctorDef = genExportMethod(exporteds, JSName.Literal("constructor"), static = false) - - (captureParams, ctorDef) - } - private def genExportProperty(alts: List[Symbol], jsName: JSName, static: Boolean): js.JSPropertyDef = { assert(!alts.isEmpty, s"genExportProperty with empty alternatives for $jsName") @@ -382,7 +360,7 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { } val getterBody = getter.headOption.map { getterSym => - genApplyForSingleExported(new FormalArgsRegistry(0, false), Exported(getterSym), static) + genApplyForSingleExported(new FormalArgsRegistry(0, false), new ExportedSymbol(getterSym, static), static) } val setterArgAndBody = { @@ -391,7 +369,8 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { } else { val formalArgsRegistry = new FormalArgsRegistry(1, false) val (List(arg), None) = formalArgsRegistry.genFormalArgs() - val body = genExportSameArgc(jsName, formalArgsRegistry, setters.map(Exported.apply), static, None) + val body = genOverloadDispatchSameArgc(jsName, formalArgsRegistry, + setters.map(new ExportedSymbol(_, static)), jstpe.AnyType, None) Some((arg, body)) } } @@ -399,10 +378,10 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { js.JSPropertyDef(flags, genExpr(jsName)(alts.head.sourcePos), getterBody, setterArgAndBody) } - private def genExportMethod(alts0: List[Exported], jsName: JSName, static: Boolean): js.JSMethodDef = { + private def genExportMethod(alts0: List[Symbol], jsName: JSName, static: Boolean)(using Context): js.JSMethodDef = { assert(alts0.nonEmpty, "need at least one alternative to generate exporter method") - implicit val pos = alts0.head.pos + implicit val pos: SourcePosition = alts0.head.sourcePos val namespace = if (static) js.MemberNamespace.PublicStatic @@ -411,12 +390,24 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { // toString() is always exported. We might need to add it here to get correct overloading. val alts = jsName match { - case JSName.Literal("toString") if alts0.forall(_.params.nonEmpty) => - Exported(defn.Any_toString) :: alts0 + case JSName.Literal("toString") if alts0.forall(_.info.paramInfoss.exists(_.nonEmpty)) => + defn.Any_toString :: alts0 case _ => alts0 } + val overloads = alts.map(new ExportedSymbol(_, static)) + + val (formalArgs, restParam, body) = + genOverloadDispatch(jsName, overloads, jstpe.AnyType) + + js.JSMethodDef(flags, genExpr(jsName), formalArgs, restParam, body)( + OptimizerHints.empty, None) + } + + def genOverloadDispatch(jsName: JSName, alts: List[Exported], tpe: jstpe.Type)( + using pos: SourcePosition): (List[js.ParamDef], Option[js.ParamDef], js.Tree) = { + // Create the formal args registry val hasVarArg = alts.exists(_.hasRepeatedParam) val minArgc = alts.map(_.minArgc).min @@ -437,14 +428,14 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { * ported to dotc. */ val body = - if (alts.tail.isEmpty) genApplyForSingleExported(formalArgsRegistry, alts.head, static) - else genExportMethodMultiAlts(formalArgsRegistry, maxNonRepeatedArgc, alts, jsName, static) + if (alts.tail.isEmpty) alts.head.genBody(formalArgsRegistry) + else genExportMethodMultiAlts(formalArgsRegistry, maxNonRepeatedArgc, alts, tpe, jsName) - js.JSMethodDef(flags, genExpr(jsName), formalArgs, restParam, body)(OptimizerHints.empty, None) + (formalArgs, restParam, body) } private def genExportMethodMultiAlts(formalArgsRegistry: FormalArgsRegistry, - maxNonRepeatedArgc: Int, alts: List[Exported], jsName: JSName, static: Boolean)( + maxNonRepeatedArgc: Int, alts: List[Exported], tpe: jstpe.Type, jsName: JSName)( implicit pos: SourcePosition): js.Tree = { // Generate tuples (argc, method) @@ -469,7 +460,7 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { if methods != altsWithVarArgs // exclude default case we're generating anyways for varargs } yield { // body of case to disambiguates methods with current count - val caseBody = genExportSameArgc(jsName, formalArgsRegistry, methods, static, Some(argc)) + val caseBody = genOverloadDispatchSameArgc(jsName, formalArgsRegistry, methods, tpe, Some(argc)) List(js.IntLiteral(argc - formalArgsRegistry.minArgc)) -> caseBody } @@ -477,7 +468,7 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { if (altsWithVarArgs.isEmpty) genThrowTypeError() else - genExportSameArgc(jsName, formalArgsRegistry, altsWithVarArgs, static, None) + genOverloadDispatchSameArgc(jsName, formalArgsRegistry, altsWithVarArgs, tpe, None) } val body = { @@ -491,7 +482,7 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { js.AsInstanceOf(js.JSSelect(restArgRef, js.StringLiteral("length")), jstpe.IntType), cases, defaultCase)( - jstpe.AnyType) + tpe) } } @@ -506,14 +497,14 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { * The registry of all the formal arguments * @param alts * Alternative methods - * @param static - * Whether we are generating a static method + * @param tpe + * Result type * @param maxArgc * Maximum number of arguments to use for disambiguation */ - private def genExportSameArgc(jsName: JSName, formalArgsRegistry: FormalArgsRegistry, - alts: List[Exported], static: Boolean, maxArgc: Option[Int]): js.Tree = { - genExportSameArgcRec(jsName, formalArgsRegistry, alts, paramIndex = 0, static, maxArgc) + private def genOverloadDispatchSameArgc(jsName: JSName, formalArgsRegistry: FormalArgsRegistry, + alts: List[Exported], tpe: jstpe.Type, maxArgc: Option[Int]): js.Tree = { + genOverloadDispatchSameArgcRec(jsName, formalArgsRegistry, alts, tpe, paramIndex = 0, maxArgc) } /** Resolves method calls to [[alts]] while assuming they have the same parameter count. @@ -524,20 +515,20 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { * The registry of all the formal arguments * @param alts * Alternative methods + * @param tpe + * Result type * @param paramIndex * Index where to start disambiguation (starts at 0, increases through recursion) - * @param static - * Whether we are generating a static method * @param maxArgc * Maximum number of arguments to use for disambiguation */ - private def genExportSameArgcRec(jsName: JSName, formalArgsRegistry: FormalArgsRegistry, alts: List[Exported], - paramIndex: Int, static: Boolean, maxArgc: Option[Int]): js.Tree = { + private def genOverloadDispatchSameArgcRec(jsName: JSName, formalArgsRegistry: FormalArgsRegistry, + alts: List[Exported], tpe: jstpe.Type, paramIndex: Int, maxArgc: Option[Int]): js.Tree = { implicit val pos = alts.head.pos if (alts.sizeIs == 1) { - genApplyForSingleExported(formalArgsRegistry, alts.head, static) + alts.head.genBody(formalArgsRegistry) } else if (maxArgc.exists(_ <= paramIndex) || !alts.exists(_.params.size > paramIndex)) { // We reach here in three cases: // 1. The parameter list has been exhausted @@ -553,10 +544,25 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { if (altsByTypeTest.size == 1) { // Testing this parameter is not doing any us good - genExportSameArgcRec(jsName, formalArgsRegistry, alts, paramIndex + 1, static, maxArgc) + genOverloadDispatchSameArgcRec(jsName, formalArgsRegistry, alts, tpe, paramIndex + 1, maxArgc) } else { // Sort them so that, e.g., isInstanceOf[String] comes before isInstanceOf[Object] - val sortedAltsByTypeTest = topoSortDistinctsBy(altsByTypeTest)(_._1) + val sortedAltsByTypeTest = topoSortDistinctsWith(altsByTypeTest) { (lhs, rhs) => + (lhs._1, rhs._1) match { + // NoTypeTest is always last + case (_, NoTypeTest) => true + case (NoTypeTest, _) => false + + case (PrimitiveTypeTest(_, rank1), PrimitiveTypeTest(_, rank2)) => + rank1 <= rank2 + + case (InstanceOfTypeTest(t1), InstanceOfTypeTest(t2)) => + t1 <:< t2 + + case (_: PrimitiveTypeTest, _: InstanceOfTypeTest) => true + case (_: InstanceOfTypeTest, _: PrimitiveTypeTest) => false + } + } val defaultCase = genThrowTypeError() @@ -565,14 +571,10 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { implicit val pos = subAlts.head.pos val paramRef = formalArgsRegistry.genArgRef(paramIndex) - val genSubAlts = genExportSameArgcRec(jsName, formalArgsRegistry, - subAlts, paramIndex + 1, static, maxArgc) + val genSubAlts = genOverloadDispatchSameArgcRec(jsName, formalArgsRegistry, + subAlts, tpe, paramIndex + 1, maxArgc) - def hasDefaultParam = subAlts.exists { exported => - val params = exported.params - params.size > paramIndex && - params(paramIndex).hasDefault - } + def hasDefaultParam = subAlts.exists(_.hasDefaultAt(paramIndex)) val optCond = typeTest match { case PrimitiveTypeTest(tpe, _) => Some(js.IsInstanceOf(paramRef, tpe)) @@ -588,7 +590,7 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { js.BinaryOp(js.BinaryOp.===, paramRef, js.Undefined()))( jstpe.BooleanType) } - js.If(condOrUndef, genSubAlts, elsep)(jstpe.AnyType) + js.If(condOrUndef, genSubAlts, elsep)(tpe) } } } @@ -677,60 +679,50 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { implicit val pos = exported.pos - // Generate JS code to prepare arguments (repeated args, default getters and unboxes) - val jsArgPrep = genPrepareArgs(formalArgsRegistry, exported, static) - val jsArgPrepRefs = jsArgPrep.map(_.ref) + val varDefs = new mutable.ListBuffer[js.VarDef] + + for ((param, i) <- exported.params.zipWithIndex) { + val rhs = genScalaArg(exported, i, formalArgsRegistry, param, static)( + prevArgsCount => varDefs.take(prevArgsCount).toList.map(_.ref)) - // Combine prep'ed formal arguments with captures - val allJSArgs = { - exported.captureParamsFront.map(_.ref) ::: - jsArgPrepRefs ::: - exported.captureParamsBack.map(_.ref) + varDefs += js.VarDef(freshLocalIdent("prep" + i), NoOriginalName, rhs.tpe, mutable = false, rhs) } - val jsResult = genResult(exported, allJSArgs, static) + val builtVarDefs = varDefs.result() + + val jsResult = genResult(exported, builtVarDefs.map(_.ref), static) - js.Block(jsArgPrep :+ jsResult) + js.Block(builtVarDefs :+ jsResult) } - /** Generate the necessary JavaScript code to prepare the arguments of an - * exported method (unboxing and default parameter handling) + /** Generates a Scala argument from dispatched JavaScript arguments + * (unboxing and default parameter handling). */ - private def genPrepareArgs(formalArgsRegistry: FormalArgsRegistry, exported: Exported, static: Boolean)( - implicit pos: SourcePosition): List[js.VarDef] = { - - val result = new mutable.ListBuffer[js.VarDef] + def genScalaArg(exported: Exported, paramIndex: Int, formalArgsRegistry: FormalArgsRegistry, + param: JSParamInfo, static: Boolean)( + previousArgsValues: Int => List[js.Tree])( + implicit pos: SourcePosition): js.Tree = { - for ((param, i) <- exported.params.zipWithIndex) yield { - val verifiedOrDefault = if (param.isRepeated) { - genJSArrayToVarArgs(formalArgsRegistry.genVarargRef(i)) - } else { - val jsArg = formalArgsRegistry.genArgRef(i) + if (param.repeated) { + genJSArrayToVarArgs(formalArgsRegistry.genVarargRef(paramIndex)) + } else { + val jsArg = formalArgsRegistry.genArgRef(paramIndex) - // Unboxed argument (if it is defined) - val unboxedArg = unbox(jsArg, param.info) + // Unboxed argument (if it is defined) + val unboxedArg = unbox(jsArg, param.info) + if (exported.hasDefaultAt(paramIndex)) { // If argument is undefined and there is a default getter, call it - if (param.hasDefault) { - js.If(js.BinaryOp(js.BinaryOp.===, jsArg, js.Undefined()), { - genCallDefaultGetter(exported.sym, i, static) { - prevArgsCount => result.take(prevArgsCount).toList.map(_.ref) - } - }, { - // Otherwise, unbox the argument - unboxedArg - })(unboxedArg.tpe) - } else { - // Otherwise, it is always the unboxed argument + js.If(js.BinaryOp(js.BinaryOp.===, jsArg, js.Undefined()), { + genCallDefaultGetter(exported.sym, paramIndex, static)(previousArgsValues) + }, { unboxedArg - } + })(unboxedArg.tpe) + } else { + // Otherwise, it is always the unboxed argument + unboxedArg } - - result += js.VarDef(freshLocalIdent("prep" + i), NoOriginalName, - verifiedOrDefault.tpe, mutable = false, verifiedOrDefault) } - - result.toList } private def genCallDefaultGetter(sym: Symbol, paramIndex: Int, static: Boolean)( @@ -757,6 +749,13 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { js.Undefined() else genApplyJSClassMethod(targetTree, defaultGetter, defaultGetterArgs) + } else if (defaultGetter.owner == targetSym) { + /* We get here if a non-native constructor has a native companion. + * This is reported on a per-class level. + */ + assert(sym.isClassConstructor, + s"got non-constructor method $sym with default method in JS native companion") + js.Undefined() } else { report.error( "When overriding a native method with default arguments, " + @@ -809,106 +808,33 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { private def genThrowTypeError(msg: String = "No matching overload")(implicit pos: Position): js.Tree = js.Throw(js.JSNew(js.JSGlobalRef("TypeError"), js.StringLiteral(msg) :: Nil)) - private final class ParamSpec(val info: Type, val isRepeated: Boolean, val hasDefault: Boolean) { - override def toString(): String = - i"ParamSpec($info, isRepeated = $isRepeated, hasDefault = $hasDefault)" - } + abstract class Exported( + val sym: Symbol, + // Parameters participating in overload resolution. + val params: scala.collection.immutable.IndexedSeq[JSParamInfo] + ) { + assert(!params.exists(_.capture), "illegal capture params in Exported") - private object ParamSpec { - def apply(methodSym: Symbol, infoAtElimRepeated: Type, infoAtElimEVT: Type, - methodHasDefaultParams: Boolean, paramIndex: Int): ParamSpec = { - val isRepeated = infoAtElimRepeated.isRepeatedParam - val info = - if (isRepeated) atPhase(elimRepeatedPhase)(infoAtElimRepeated.repeatedToSingle.widenDealias) - else infoAtElimEVT - val hasDefault = methodHasDefaultParams && defaultGetterDenot(methodSym, paramIndex).exists - new ParamSpec(info, isRepeated, hasDefault) - } - } - - // This is a case class because we rely on its structural equality - private final case class Exported(sym: Symbol) { - val isConstructorOfNestedJSClass = - sym.isClassConstructor && sym.owner.isNestedJSClass - - // params: List[ParamSpec] ; captureParams and captureParamsBack: List[js.ParamDef] - val (params, captureParamsFront, captureParamsBack) = { - val (paramNamesAtElimRepeated, paramInfosAtElimRepeated, methodHasDefaultParams) = - atPhase(elimRepeatedPhase)((sym.info.paramNamess.flatten, sym.info.paramInfoss.flatten, sym.hasDefaultParams)) - val (paramNamesAtElimEVT, paramInfosAtElimEVT) = - atPhase(elimErasedValueTypePhase)((sym.info.firstParamNames, sym.info.firstParamTypes)) - val (paramNamesNow, paramInfosNow) = - (sym.info.firstParamNames, sym.info.firstParamTypes) - - val formalParamCount = paramInfosAtElimRepeated.size - - def buildFormalParams(formalParamInfosAtElimEVT: List[Type]): IndexedSeq[ParamSpec] = { - (for { - (infoAtElimRepeated, infoAtElimEVT, paramIndex) <- - paramInfosAtElimRepeated.lazyZip(formalParamInfosAtElimEVT).lazyZip(0 until formalParamCount) - } yield { - ParamSpec(sym, infoAtElimRepeated, infoAtElimEVT, methodHasDefaultParams, paramIndex) - }).toIndexedSeq - } - - def buildCaptureParams(namesAndInfosNow: List[(TermName, Type)]): List[js.ParamDef] = { - implicit val pos: Position = sym.span - for ((name, info) <- namesAndInfosNow) yield { - js.ParamDef(freshLocalIdent(name.mangledString), NoOriginalName, toIRType(info), - mutable = false) - } - } - - if (!isConstructorOfNestedJSClass) { - // Easy case: all params are formal params - assert(paramInfosAtElimEVT.size == formalParamCount && paramInfosNow.size == formalParamCount, - s"Found $formalParamCount params entering elimRepeated but ${paramInfosAtElimEVT.size} params entering " + - s"elimErasedValueType and ${paramInfosNow.size} params at the back-end for non-lifted symbol ${sym.fullName}") - val formalParams = buildFormalParams(paramInfosAtElimEVT) - (formalParams, Nil, Nil) - } else if (formalParamCount == 0) { - // Fast path: all params are capture params - val captureParams = buildCaptureParams(paramNamesNow.zip(paramInfosNow)) - (IndexedSeq.empty, Nil, captureParams) + private val paramsHasDefault = { + if (!atPhase(elimRepeatedPhase)(sym.hasDefaultParams)) { + Vector.empty } else { - /* Slow path: we have to isolate formal params (which were already present at elimRepeated) - * from capture params (which are later, typically by erasure and/or lambdalift). - */ - - def findStartOfFormalParamsIn(paramNames: List[TermName]): Int = { - val start = paramNames.indexOfSlice(paramNamesAtElimRepeated) - assert(start >= 0, s"could not find formal param names $paramNamesAtElimRepeated in $paramNames") - start - } - - // Find the infos of formal params at elimEVT - val startOfFormalParamsAtElimEVT = findStartOfFormalParamsIn(paramNamesAtElimEVT) - val formalParamInfosAtElimEVT = paramInfosAtElimEVT.drop(startOfFormalParamsAtElimEVT).take(formalParamCount) - - // Build the formal param specs from their infos at elimRepeated and elimEVT - val formalParams = buildFormalParams(formalParamInfosAtElimEVT) - - // Find the formal params now to isolate the capture params (before and after the formal params) - val startOfFormalParamsNow = findStartOfFormalParamsIn(paramNamesNow) - val paramNamesAndInfosNow = paramNamesNow.zip(paramInfosNow) - val (captureParamsFrontNow, restOfParamsNow) = paramNamesAndInfosNow.splitAt(startOfFormalParamsNow) - val captureParamsBackNow = restOfParamsNow.drop(formalParamCount) - - // Build the capture param defs from the isolated capture params - val captureParamsFront = buildCaptureParams(captureParamsFrontNow) - val captureParamsBack = buildCaptureParams(captureParamsBackNow) - - (formalParams, captureParamsFront, captureParamsBack) + val targetSym = targetSymForDefaultGetter(sym) + params.indices.map(i => defaultGetterDenot(targetSym, sym, i).exists) } } - val hasRepeatedParam = params.nonEmpty && params.last.isRepeated + def hasDefaultAt(paramIndex: Int): Boolean = + paramIndex < paramsHasDefault.size && paramsHasDefault(paramIndex) + + val hasRepeatedParam = params.nonEmpty && params.last.repeated val minArgc = { // Find the first default param or repeated param - val firstOptionalParamIndex = params.indexWhere(p => p.hasDefault || p.isRepeated) - if (firstOptionalParamIndex == -1) params.size - else firstOptionalParamIndex + params + .indices + .find(i => hasDefaultAt(i) || params(i).repeated) + .getOrElse(params.size) } val maxNonRepeatedArgc = if (hasRepeatedParam) params.size - 1 else params.size @@ -925,6 +851,15 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { } def typeInfo: String = sym.info.toString + + def genBody(formalArgsRegistry: FormalArgsRegistry): js.Tree + } + + private class ExportedSymbol(sym: Symbol, static: Boolean) + extends Exported(sym, sym.jsParamInfos.toIndexedSeq) { + + def genBody(formalArgsRegistry: FormalArgsRegistry): js.Tree = + genApplyForSingleExported(formalArgsRegistry, this, static) } // !!! Hash codes of RTTypeTest are meaningless because of InstanceOfTypeTest @@ -944,46 +879,14 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { private case object NoTypeTest extends RTTypeTest - private object RTTypeTest { - given PartialOrdering[RTTypeTest] with { - override def tryCompare(lhs: RTTypeTest, rhs: RTTypeTest): Option[Int] = { - if (lteq(lhs, rhs)) if (lteq(rhs, lhs)) Some(0) else Some(-1) - else if (lteq(rhs, lhs)) Some(1) else None - } - - override def lteq(lhs: RTTypeTest, rhs: RTTypeTest): Boolean = { - (lhs, rhs) match { - // NoTypeTest is always last - case (_, NoTypeTest) => true - case (NoTypeTest, _) => false - - case (PrimitiveTypeTest(_, rank1), PrimitiveTypeTest(_, rank2)) => - rank1 <= rank2 - - case (InstanceOfTypeTest(t1), InstanceOfTypeTest(t2)) => - t1 <:< t2 - - case (_: PrimitiveTypeTest, _: InstanceOfTypeTest) => true - case (_: InstanceOfTypeTest, _: PrimitiveTypeTest) => false - } - } - - override def equiv(lhs: RTTypeTest, rhs: RTTypeTest): Boolean = { - lhs == rhs - } - } - } - /** Very simple O(n²) topological sort for elements assumed to be distinct. */ - private def topoSortDistinctsBy[A <: AnyRef, B](coll: List[A])(f: A => B)( - using ord: PartialOrdering[B]): List[A] = { - + private def topoSortDistinctsWith[A <: AnyRef](coll: List[A])(lteq: (A, A) => Boolean): List[A] = { @tailrec def loop(coll: List[A], acc: List[A]): List[A] = { if (coll.isEmpty) acc else if (coll.tail.isEmpty) coll.head :: acc else { - val (lhs, rhs) = coll.span(x => !coll.forall(y => (x eq y) || !ord.lteq(f(x), f(y)))) + val (lhs, rhs) = coll.span(x => !coll.forall(y => (x eq y) || !lteq(x, y))) assert(!rhs.isEmpty, s"cycle while ordering $coll") loop(lhs ::: rhs.tail, rhs.head :: acc) } @@ -1039,7 +942,7 @@ final class JSExportsGen(jsCodeGen: JSCodeGen)(using Context) { m.toList } - private class FormalArgsRegistry(val minArgc: Int, needsRestParam: Boolean) { + class FormalArgsRegistry(val minArgc: Int, needsRestParam: Boolean) { private val fixedParamNames: scala.collection.immutable.IndexedSeq[jsNames.LocalName] = (0 until minArgc).toIndexedSeq.map(_ => freshLocalIdent("arg")(NoPosition).name) diff --git a/compiler/src/dotty/tools/backend/sjs/JSPrimitives.scala b/compiler/src/dotty/tools/backend/sjs/JSPrimitives.scala index d4a96f29ca5c..372393733051 100644 --- a/compiler/src/dotty/tools/backend/sjs/JSPrimitives.scala +++ b/compiler/src/dotty/tools/backend/sjs/JSPrimitives.scala @@ -27,7 +27,10 @@ object JSPrimitives { final val UNITVAL = JS_NATIVE + 1 // () value, which is undefined - final val CONSTRUCTOROF = UNITVAL + 1 // runtime.constructorOf(clazz) + final val JS_IMPORT = UNITVAL + 1 // js.import.apply(specifier) + final val JS_IMPORT_META = JS_IMPORT + 1 // js.import.meta + + final val CONSTRUCTOROF = JS_IMPORT_META + 1 // runtime.constructorOf(clazz) final val CREATE_INNER_JS_CLASS = CONSTRUCTOROF + 1 // runtime.createInnerJSClass final val CREATE_LOCAL_JS_CLASS = CREATE_INNER_JS_CLASS + 1 // runtime.createLocalJSClass final val WITH_CONTEXTUAL_JS_CLASS_VALUE = CREATE_LOCAL_JS_CLASS + 1 // runtime.withContextualJSClassValue @@ -106,6 +109,9 @@ class JSPrimitives(ictx: Context) extends DottyPrimitives(ictx) { addPrimitive(defn.BoxedUnit_UNIT, UNITVAL) + addPrimitive(jsdefn.JSImport_apply, JS_IMPORT) + addPrimitive(jsdefn.JSImport_meta, JS_IMPORT_META) + addPrimitive(jsdefn.Runtime_constructorOf, CONSTRUCTOROF) addPrimitive(jsdefn.Runtime_createInnerJSClass, CREATE_INNER_JS_CLASS) addPrimitive(jsdefn.Runtime_createLocalJSClass, CREATE_LOCAL_JS_CLASS) diff --git a/compiler/src/dotty/tools/dotc/transform/sjs/JSSymUtils.scala b/compiler/src/dotty/tools/dotc/transform/sjs/JSSymUtils.scala index 0651e33c4d7d..f718d68e9588 100644 --- a/compiler/src/dotty/tools/dotc/transform/sjs/JSSymUtils.scala +++ b/compiler/src/dotty/tools/dotc/transform/sjs/JSSymUtils.scala @@ -92,6 +92,24 @@ object JSSymUtils { } } + /** Info about a Scala method param when called as JS method. + * + * @param info + * Parameter type (type of a single element if repeated). + * @param repeated + * Whether the parameter is repeated. + * @param capture + * Whether the parameter is a capture. + */ + final class JSParamInfo( + val info: Type, + val repeated: Boolean = false, + val capture: Boolean = false + ) { + override def toString(): String = + s"ParamSpec($info, repeated = $repeated, capture = $capture)" + } + extension (sym: Symbol) { /** Is this symbol a JavaScript type? */ def isJSType(using Context): Boolean = @@ -190,6 +208,43 @@ object JSSymUtils { def defaultJSName(using Context): String = if (sym.isTerm) sym.asTerm.name.unexpandedName.getterName.toString() else sym.name.unexpandedName.stripModuleClassSuffix.toString() + + def jsParamInfos(using Context): List[JSParamInfo] = { + assert(sym.is(Method), s"trying to take JS param info of non-method: $sym") + + def paramNamesAndTypes(using Context): List[(Names.TermName, Type)] = + sym.info.paramNamess.flatten.zip(sym.info.paramInfoss.flatten) + + val paramInfosAtElimRepeated = atPhase(elimRepeatedPhase) { + val list = + for ((name, info) <- paramNamesAndTypes) yield { + val v = + if (info.isRepeatedParam) Some(info.repeatedToSingle.widenDealias) + else None + name -> v + } + list.toMap + } + + val paramInfosAtElimEVT = atPhase(elimErasedValueTypePhase) { + paramNamesAndTypes.toMap + } + + for ((paramName, paramInfoNow) <- paramNamesAndTypes) yield { + paramInfosAtElimRepeated.get(paramName) match { + case None => + // This is a capture parameter introduced by erasure or lambdalift + new JSParamInfo(paramInfoNow, capture = true) + + case Some(Some(info)) => + new JSParamInfo(info, repeated = true) + + case Some(None) => + val info = paramInfosAtElimEVT.getOrElse(paramName, paramInfoNow) + new JSParamInfo(info) + } + } + } } private object JSUnaryOpMethodName { diff --git a/project/Build.scala b/project/Build.scala index dc20ee2d3e2a..ecebf6e00fb2 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1160,7 +1160,8 @@ object Build { "compliantModuleInit" -> (sems.moduleInit == CheckedBehavior.Compliant), "strictFloats" -> sems.strictFloats, "productionMode" -> sems.productionMode, - "es2015" -> linkerConfig.esFeatures.useECMAScript2015, + "esVersion" -> linkerConfig.esFeatures.esVersion.edition, + "useECMAScript2015Semantics" -> linkerConfig.esFeatures.useECMAScript2015Semantics, ) }.taskValue, diff --git a/project/ConstantHolderGenerator.scala b/project/ConstantHolderGenerator.scala index 9dc45313861d..e22609a305fd 100644 --- a/project/ConstantHolderGenerator.scala +++ b/project/ConstantHolderGenerator.scala @@ -35,6 +35,7 @@ object ConstantHolderGenerator { private final def literal(v: Any): String = v match { case s: String => "raw\"\"\"" + s + "\"\"\"" case b: Boolean => b.toString + case i: Int => i.toString case f: File => literal(f.getAbsolutePath) case _ => diff --git a/project/plugins.sbt b/project/plugins.sbt index b1a292ec9624..5a49901967ca 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -2,7 +2,7 @@ // // e.g. addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.1.0") -addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.5.1") +addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.6.0") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.6")