Skip to content

Fix #8531: Annnotations on class value parameters go to the constructor #8534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ object desugar {
else originalTparams
}
else originalTparams

// Annotations on class _type_ parameters are set on the derived parameters
// but not on the constructor parameters. The reverse is true for
// annotations on class _value_ parameters.
val constrTparams = impliedTparams.map(toDefParam(_, keepAnnotations = false))
val constrVparamss =
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
Expand All @@ -444,7 +448,14 @@ object desugar {
ctx.error(CaseClassMissingNonImplicitParamList(cdef), namePos)
ListOfNil
}
else originalVparamss.nestedMap(toDefParam(_, keepAnnotations = false, keepDefault = true))
else originalVparamss.nestedMap(toDefParam(_, keepAnnotations = true, keepDefault = true))
val derivedTparams =
constrTparams.zipWithConserve(impliedTparams)((tparam, impliedParam) =>
derivedTypeParam(tparam).withAnnotations(impliedParam.mods.annotations))
val derivedVparamss =
constrVparamss.nestedMap(vparam =>
derivedTermParam(vparam).withAnnotations(Nil))

val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss)

val (normalizedBody, enumCases, enumCompanionRef) = {
Expand Down Expand Up @@ -480,14 +491,6 @@ object desugar {

def anyRef = ref(defn.AnyRefAlias.typeRef)

// Annotations are dropped from the constructor parameters but should be
// preserved in all derived parameters.
val derivedTparams =
constrTparams.zipWithConserve(impliedTparams)((tparam, impliedParam) =>
derivedTypeParam(tparam).withAnnotations(impliedParam.mods.annotations))
val derivedVparamss =
constrVparamss.nestedMap(vparam => derivedTermParam(vparam))

val arity = constrVparamss.head.length

val classTycon: Tree = TypeRefTree() // watching is set at end of method
Expand Down Expand Up @@ -779,8 +782,10 @@ object desugar {
val originalVparamsIt = originalVparamss.iterator.flatten
derivedVparamss match {
case first :: rest =>
first.map(_.withMods(originalVparamsIt.next().mods | caseAccessor)) ++
rest.flatten.map(_.withMods(originalVparamsIt.next().mods))
// Annotations on the class _value_ parameters are not set on the parameter accessors
def mods(vdef: ValDef) = vdef.mods.withAnnotations(Nil)
first.map(_.withMods(mods(originalVparamsIt.next()) | caseAccessor)) ++
rest.flatten.map(_.withMods(mods(originalVparamsIt.next())))
case _ =>
Nil
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ class Definitions {
@tu lazy val TASTYLongSignatureAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.internal.TASTYLongSignature")
@tu lazy val TailrecAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.tailrec")
@tu lazy val ThreadUnsafeAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.threadUnsafe")
@tu lazy val TransientParamAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.constructorOnly")
@tu lazy val ConstructorOnlyAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.constructorOnly")
@tu lazy val CompileTimeOnlyAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.compileTimeOnly")
@tu lazy val SwitchAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.switch")
@tu lazy val ThrowsAnnot: ClassSymbol = ctx.requiredClass("scala.throws")
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ object SymDenotations {
final def setParamssFromDefs(tparams: List[TypeDef[?]], vparamss: List[List[ValDef[?]]])(using Context): Unit =
setParamss(tparams.map(_.symbol), vparamss.map(_.map(_.symbol)))

/** A pair consistsing of type paremeter symbols and value parameter symbol lists
/** A pair consisting of type parameter symbols and value parameter symbol lists
* of this method definition, or (Nil, Nil) for other symbols.
* Makes use of `rawParamss` when present, or constructs fresh parameter symbols otherwise.
* This method can be allocation-heavy.
Expand Down
16 changes: 10 additions & 6 deletions compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ class ExtractSemanticDB extends Phase:
|| sym == defn.Any_typeCast
|| qualifier.exists(excludeQual)

private def traverseAnnotsOf(sym: Symbol)(using Context): Unit =
for annot <- sym.annotations do
if annot.tree.span.exists
&& annot.tree.span.hasLength
annot.tree match
case tree: Typed => () // hack for inline code
case tree => traverse(tree)

override def traverse(tree: Tree)(using Context): Unit =

inline def traverseCtorParamTpt(ctorSym: Symbol, tpt: Tree): Unit =
Expand All @@ -115,12 +123,7 @@ class ExtractSemanticDB extends Phase:
else
traverse(tpt)

for annot <- tree.symbol.annotations do
if annot.tree.span.exists
&& annot.tree.span.hasLength
annot.tree match
case tree: Typed => () // hack for inline code
case tree => traverse(tree)
traverseAnnotsOf(tree.symbol)

tree match
case tree: PackageDef =>
Expand Down Expand Up @@ -563,6 +566,7 @@ class ExtractSemanticDB extends Phase:
vparams <- vparamss
vparam <- vparams
do
traverseAnnotsOf(vparam.symbol)
if !excludeSymbol(vparam.symbol)
val symkinds =
getters.get(vparam.name).fold(SymbolKind.emptySet)(getter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Symbol_paramSymss(self: Symbol)(using ctx: Context): (List[Symbol], List[List[Symbol]]) =
self.paramSymss

def Symbol_primaryConstructor(self: Symbol)(using Context): Symbol =
self.primaryConstructor

def Symbol_caseFields(self: Symbol)(using ctx: Context): List[Symbol] =
if (!self.isClass) Nil
else self.asClass.paramAccessors.collect {
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/Constructors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,12 @@ class Constructors extends MiniPhase with IdentityDenotTransformer { thisPhase =
Nil
}
else {
if (acc.hasAnnotation(defn.TransientParamAnnot))
ctx.error(em"transient parameter $acc is retained as field in class ${acc.owner}", acc.sourcePos)
val param = acc.subst(accessors, paramSyms)
if (param.hasAnnotation(defn.ConstructorOnlyAnnot))
ctx.error(em"${acc.name} is marked `@constructorOnly` but it is retained as a field in ${acc.owner}", acc.sourcePos)
val target = if (acc.is(Method)) acc.field else acc
if (!target.exists) Nil // this case arises when the parameter accessor is an alias
else {
val param = acc.subst(accessors, paramSyms)
val assigns = Assign(ref(target), ref(param)).withSpan(tree.span) :: Nil
if (acc.name != nme.OUTER) assigns
else {
Expand Down
6 changes: 5 additions & 1 deletion library/src/scala/tasty/Reflection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2223,12 +2223,16 @@ class Reflection(private[scala] val internal: CompilerInterface) { self =>
def methods(using ctx: Context): List[Symbol] =
internal.Symbol_methods(sym)

/** A pair consistsing of type paremeter symbols and value parameter symbol lists
/** A pair consisting of type parameter symbols and value parameter symbol lists
* of this method definition, or (Nil, Nil) for other symbols.
*/
def paramSymss(using ctx: Context): (List[Symbol], List[List[Symbol]]) =
internal.Symbol_paramSymss(sym)

/** The primary constructor of a class or trait, `noSymbol` if not applicable. */
def primaryConstructor(using Context): Symbol =
internal.Symbol_primaryConstructor(sym)

/** Fields of a case class type -- only the ones declared in primary constructor */
def caseFields(using ctx: Context): List[Symbol] =
internal.Symbol_caseFields(sym)
Expand Down
5 changes: 4 additions & 1 deletion library/src/scala/tasty/reflect/CompilerInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1305,11 +1305,14 @@ trait CompilerInterface {
/** Get all non-private methods declared or inherited */
def Symbol_methods(self: Symbol)(using ctx: Context): List[Symbol]

/** A pair consistsing of type paremeter symbols and value parameter symbol lists
/** A pair consisting of type parameter symbols and value parameter symbol lists
* of this method definition, or (Nil, Nil) for other symbols.
*/
def Symbol_paramSymss(self: Symbol)(using ctx: Context): (List[Symbol], List[List[Symbol]])

/** The primary constructor of a class or trait, `noSymbol` if not applicable. */
def Symbol_primaryConstructor(self: Symbol)(using Context): Symbol

/** Fields of a case class type -- only the ones declared in primary constructor */
def Symbol_caseFields(self: Symbol)(using ctx: Context): List[Symbol]

Expand Down
5 changes: 5 additions & 0 deletions tests/run/i8531/Named.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;

@Retention(RetentionPolicy.RUNTIME)
public @interface Named {}
9 changes: 9 additions & 0 deletions tests/run/i8531/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Foo(@Named s: String)

object Test {
def main(args: Array[String]): Unit = {
val ctor = classOf[Foo].getDeclaredConstructors()(0)
val annots = ctor.getParameterAnnotations()(0)
assert(annots.length == 1, annots.length)
}
}