diff --git a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala index 813f8ddf8780..ecbfbeb6d6e5 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala @@ -80,5 +80,10 @@ object TypeUtils { case self: TypeProxy => self.underlying.companionRef } + + /** Is this type a methodic type that takes implicit parameters (both old and new) at some point? */ + def takesImplicitParams(using Context): Boolean = self.stripPoly match + case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams + case _ => false } } diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 1eb4a37c9a2f..677beac079fb 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1237,12 +1237,81 @@ class Namer { typer: Typer => } } + /** Ensure that the first type in a list of parent types Ps points to a non-trait class. + * If that's not already the case, add one. The added class type CT is determined as follows. + * First, let C be the unique class such that + * - there is a parent P_i such that P_i derives from C, and + * - for every class D: If some parent P_j, j <= i derives from D, then C derives from D. + * Then, let CT be the smallest type which + * - has C as its class symbol, and + * - for all parents P_i: If P_i derives from C then P_i <:< CT. + */ + def ensureFirstIsClass(parents: List[Type]): List[Type] = + + def realClassParent(sym: Symbol): ClassSymbol = + if !sym.isClass then defn.ObjectClass + else if !sym.is(Trait) then sym.asClass + else sym.info.parents match + case parentRef :: _ => realClassParent(parentRef.typeSymbol) + case nil => defn.ObjectClass + + def improve(candidate: ClassSymbol, parent: Type): ClassSymbol = + val pcls = realClassParent(parent.classSymbol) + if (pcls derivesFrom candidate) pcls else candidate + + parents match + case p :: _ if p.classSymbol.isRealClass => parents + case _ => + val pcls = parents.foldLeft(defn.ObjectClass)(improve) + typr.println(i"ensure first is class $parents%, % --> ${parents map (_ baseType pcls)}%, %") + val first = TypeComparer.glb(defn.ObjectType :: parents.map(_.baseType(pcls))) + checkFeasibleParent(first, cls.srcPos, em" in inferred superclass $first") :: parents + end ensureFirstIsClass + + /** If `parents` contains references to traits that have supertraits with implicit parameters + * add those supertraits in linearization order unless they are already covered by other + * parent types. For instance, in + * + * class A + * trait B(using I) extends A + * trait C extends B + * class D extends A, C + * + * the class declaration of `D` is augmented to + * + * class D extends A, B, C + * + * so that an implicit `I` can be passed to `B`. See i7613.scala for more examples. + */ + def addUsingTraits(parents: List[Type]): List[Type] = + lazy val existing = parents.map(_.classSymbol).toSet + def recur(parents: List[Type]): List[Type] = parents match + case parent :: parents1 => + val psym = parent.classSymbol + val addedTraits = + if psym.is(Trait) then + psym.asClass.baseClasses.tail.iterator + .takeWhile(_.is(Trait)) + .filter(p => + p.primaryConstructor.info.takesImplicitParams + && !cls.superClass.isSubClass(p) + && !existing.contains(p)) + .toList.reverse + else Nil + addedTraits.map(parent.baseType) ::: parent :: recur(parents1) + case nil => + Nil + if cls.isRealClass then recur(parents) else parents + end addUsingTraits + completeConstructor(denot) denot.info = tempInfo val parentTypes = defn.adjustForTuple(cls, cls.typeParams, defn.adjustForBoxedUnit(cls, - ensureFirstIsClass(parents.map(checkedParentType(_)), cls.span) + addUsingTraits( + ensureFirstIsClass(parents.map(checkedParentType(_))) + ) ) ) typr.println(i"completing $denot, parents = $parents%, %, parentTypes = $parentTypes%, %") diff --git a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala index 8ff0dfc8c61b..a744ca39f41f 100644 --- a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala +++ b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala @@ -103,8 +103,8 @@ class ReTyper extends Typer with ReChecking { override def completeAnnotations(mdef: untpd.MemberDef, sym: Symbol)(using Context): Unit = () - override def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] = - parents + override def ensureConstrCall(cls: ClassSymbol, parent: Tree)(using Context): Tree = + parent override def handleUnexpectedFunType(tree: untpd.Apply, fun: Tree)(using Context): Tree = fun.tpe match { case mt: MethodType => diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index b7dbbe859ce4..50e84a86c13e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2212,20 +2212,18 @@ class Typer extends Namer * @param psym Its type symbol * @param cinfo The info of its constructor */ - def maybeCall(ref: Tree, psym: Symbol, cinfo: Type): Tree = cinfo.stripPoly match { + def maybeCall(ref: Tree, psym: Symbol): Tree = psym.primaryConstructor.info.stripPoly match case cinfo @ MethodType(Nil) if cinfo.resultType.isImplicitMethod => typedExpr(untpd.New(untpd.TypedSplice(ref)(using superCtx), Nil))(using superCtx) case cinfo @ MethodType(Nil) if !cinfo.resultType.isInstanceOf[MethodType] => ref case cinfo: MethodType => - if (!ctx.erasedTypes) { // after constructors arguments are passed in super call. + if !ctx.erasedTypes then // after constructors arguments are passed in super call. typr.println(i"constr type: $cinfo") report.error(ParameterizedTypeLacksArguments(psym), ref.srcPos) - } ref case _ => ref - } val seenParents = mutable.Set[Symbol]() @@ -2250,7 +2248,7 @@ class Typer extends Namer if (tree.isType) { checkSimpleKinded(result) // Not needed for constructor calls, as type arguments will be inferred. if (psym.is(Trait) && !cls.is(Trait) && !cls.superClass.isSubClass(psym)) - result = maybeCall(result, psym, psym.primaryConstructor.info) + result = maybeCall(result, psym) } else checkParentCall(result, cls) checkTraitInheritance(psym, cls, tree.srcPos) @@ -2258,6 +2256,27 @@ class Typer extends Namer result } + /** Augment `ptrees` to have the same class symbols as `parents`. Generate TypeTrees + * or New trees to fill in any parents for which no tree exists yet. + */ + def parentTrees(parents: List[Type], ptrees: List[Tree]): List[Tree] = parents match + case parent :: parents1 => + val psym = parent.classSymbol + def hasSameParent(ptree: Tree) = ptree.tpe.classSymbol == psym + ptrees match + case ptree :: ptrees1 if hasSameParent(ptree) => + ptree :: parentTrees(parents1, ptrees1) + case ptree :: ptrees1 if ptrees1.exists(hasSameParent) => + ptree :: parentTrees(parents, ptrees1) + case _ => + var added: Tree = TypeTree(parent).withSpan(cdef.nameSpan.focus) + if psym.is(Trait) && psym.primaryConstructor.info.takesImplicitParams then + // classes get a constructor separately using a different context + added = ensureConstrCall(cls, added) + added :: parentTrees(parents1, ptrees) + case _ => + ptrees + /** Checks if one of the decls is a type with the same name as class type member in selfType */ def classExistsOnSelf(decls: Scope, self: tpd.ValDef): Boolean = { val selfType = self.tpt.tpe @@ -2278,8 +2297,10 @@ class Typer extends Namer completeAnnotations(cdef, cls) val constr1 = typed(constr).asInstanceOf[DefDef] - val parentsWithClass = ensureFirstTreeIsClass(parents.mapconserve(typedParent).filterConserve(!_.isEmpty), cdef.nameSpan) - val parents1 = ensureConstrCall(cls, parentsWithClass)(using superCtx) + val parents0 = parentTrees( + cls.classInfo.declaredParents, + parents.mapconserve(typedParent).filterConserve(!_.isEmpty)) + val parents1 = ensureConstrCall(cls, parents0)(using superCtx) val firstParentTpe = parents1.head.tpe.dealias val firstParent = firstParentTpe.typeSymbol @@ -2348,52 +2369,23 @@ class Typer extends Namer protected def addAccessorDefs(cls: Symbol, body: List[Tree])(using Context): List[Tree] = ctx.compilationUnit.inlineAccessors.addAccessorDefs(cls, body) - /** Ensure that the first type in a list of parent types Ps points to a non-trait class. - * If that's not already the case, add one. The added class type CT is determined as follows. - * First, let C be the unique class such that - * - there is a parent P_i such that P_i derives from C, and - * - for every class D: If some parent P_j, j <= i derives from D, then C derives from D. - * Then, let CT be the smallest type which - * - has C as its class symbol, and - * - for all parents P_i: If P_i derives from C then P_i <:< CT. + /** If this is a real class, make sure its first parent is a + * constructor call. Cannot simply use a type. Overridden in ReTyper. */ - def ensureFirstIsClass(parents: List[Type], span: Span)(using Context): List[Type] = { - def realClassParent(cls: Symbol): ClassSymbol = - if (!cls.isClass) defn.ObjectClass - else if (!cls.is(Trait)) cls.asClass - else cls.info.parents match { - case parentRef :: _ => realClassParent(parentRef.typeSymbol) - case nil => defn.ObjectClass - } - def improve(candidate: ClassSymbol, parent: Type): ClassSymbol = { - val pcls = realClassParent(parent.classSymbol) - if (pcls derivesFrom candidate) pcls else candidate - } - parents match { - case p :: _ if p.classSymbol.isRealClass => parents - case _ => - val pcls = parents.foldLeft(defn.ObjectClass)(improve) - typr.println(i"ensure first is class $parents%, % --> ${parents map (_ baseType pcls)}%, %") - val first = TypeComparer.glb(defn.ObjectType :: parents.map(_.baseType(pcls))) - checkFeasibleParent(first, ctx.source.atSpan(span), em" in inferred superclass $first") :: parents - } - } + def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] = parents match + case parents @ (first :: others) => + parents.derivedCons(ensureConstrCall(cls, first), others) + case parents => + parents - /** Ensure that first parent tree refers to a real class. */ - def ensureFirstTreeIsClass(parents: List[Tree], span: Span)(using Context): List[Tree] = parents match { - case p :: ps if p.tpe.classSymbol.isRealClass => parents - case _ => TypeTree(ensureFirstIsClass(parents.tpes, span).head).withSpan(span.focus) :: parents - } - - /** If this is a real class, make sure its first parent is a + /** If this is a real class, make sure its first parent is a * constructor call. Cannot simply use a type. Overridden in ReTyper. */ - def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] = { - val firstParent :: otherParents = parents - if (firstParent.isType && !cls.is(Trait) && !cls.is(JavaDefined)) - typed(untpd.New(untpd.TypedSplice(firstParent), Nil)) :: otherParents - else parents - } + def ensureConstrCall(cls: ClassSymbol, parent: Tree)(using Context): Tree = + if (parent.isType && !cls.is(Trait) && !cls.is(JavaDefined)) + typed(untpd.New(untpd.TypedSplice(parent), Nil)) + else + parent def localDummy(cls: ClassSymbol, impl: untpd.Template)(using Context): Symbol = newLocalDummy(cls, impl.span) diff --git a/docs/docs/reference/other-new-features/trait-parameters.md b/docs/docs/reference/other-new-features/trait-parameters.md index 1655e338b32c..5ccde10b05c1 100644 --- a/docs/docs/reference/other-new-features/trait-parameters.md +++ b/docs/docs/reference/other-new-features/trait-parameters.md @@ -52,6 +52,36 @@ The correct way to write `E` is to extend both `Greeting` and class E extends Greeting("Bob"), FormalGreeting ``` +### Traits With Context Parameters + +This "explicit extension required" rule is relaxed if the missing trait contains only +[context parameters](../contextual/using-clauses). In that case the trait reference is +implicitly inserted as an additional parent with inferred arguments. For instance, +here's a variant of greetings where the addressee is a context parameter of type +`ImpliedName`: + +```scala +case class ImpliedName(name: String): + override def toString = name + +trait ImpliedGreeting(using val iname: ImpliedName): + def msg = s"How are you, $iname" + +trait ImpliedFormalGreeting extends ImpliedGreeting: + override def msg = s"How do you do, $iname" + +class F(using iname: ImpliedName) extends ImpliedFormalGreeting +``` + +The definition of `F` in the last line is implicitly expanded to +```scala +class F(using iname: ImpliedName) extends + Object, + ImpliedGreeting(using iname), + ImpliedFormalGreeting(using iname) +``` +Note the inserted reference to the super trait `ImpliedGreeting`, which was not mentioned explicitly. + ## Reference For more information, see [Scala SIP 25](http://docs.scala-lang.org/sips/pending/trait-parameters.html). diff --git a/tests/neg/i6060.scala b/tests/neg/i6060.scala index 003df32becd1..8e8ca3a45662 100644 --- a/tests/neg/i6060.scala +++ b/tests/neg/i6060.scala @@ -1,6 +1,6 @@ class I1(i2: Int) { def apply(i3: Int) = 1 - new I1(1)(2) {} // error: too many arguments in parent constructor + new I1(1)(2) {} // error: too many arguments in parent constructor // error } class I0(i1: Int) { diff --git a/tests/neg/i7613.check b/tests/neg/i7613.check new file mode 100644 index 000000000000..1cf86894d64d --- /dev/null +++ b/tests/neg/i7613.check @@ -0,0 +1,8 @@ +-- Error: tests/neg/i7613.scala:10:16 ---------------------------------------------------------------------------------- +10 | new BazLaws[A] {} // error // error + | ^ + | no implicit argument of type Baz[A] was found for parameter x$1 of constructor BazLaws in trait BazLaws +-- Error: tests/neg/i7613.scala:10:2 ----------------------------------------------------------------------------------- +10 | new BazLaws[A] {} // error // error + | ^ + | no implicit argument of type Bar[A] was found for parameter x$1 of constructor BarLaws in trait BarLaws diff --git a/tests/neg/i7613.scala b/tests/neg/i7613.scala new file mode 100644 index 000000000000..f50700d94219 --- /dev/null +++ b/tests/neg/i7613.scala @@ -0,0 +1,11 @@ +trait Foo[A] +trait Bar[A] extends Foo[A] +trait Baz[A] extends Bar[A] + +trait FooLaws[A](using Foo[A]) +trait BarLaws[A](using Bar[A]) extends FooLaws[A] +trait BazLaws[A](using Baz[A]) extends BarLaws[A] + +def instance[A](using Foo[A]): BazLaws[A] = + new BazLaws[A] {} // error // error + diff --git a/tests/pos/reference/trait-parameters.scala b/tests/pos/reference/trait-parameters.scala index 8b043a4279d6..dbd61a9853cd 100644 --- a/tests/pos/reference/trait-parameters.scala +++ b/tests/pos/reference/trait-parameters.scala @@ -16,4 +16,13 @@ class E extends Greeting("Bob") with FormalGreeting // class D2 extends C with Greeting("Bill") // error +case class ImpliedName(name: String): + override def toString = name +trait ImpliedGreeting(using val iname: ImpliedName): + def msg = s"How are you, $iname" + +trait ImpliedFormalGreeting extends ImpliedGreeting: + override def msg = s"How do you do, $iname" + +class F(using iname: ImpliedName) extends ImpliedFormalGreeting diff --git a/tests/run/i7613.check b/tests/run/i7613.check new file mode 100644 index 000000000000..fe54e2aa095e --- /dev/null +++ b/tests/run/i7613.check @@ -0,0 +1,5 @@ +D: B1 +superD: B1 +E: B2 +F: B1 +F: B2 diff --git a/tests/run/i7613.scala b/tests/run/i7613.scala new file mode 100644 index 000000000000..a77682b3cffa --- /dev/null +++ b/tests/run/i7613.scala @@ -0,0 +1,29 @@ +trait Foo[A] +trait Bar[A] extends Foo[A] +trait Baz[A] extends Bar[A] + +trait FooLaws[A](using Foo[A]) +trait BarLaws[A](using Bar[A]) extends FooLaws[A] +trait BazLaws[A](using Baz[A]) extends BarLaws[A] + +def instance1[A](using Baz[A]): BazLaws[A] = + new FooLaws[A] with BarLaws[A] with BazLaws[A] {} + +def instance2[A](using Baz[A]): BazLaws[A] = + new BazLaws[A] {} + +trait I: + def show(x: String): Unit +class A +trait B1(using I) extends A { summon[I].show("B1") } +trait B2(using I) extends B1 { summon[I].show("B2") } +trait C1 extends B1 +trait C2 extends B2 +class D(using I) extends A, C1 +class E(using I) extends D(using new I { def show(x: String) = println(s"superD: $x")}), C2 +class F(using I) extends A, C2 + +@main def Test = + D(using new I { def show(x: String) = println(s"D: $x")}) + E(using new I { def show(x: String) = println(s"E: $x")}) + F(using new I { def show(x: String) = println(s"F: $x")})