diff --git a/src/dotty/tools/dotc/ast/tpd.scala b/src/dotty/tools/dotc/ast/tpd.scala index defcf4838f59..51011f90b1fe 100644 --- a/src/dotty/tools/dotc/ast/tpd.scala +++ b/src/dotty/tools/dotc/ast/tpd.scala @@ -302,7 +302,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { true case pre: ThisType => pre.cls.isStaticOwner || - tp.symbol.is(ParamOrAccessor) && ctx.owner.enclosingClass == pre.cls + tp.symbol.is(ParamOrAccessor) && !pre.cls.is(Trait) && ctx.owner.enclosingClass == pre.cls // was ctx.owner.enclosingClass.derivesFrom(pre.cls) which was not tight enough // and was spuriously triggered in case inner class would inherit from outer one // eg anonymous TypeMap inside TypeMap.andThen diff --git a/src/dotty/tools/dotc/core/SymDenotations.scala b/src/dotty/tools/dotc/core/SymDenotations.scala index 9db75ee94c02..d8dddb0828cb 100644 --- a/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/src/dotty/tools/dotc/core/SymDenotations.scala @@ -516,7 +516,7 @@ object SymDenotations { !isAnonymousFunction && !isCompanionMethod - /** Is this a setter? */ + /** Is this a getter? */ final def isGetter(implicit ctx: Context) = (this is Accessor) && !originalName.isSetterName && !originalName.isScala2LocalSuffix diff --git a/src/dotty/tools/dotc/transform/Mixin.scala b/src/dotty/tools/dotc/transform/Mixin.scala index 8562f4f02660..de6cde8f2cc6 100644 --- a/src/dotty/tools/dotc/transform/Mixin.scala +++ b/src/dotty/tools/dotc/transform/Mixin.scala @@ -20,7 +20,7 @@ import collection.mutable /** This phase performs the following transformations: * - * 1. (done in `traitDefs`) Map every concrete trait getter + * 1. (done in `traitDefs` and `transformSym`) Map every concrete trait getter * * def x(): T = expr * @@ -46,32 +46,45 @@ import collection.mutable * For every trait M directly implemented by the class (see SymUtils.mixin), in * reverse linearization order, add the following definitions to C: * - * 3.1 (done in `traitInits`) For every concrete trait getter ` def x(): T` in M, - * in order of textual occurrence, produce the following: + * 3.1 (done in `traitInits`) For every parameter accessor ` def x(): T` in M, + * in order of textual occurrence, add * - * 3.1.1 If `x` is also a member of `C`, and M is a Dotty trait: + * def x() = e + * + * where `e` is the constructor argument in C that corresponds to `x`. Issue + * an error if no such argument exists. + * + * 3.2 (done in `traitInits`) For every concrete trait getter ` def x(): T` in M + * which is not a parameter accessor, in order of textual occurrence, produce the following: + * + * 3.2.1 If `x` is also a member of `C`, and M is a Dotty trait: * * def x(): T = super[M].initial$x() * - * 3.1.2 If `x` is also a member of `C`, and M is a Scala 2.x trait: + * 3.2.2 If `x` is also a member of `C`, and M is a Scala 2.x trait: * * def x(): T = _ * - * 3.1.3 If `x` is not a member of `C`, and M is a Dotty trait: + * 3.2.3 If `x` is not a member of `C`, and M is a Dotty trait: * * super[M].initial$x() * - * 3.1.4 If `x` is not a member of `C`, and M is a Scala2.x trait, nothing gets added. + * 3.2.4 If `x` is not a member of `C`, and M is a Scala2.x trait, nothing gets added. * * - * 3.2 (done in `superCallOpt`) The call: + * 3.3 (done in `superCallOpt`) The call: * * super[M]. * - * 3.3 (done in `setters`) For every concrete setter ` def x_=(y: T)` in M: + * 3.4 (done in `setters`) For every concrete setter ` def x_=(y: T)` in M: * * def x_=(y: T) = () * + * 4. (done in `transformTemplate` and `transformSym`) Drop all parameters from trait + * constructors. + * + * 5. (done in `transformSym`) Drop ParamAccessor flag from all parameter accessors in traits. + * * Conceptually, this is the second half of the previous mixin phase. It needs to run * after erasure because it copies references to possibly private inner classes and objects * into enclosing classes where they are not visible. This can only be done if all references @@ -86,7 +99,9 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform => override def transformSym(sym: SymDenotation)(implicit ctx: Context): SymDenotation = if (sym.is(Accessor, butNot = Deferred) && sym.owner.is(Trait)) - sym.copySymDenotation(initFlags = sym.flags | Deferred).ensureNotPrivate + sym.copySymDenotation(initFlags = sym.flags &~ ParamAccessor | Deferred).ensureNotPrivate + else if (sym.isConstructor && sym.owner.is(Trait) && sym.info.firstParamTypes.nonEmpty) + sym.copySymDenotation(info = MethodType(Nil, sym.info.resultType)) else sym @@ -111,7 +126,7 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform => def traitDefs(stats: List[Tree]): List[Tree] = { val initBuf = new mutable.ListBuffer[Tree] stats.flatMap({ - case stat: DefDef if stat.symbol.isGetter && !stat.rhs.isEmpty && !stat.symbol.is(Flags.Lazy) => + case stat: DefDef if stat.symbol.isGetter && !stat.rhs.isEmpty && !stat.symbol.is(Flags.Lazy) => // make initializer that has all effects of previous getter, // replace getter rhs with empty tree. val vsym = stat.symbol @@ -131,15 +146,22 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform => }) ++ initBuf } - def transformSuper(tree: Tree): Tree = { + /** Map constructor call to a pair of a supercall and a list of arguments + * to be used as initializers of trait parameters if the target of the call + * is a trait. + */ + def transformConstructor(tree: Tree): (Tree, List[Tree]) = { val Apply(sel @ Select(New(_), nme.CONSTRUCTOR), args) = tree - superRef(tree.symbol, tree.pos).appliedToArgs(args) + val (callArgs, initArgs) = if (tree.symbol.owner.is(Trait)) (Nil, args) else (args, Nil) + (superRef(tree.symbol, tree.pos).appliedToArgs(callArgs), initArgs) } - val superCalls = ( + val superCallsAndArgs = ( for (p <- impl.parents if p.symbol.isConstructor) - yield p.symbol.owner -> transformSuper(p) + yield p.symbol.owner -> transformConstructor(p) ).toMap + val superCalls = superCallsAndArgs.mapValues(_._1) + val initArgs = superCallsAndArgs.mapValues(_._2) def superCallOpt(baseCls: Symbol): List[Tree] = superCalls.get(baseCls) match { case Some(call) => @@ -155,35 +177,63 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform => def wasDeferred(sym: Symbol) = ctx.atPhase(thisTransform) { implicit ctx => sym is Deferred } - def traitInits(mixin: ClassSymbol): List[Tree] = + def traitInits(mixin: ClassSymbol): List[Tree] = { + var argNum = 0 + def nextArgument() = initArgs.get(mixin) match { + case Some(arguments) => + try arguments(argNum) finally argNum += 1 + case None => + val (msg, pos) = impl.parents.find(_.tpe.typeSymbol == mixin) match { + case Some(parent) => ("lacks argument list", parent.pos) + case None => + ("""is indirectly implemented, + |needs to be implemented directly so that arguments can be passed""".stripMargin, + cls.pos) + } + ctx.error(i"parameterized $mixin $msg", pos) + EmptyTree + } + for (getter <- mixin.info.decls.filter(getr => getr.isGetter && !wasDeferred(getr)).toList) yield { val isScala2x = mixin.is(Scala2x) def default = Underscore(getter.info.resultType) def initial = transformFollowing(superRef(initializer(getter)).appliedToNone) - if (isCurrent(getter) || getter.is(ExpandedName)) + + /** A call to the implementation of `getter` in `mixin`'s implementation class */ + def lazyGetterCall = { + def canbeImplClassGetter(sym: Symbol) = sym.info.firstParamTypes match { + case t :: Nil => t.isDirectRef(mixin) + case _ => false + } + val implClassGetter = mixin.implClass.info.nonPrivateDecl(getter.name) + .suchThat(canbeImplClassGetter).symbol + ref(mixin.implClass).select(implClassGetter).appliedTo(This(cls)) + } + + if (isCurrent(getter) || getter.is(ExpandedName)) { + val rhs = + if (ctx.atPhase(thisTransform)(implicit ctx => getter.is(ParamAccessor))) nextArgument() + else if (isScala2x) + if (getter.is(Lazy)) lazyGetterCall + else Underscore(getter.info.resultType) + else transformFollowing(superRef(initializer(getter)).appliedToNone) // transformFollowing call is needed to make memoize & lazy vals run - transformFollowing( - DefDef(implementation(getter.asTerm), - if (isScala2x) { - if (getter.is(Flags.Lazy)) { // lazy vals need to have a rhs that will be the lazy initializer - val sym = mixin.implClass.info.nonPrivateDecl(getter.name).suchThat(_.info.paramTypess match { - case List(List(t: TypeRef)) => t.isDirectRef(mixin) - case _ => false - }).symbol // lazy val can be overloaded - ref(mixin.implClass).select(sym).appliedTo(This(ctx.owner.asClass)) - } - else default - } else initial) - ) + transformFollowing(DefDef(implementation(getter.asTerm), rhs)) + } else if (isScala2x) EmptyTree else initial } + } def setters(mixin: ClassSymbol): List[Tree] = for (setter <- mixin.info.decls.filter(setr => setr.isSetter && !wasDeferred(setr)).toList) yield DefDef(implementation(setter.asTerm), unitLiteral.withPos(cls.pos)) cpy.Template(impl)( + constr = + if (cls.is(Trait) && impl.constr.vparamss.flatten.nonEmpty) + cpy.DefDef(impl.constr)(vparamss = Nil :: Nil) + else impl.constr, parents = impl.parents.map(p => TypeTree(p.tpe).withPos(p.pos)), body = if (cls is Trait) traitDefs(impl.body) diff --git a/src/dotty/tools/dotc/typer/Checking.scala b/src/dotty/tools/dotc/typer/Checking.scala index 3ef6d059a0f4..9047b8cb35c0 100644 --- a/src/dotty/tools/dotc/typer/Checking.scala +++ b/src/dotty/tools/dotc/typer/Checking.scala @@ -20,6 +20,7 @@ import annotation.unchecked import util.Positions._ import util.{Stats, SimpleMap} import util.common._ +import transform.SymUtils._ import Decorators._ import Uniques._ import ErrorReporting.{err, errorType, DiagnosticString} @@ -328,9 +329,15 @@ trait Checking { } } - def checkInstantiatable(cls: ClassSymbol, pos: Position): Unit = { - ??? // to be done in later phase: check that class `cls` is legal in a new. - } + def checkParentCall(call: Tree, caller: ClassSymbol)(implicit ctx: Context) = + if (!ctx.isAfterTyper) { + val called = call.tpe.classSymbol + if (caller is Trait) + ctx.error(i"$caller may not call constructor of $called", call.pos) + else if (called.is(Trait) && !caller.mixins.contains(called)) + ctx.error(i"""$called is already implemented by super${caller.superClass}, + |its constructor cannot be called again""".stripMargin, call.pos) + } } trait NoChecking extends Checking { @@ -343,4 +350,5 @@ trait NoChecking extends Checking { override def checkImplicitParamsNotSingletons(vparamss: List[List[ValDef]])(implicit ctx: Context): Unit = () override def checkFeasible(tp: Type, pos: Position, where: => String = "")(implicit ctx: Context): Type = tp override def checkNoDoubleDefs(cls: Symbol)(implicit ctx: Context): Unit = () + override def checkParentCall(call: Tree, caller: ClassSymbol)(implicit ctx: Context) = () } diff --git a/src/dotty/tools/dotc/typer/Typer.scala b/src/dotty/tools/dotc/typer/Typer.scala index 5f03d19e7c62..2bdd0d19714e 100644 --- a/src/dotty/tools/dotc/typer/Typer.scala +++ b/src/dotty/tools/dotc/typer/Typer.scala @@ -911,8 +911,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit if (tree.isType) typedType(tree)(superCtx) else { val result = typedExpr(tree)(superCtx) - if ((cls is Trait) && result.tpe.classSymbol.isRealClass && !ctx.isAfterTyper) - ctx.error(s"trait may not call constructor of ${result.tpe.classSymbol}", tree.pos) + checkParentCall(result, cls) result } diff --git a/test/dotc/tests.scala b/test/dotc/tests.scala index 0a6127580df9..1aa35e3ee66b 100644 --- a/test/dotc/tests.scala +++ b/test/dotc/tests.scala @@ -138,6 +138,8 @@ class tests extends CompilerTest { @Test def neg_instantiateAbstract = compileFile(negDir, "instantiateAbstract", xerrors = 8) @Test def neg_selfInheritance = compileFile(negDir, "selfInheritance", xerrors = 5) @Test def neg_shadowedImplicits = compileFile(negDir, "arrayclone-new", xerrors = 2) + @Test def neg_traitParamsTyper = compileFile(negDir, "traitParamsTyper", xerrors = 5) + @Test def neg_traitParamsMixin = compileFile(negDir, "traitParamsMixin", xerrors = 2) @Test def run_all = runFiles(runDir) diff --git a/tests/neg/traitParamsMixin.scala b/tests/neg/traitParamsMixin.scala new file mode 100644 index 000000000000..dfb9fbe2f66d --- /dev/null +++ b/tests/neg/traitParamsMixin.scala @@ -0,0 +1,12 @@ +trait T(x: Int) { + def f = x +} + +class C extends T // error + +trait U extends T + +class D extends U { // error + +} + diff --git a/tests/neg/traitParamsTyper.scala b/tests/neg/traitParamsTyper.scala new file mode 100644 index 000000000000..f87ba3691d7f --- /dev/null +++ b/tests/neg/traitParamsTyper.scala @@ -0,0 +1,16 @@ +trait T(x: Int) { + def f = x +} + +class C(x: Int) extends T() // error + +trait U extends C with T + +trait V extends C(1) with T(2) // two errors + +trait W extends T(3) // error + + +class E extends T(0) +class F extends E with T(1) // error + diff --git a/tests/run/traitParamInit.scala b/tests/run/traitParamInit.scala new file mode 100644 index 000000000000..37d8a425d5ff --- /dev/null +++ b/tests/run/traitParamInit.scala @@ -0,0 +1,30 @@ +object Trace { + private var results = List[Any]() + def apply[A](a: A) = {results ::= a; a} + def fetchAndClear(): Seq[Any] = try results.reverse finally results = Nil +} +trait T(a: Any) { + val ta = a + Trace(s"T.($ta)") + val t_val = Trace("T.val") +} + +trait U(a: Any) extends T { + val ua = a + Trace(s"U.($ua)") +} + +object Test { + def check(expected: Any) = { + val actual = Trace.fetchAndClear() + if (actual != expected) + sys.error(s"\n$actual\n$expected") + } + def main(args: Array[String]): Unit = { + new T(Trace("ta")) with U(Trace("ua")) {} + check(List("ta", "T.(ta)", "T.val", "ua", "U.(ua)")) + + new U(Trace("ua")) with T(Trace("ta")) {} + check(List("ta", "T.(ta)", "T.val", "ua", "U.(ua)")) + } +} diff --git a/tests/run/traitParams.scala b/tests/run/traitParams.scala new file mode 100644 index 000000000000..82c176461c24 --- /dev/null +++ b/tests/run/traitParams.scala @@ -0,0 +1,32 @@ +object State { + var s: Int = 0 +} + +trait T(x: Int, val y: Int) { + def f = x +} + +trait U extends T { + State.s += 1 + override def f = super.f + y +} +trait U2(a: Any) extends T { + def d = a // okay + val v = a // okay + a // used to crash +} + +import State._ +class C(x: Int) extends U with T(x, x * x + s) +class C2(x: Int) extends T(x, x * x + s) with U + +class D extends C(10) with T +class D2 extends C2(10) with T + +object Test { + def main(args: Array[String]): Unit = { + assert(new D().f == 110) + assert(new D2().f == 111) + } +} +