Skip to content

Commit 9693cd8

Browse files
authored
Merge pull request #13474 from dotty-staging/dep-annots
Remove anomalies and gaps in handling annotations
2 parents a38c98e + d921598 commit 9693cd8

File tree

10 files changed

+115
-34
lines changed

10 files changed

+115
-34
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
113113
case _ => 0
114114
}
115115

116-
/** The (last) list of arguments of an application */
117-
def arguments(tree: Tree): List[Tree] = unsplice(tree) match {
118-
case Apply(_, args) => args
119-
case TypeApply(fn, _) => arguments(fn)
120-
case Block(_, expr) => arguments(expr)
116+
/** All term arguments of an application in a single flattened list */
117+
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
118+
case Apply(fn, args) => allArguments(fn) ::: args
119+
case TypeApply(fn, _) => allArguments(fn)
120+
case Block(_, expr) => allArguments(expr)
121121
case _ => Nil
122122
}
123123

compiler/src/dotty/tools/dotc/core/Annotations.scala

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@ import StdNames._
77
import dotty.tools.dotc.ast.tpd
88
import scala.util.Try
99
import util.Spans.Span
10+
import printing.{Showable, Printer}
11+
import printing.Texts.Text
12+
import annotation.internal.sharable
1013

1114
object Annotations {
1215

1316
def annotClass(tree: Tree)(using Context) =
1417
if (tree.symbol.isConstructor) tree.symbol.owner
1518
else tree.tpe.typeSymbol
1619

17-
abstract class Annotation {
20+
abstract class Annotation extends Showable {
1821
def tree(using Context): Tree
1922

2023
def symbol(using Context): Symbol = annotClass(tree)
@@ -26,7 +29,8 @@ object Annotations {
2629
def derivedAnnotation(tree: Tree)(using Context): Annotation =
2730
if (tree eq this.tree) this else Annotation(tree)
2831

29-
def arguments(using Context): List[Tree] = ast.tpd.arguments(tree)
32+
/** All arguments to this annotation in a single flat list */
33+
def arguments(using Context): List[Tree] = ast.tpd.allArguments(tree)
3034

3135
def argument(i: Int)(using Context): Option[Tree] = {
3236
val args = arguments
@@ -44,15 +48,48 @@ object Annotations {
4448
/** The tree evaluation has finished. */
4549
def isEvaluated: Boolean = true
4650

51+
/** Normally, type map over all tree nodes of this annotation, but can
52+
* be overridden. Returns EmptyAnnotation if type type map produces a range
53+
* type, since ranges cannot be types of trees.
54+
*/
55+
def mapWith(tm: TypeMap)(using Context) =
56+
val args = arguments
57+
if args.isEmpty then this
58+
else
59+
val findDiff = new TreeAccumulator[Type]:
60+
def apply(x: Type, tree: Tree)(using Context): Type =
61+
if tm.isRange(x) then x
62+
else
63+
val tp1 = tm(tree.tpe)
64+
foldOver(if tp1 =:= tree.tpe then x else tp1, tree)
65+
val diff = findDiff(NoType, args)
66+
if tm.isRange(diff) then EmptyAnnotation
67+
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
68+
else this
69+
70+
/** Does this annotation refer to a parameter of `tl`? */
71+
def refersToParamOf(tl: TermLambda)(using Context): Boolean =
72+
val args = arguments
73+
if args.isEmpty then false
74+
else tree.existsSubTree {
75+
case id: Ident => id.tpe match
76+
case TermParamRef(tl1, _) => tl eq tl1
77+
case _ => false
78+
case _ => false
79+
}
80+
81+
/** A string representation of the annotation. Overridden in BodyAnnotation.
82+
*/
83+
def toText(printer: Printer): Text = printer.annotText(this)
84+
4785
def ensureCompleted(using Context): Unit = tree
4886

4987
def sameAnnotation(that: Annotation)(using Context): Boolean =
5088
symbol == that.symbol && tree.sameTree(that.tree)
5189
}
5290

53-
case class ConcreteAnnotation(t: Tree) extends Annotation {
91+
case class ConcreteAnnotation(t: Tree) extends Annotation:
5492
def tree(using Context): Tree = t
55-
}
5693

5794
abstract class LazyAnnotation extends Annotation {
5895
protected var mySym: Symbol | (Context ?=> Symbol)
@@ -98,6 +135,7 @@ object Annotations {
98135
if (tree eq this.tree) this else ConcreteBodyAnnotation(tree)
99136
override def arguments(using Context): List[Tree] = Nil
100137
override def ensureCompleted(using Context): Unit = ()
138+
override def toText(printer: Printer): Text = "@Body"
101139
}
102140

103141
class ConcreteBodyAnnotation(body: Tree) extends BodyAnnotation {
@@ -194,6 +232,8 @@ object Annotations {
194232
apply(defn.SourceFileAnnot, Literal(Constant(path)))
195233
}
196234

235+
@sharable val EmptyAnnotation = Annotation(EmptyTree)
236+
197237
def ThrowsAnnotation(cls: ClassSymbol)(using Context): Annotation = {
198238
val tref = cls.typeRef
199239
Annotation(defn.ThrowsAnnot.typeRef.appliedTo(tref), Ident(tref))

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,13 @@ object TypeOps:
162162
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
163163
// corrective steps, so no widening is wanted.
164164
simplify(l, theMap) | simplify(r, theMap)
165-
case AnnotatedType(parent, annot)
166-
if annot.symbol == defn.UncheckedVarianceAnnot && !ctx.mode.is(Mode.Type) && !theMap.isInstanceOf[SimplifyKeepUnchecked] =>
167-
simplify(parent, theMap)
165+
case tp @ AnnotatedType(parent, annot) =>
166+
val parent1 = simplify(parent, theMap)
167+
if annot.symbol == defn.UncheckedVarianceAnnot
168+
&& !ctx.mode.is(Mode.Type)
169+
&& !theMap.isInstanceOf[SimplifyKeepUnchecked]
170+
then parent1
171+
else tp.derivedAnnotatedType(parent1, annot)
168172
case _: MatchType =>
169173
val normed = tp.tryNormalize
170174
if (normed.exists) normed else mapOver

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3604,6 +3604,9 @@ object Types {
36043604
case tp: AppliedType => tp.fold(status, compute(_, _, theAcc))
36053605
case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional)
36063606
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
3607+
case AnnotatedType(parent, ann) =>
3608+
if ann.refersToParamOf(thisLambdaType) then TrueDeps
3609+
else compute(status, parent, theAcc)
36073610
case _: ThisType | _: BoundType | NoPrefix => status
36083611
case _ =>
36093612
(if theAcc != null then theAcc else DepAcc()).foldOver(status, tp)
@@ -3656,8 +3659,10 @@ object Types {
36563659
if (isResultDependent) {
36573660
val dropDependencies = new ApproximatingTypeMap {
36583661
def apply(tp: Type) = tp match {
3659-
case tp @ TermParamRef(thisLambdaType, _) =>
3662+
case tp @ TermParamRef(`thisLambdaType`, _) =>
36603663
range(defn.NothingType, atVariance(1)(apply(tp.underlying)))
3664+
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
3665+
mapOver(parent)
36613666
case _ => mapOver(tp)
36623667
}
36633668
}
@@ -5380,6 +5385,8 @@ object Types {
53805385
variance = saved
53815386
derivedLambdaType(tp)(ptypes1, this(restpe))
53825387

5388+
def isRange(tp: Type): Boolean = tp.isInstanceOf[Range]
5389+
53835390
/** Map this function over given type */
53845391
def mapOver(tp: Type): Type = {
53855392
record(s"TypeMap mapOver ${getClass}")
@@ -5423,8 +5430,9 @@ object Types {
54235430

54245431
case tp @ AnnotatedType(underlying, annot) =>
54255432
val underlying1 = this(underlying)
5426-
if (underlying1 eq underlying) tp
5427-
else derivedAnnotatedType(tp, underlying1, mapOver(annot))
5433+
val annot1 = annot.mapWith(this)
5434+
if annot1 eq EmptyAnnotation then underlying1
5435+
else derivedAnnotatedType(tp, underlying1, annot1)
54285436

54295437
case _: ThisType
54305438
| _: BoundType
@@ -5496,9 +5504,6 @@ object Types {
54965504
else newScopeWith(elems1: _*)
54975505
}
54985506

5499-
def mapOver(annot: Annotation): Annotation =
5500-
annot.derivedAnnotation(mapOver(annot.tree))
5501-
55025507
def mapOver(tree: Tree): Tree = treeTypeMap(tree)
55035508

55045509
/** Can be overridden. By default, only the prefix is mapped. */
@@ -5545,8 +5550,6 @@ object Types {
55455550

55465551
protected def emptyRange = range(defn.NothingType, defn.AnyType)
55475552

5548-
protected def isRange(tp: Type): Boolean = tp.isInstanceOf[Range]
5549-
55505553
protected def lower(tp: Type): Type = tp match {
55515554
case tp: Range => tp.lo
55525555
case _ => tp

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
539539
case _ => literalText(String.valueOf(const.value))
540540
}
541541

542-
def toText(annot: Annotation): Text = s"@${annot.symbol.name}" // for now
542+
/** Usual target for `Annotation#toText`, overridden in RefinedPrinter */
543+
def annotText(annot: Annotation): Text = s"@${annot.symbol.name}"
544+
545+
def toText(annot: Annotation): Text = annot.toText(this)
543546

544547
def toText(param: LambdaParam): Text =
545548
varianceSign(param.paramVariance)
@@ -570,7 +573,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
570573
Text()
571574

572575
nodeName ~ "(" ~ elems ~ tpSuffix ~ ")" ~ (Str(tree.sourcePos.toString) provided printDebug)
573-
}.close // todo: override in refined printer
576+
}.close
574577

575578
def toText(pos: SourcePosition): Text =
576579
if (!pos.exists) "<no position>"

compiler/src/dotty/tools/dotc/printing/Printer.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ abstract class Printer {
119119
/** A description of sym's location */
120120
def extendedLocationText(sym: Symbol): Text
121121

122+
/** Textual description of regular annotation in terms of its tree */
123+
def annotText(annot: Annotation): Text
124+
122125
/** Textual representation of denotation */
123126
def toText(denot: Denotation): Text
124127

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import typer.ProtoTypes._
2121
import Trees._
2222
import TypeApplications._
2323
import Decorators._
24-
import NameKinds.WildcardParamName
24+
import NameKinds.{WildcardParamName, DefaultGetterName}
2525
import util.Chars.isOperatorPart
2626
import transform.TypeUtils._
2727
import transform.SymUtils._
@@ -607,7 +607,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
607607
case tree: Template =>
608608
toTextTemplate(tree)
609609
case Annotated(arg, annot) =>
610-
toTextLocal(arg) ~~ annotText(annot)
610+
toTextLocal(arg) ~~ toText(annot)
611611
case EmptyTree =>
612612
"<empty>"
613613
case TypedSplice(t) =>
@@ -964,14 +964,18 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
964964
keywordStr("package ") ~ toTextPackageId(tree.pid) ~ bodyText
965965
}
966966

967+
/** Textual representation of an instance creation expression without the leading `new` */
967968
protected def constrText(tree: untpd.Tree): Text = toTextLocal(tree).stripPrefix(keywordStr("new ")) // DD
968969

969-
protected def annotText(tree: untpd.Tree): Text = "@" ~ constrText(tree) // DD
970-
971-
override def annotsText(sym: Symbol): Text =
972-
Text(sym.annotations.map(ann =>
973-
if ann.symbol == defn.BodyAnnot then Str(simpleNameString(ann.symbol))
974-
else annotText(ann.tree)))
970+
protected def annotText(sym: Symbol, tree: untpd.Tree): Text =
971+
def recur(t: untpd.Tree): Text = t match
972+
case Apply(fn, Nil) => recur(fn)
973+
case Apply(fn, args) =>
974+
val explicitArgs = args.filterNot(_.symbol.name.is(DefaultGetterName))
975+
recur(fn) ~ "(" ~ toTextGlobal(explicitArgs, ", ") ~ ")"
976+
case TypeApply(fn, args) => recur(fn) ~ "[" ~ toTextGlobal(args, ", ") ~ "]"
977+
case _ => s"@${sym.orElse(tree.symbol).name}"
978+
recur(tree)
975979

976980
protected def modText(mods: untpd.Modifiers, sym: Symbol, kw: String, isType: Boolean): Text = { // DD
977981
val suppressKw = if (enclDefIsClass) mods.isAllOf(LocalParam) else mods.is(Param)
@@ -984,12 +988,16 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
984988
if (rawFlags.is(Param)) flagMask = flagMask &~ Given &~ Erased
985989
val flags = rawFlags & flagMask
986990
var flagsText = toTextFlags(sym, flags)
987-
val annotations =
988-
if (sym.exists) sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)
989-
else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol))
990-
Text(annotations.map(annotText), " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
991+
val annotTexts =
992+
if sym.exists then
993+
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText)
994+
else
995+
mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol)).map(annotText(NoSymbol, _))
996+
Text(annotTexts, " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
991997
}
992998

999+
override def annotText(annot: Annotation): Text = annotText(annot.symbol, annot.tree)
1000+
9931001
def optText(name: Name)(encl: Text => Text): Text =
9941002
if (name.isEmpty) "" else encl(toText(name))
9951003

tests/neg/annot-printing.check

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
-- [E007] Type Mismatch Error: tests/neg/annot-printing.scala:5:46 -----------------------------------------------------
2+
5 |def x: Int @nowarn @main @Foo @Bar("hello") = "abc" // error
3+
| ^^^^^
4+
| Found: ("abc" : String)
5+
| Required: Int @nowarn() @main @Foo @Bar("hello")
6+
7+
longer explanation available when compiling with `-explain`

tests/neg/annot-printing.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import scala.annotation.*
2+
class Foo() extends Annotation
3+
class Bar(s: String) extends Annotation
4+
5+
def x: Int @nowarn @main @Foo @Bar("hello") = "abc" // error
6+

tests/pos/dependent-annot.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class C
2+
class ann(x: Any*) extends annotation.Annotation
3+
4+
def f(y: C, z: C) =
5+
def g(): C @ann(y, z) = ???
6+
val ac: ((x: C) => Array[String @ann(x)]) = ???
7+
val dc = ac(g())

0 commit comments

Comments
 (0)