Skip to content

Commit 1a84caa

Browse files
authored
Merge pull request #13348 from dotty-staging/dep-annots
Remove anomalies and gaps in handling annotations
2 parents 321a92c + 00c9adb commit 1a84caa

19 files changed

+231
-42
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
113113
case _ => 0
114114
}
115115

116-
/** The (last) list of arguments of an application */
117-
def arguments(tree: Tree): List[Tree] = unsplice(tree) match {
118-
case Apply(_, args) => args
119-
case TypeApply(fn, _) => arguments(fn)
120-
case Block(_, expr) => arguments(expr)
116+
/** All term arguments of an application in a single flattened list */
117+
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
118+
case Apply(fn, args) => allArguments(fn) ::: args
119+
case TypeApply(fn, _) => allArguments(fn)
120+
case Block(_, expr) => allArguments(expr)
121121
case _ => Nil
122122
}
123123

compiler/src/dotty/tools/dotc/core/Annotations.scala

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@ import StdNames._
77
import dotty.tools.dotc.ast.tpd
88
import scala.util.Try
99
import util.Spans.Span
10+
import printing.{Showable, Printer}
11+
import printing.Texts.Text
12+
import annotation.internal.sharable
1013

1114
object Annotations {
1215

1316
def annotClass(tree: Tree)(using Context) =
1417
if (tree.symbol.isConstructor) tree.symbol.owner
1518
else tree.tpe.typeSymbol
1619

17-
abstract class Annotation {
20+
abstract class Annotation extends Showable {
1821
def tree(using Context): Tree
1922

2023
def symbol(using Context): Symbol = annotClass(tree)
@@ -26,7 +29,8 @@ object Annotations {
2629
def derivedAnnotation(tree: Tree)(using Context): Annotation =
2730
if (tree eq this.tree) this else Annotation(tree)
2831

29-
def arguments(using Context): List[Tree] = ast.tpd.arguments(tree)
32+
/** All arguments to this annotation in a single flat list */
33+
def arguments(using Context): List[Tree] = ast.tpd.allArguments(tree)
3034

3135
def argument(i: Int)(using Context): Option[Tree] = {
3236
val args = arguments
@@ -44,15 +48,48 @@ object Annotations {
4448
/** The tree evaluation has finished. */
4549
def isEvaluated: Boolean = true
4650

51+
/** Normally, type map over all tree nodes of this annotation, but can
52+
* be overridden. Returns EmptyAnnotation if type type map produces a range
53+
* type, since ranges cannot be types of trees.
54+
*/
55+
def mapWith(tm: TypeMap)(using Context) =
56+
val args = arguments
57+
if args.isEmpty then this
58+
else
59+
val findDiff = new TreeAccumulator[Type]:
60+
def apply(x: Type, tree: Tree)(using Context): Type =
61+
if tm.isRange(x) then x
62+
else
63+
val tp1 = tm(tree.tpe)
64+
foldOver(if tp1 =:= tree.tpe then x else tp1, tree)
65+
val diff = findDiff(NoType, args)
66+
if tm.isRange(diff) then EmptyAnnotation
67+
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
68+
else this
69+
70+
/** Does this annotation refer to a parameter of `tl`? */
71+
def refersToParamOf(tl: TermLambda)(using Context): Boolean =
72+
val args = arguments
73+
if args.isEmpty then false
74+
else tree.existsSubTree {
75+
case id: Ident => id.tpe match
76+
case TermParamRef(tl1, _) => tl eq tl1
77+
case _ => false
78+
case _ => false
79+
}
80+
81+
/** A string representation of the annotation. Overridden in BodyAnnotation.
82+
*/
83+
def toText(printer: Printer): Text = printer.annotText(this)
84+
4785
def ensureCompleted(using Context): Unit = tree
4886

4987
def sameAnnotation(that: Annotation)(using Context): Boolean =
5088
symbol == that.symbol && tree.sameTree(that.tree)
5189
}
5290

53-
case class ConcreteAnnotation(t: Tree) extends Annotation {
91+
case class ConcreteAnnotation(t: Tree) extends Annotation:
5492
def tree(using Context): Tree = t
55-
}
5693

5794
abstract class LazyAnnotation extends Annotation {
5895
protected var mySym: Symbol | (Context ?=> Symbol)
@@ -98,6 +135,7 @@ object Annotations {
98135
if (tree eq this.tree) this else ConcreteBodyAnnotation(tree)
99136
override def arguments(using Context): List[Tree] = Nil
100137
override def ensureCompleted(using Context): Unit = ()
138+
override def toText(printer: Printer): Text = "@Body"
101139
}
102140

103141
class ConcreteBodyAnnotation(body: Tree) extends BodyAnnotation {
@@ -194,6 +232,8 @@ object Annotations {
194232
apply(defn.SourceFileAnnot, Literal(Constant(path)))
195233
}
196234

235+
@sharable val EmptyAnnotation = Annotation(EmptyTree)
236+
197237
def ThrowsAnnotation(cls: ClassSymbol)(using Context): Annotation = {
198238
val tref = cls.typeRef
199239
Annotation(defn.ThrowsAnnot.typeRef.appliedTo(tref), Ident(tref))

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,13 @@ object TypeOps:
164164
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
165165
// corrective steps, so no widening is wanted.
166166
simplify(l, theMap) | simplify(r, theMap)
167-
case AnnotatedType(parent, annot)
168-
if annot.symbol == defn.UncheckedVarianceAnnot && !ctx.mode.is(Mode.Type) && !theMap.isInstanceOf[SimplifyKeepUnchecked] =>
169-
simplify(parent, theMap)
167+
case tp @ AnnotatedType(parent, annot) =>
168+
val parent1 = simplify(parent, theMap)
169+
if annot.symbol == defn.UncheckedVarianceAnnot
170+
&& !ctx.mode.is(Mode.Type)
171+
&& !theMap.isInstanceOf[SimplifyKeepUnchecked]
172+
then parent1
173+
else tp.derivedAnnotatedType(parent1, annot)
170174
case _: MatchType =>
171175
val normed = tp.tryNormalize
172176
if (normed.exists) normed else mapOver

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3601,6 +3601,9 @@ object Types {
36013601
case tp: AppliedType => tp.fold(status, compute(_, _, theAcc))
36023602
case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional)
36033603
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
3604+
case AnnotatedType(parent, ann) =>
3605+
if ann.refersToParamOf(thisLambdaType) then TrueDeps
3606+
else compute(status, parent, theAcc)
36043607
case _: ThisType | _: BoundType | NoPrefix => status
36053608
case _ =>
36063609
(if theAcc != null then theAcc else DepAcc()).foldOver(status, tp)
@@ -3653,8 +3656,10 @@ object Types {
36533656
if (isResultDependent) {
36543657
val dropDependencies = new ApproximatingTypeMap {
36553658
def apply(tp: Type) = tp match {
3656-
case tp @ TermParamRef(thisLambdaType, _) =>
3659+
case tp @ TermParamRef(`thisLambdaType`, _) =>
36573660
range(defn.NothingType, atVariance(1)(apply(tp.underlying)))
3661+
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
3662+
mapOver(parent)
36583663
case _ => mapOver(tp)
36593664
}
36603665
}
@@ -5379,6 +5384,8 @@ object Types {
53795384
variance = saved
53805385
derivedLambdaType(tp)(ptypes1, this(restpe))
53815386

5387+
def isRange(tp: Type): Boolean = tp.isInstanceOf[Range]
5388+
53825389
/** Map this function over given type */
53835390
def mapOver(tp: Type): Type = {
53845391
record(s"TypeMap mapOver ${getClass}")
@@ -5422,8 +5429,9 @@ object Types {
54225429

54235430
case tp @ AnnotatedType(underlying, annot) =>
54245431
val underlying1 = this(underlying)
5425-
if (underlying1 eq underlying) tp
5426-
else derivedAnnotatedType(tp, underlying1, mapOver(annot))
5432+
val annot1 = annot.mapWith(this)
5433+
if annot1 eq EmptyAnnotation then underlying1
5434+
else derivedAnnotatedType(tp, underlying1, annot1)
54275435

54285436
case _: ThisType
54295437
| _: BoundType
@@ -5495,9 +5503,6 @@ object Types {
54955503
else newScopeWith(elems1: _*)
54965504
}
54975505

5498-
def mapOver(annot: Annotation): Annotation =
5499-
annot.derivedAnnotation(mapOver(annot.tree))
5500-
55015506
def mapOver(tree: Tree): Tree = treeTypeMap(tree)
55025507

55035508
/** Can be overridden. By default, only the prefix is mapped. */
@@ -5544,8 +5549,6 @@ object Types {
55445549

55455550
protected def emptyRange = range(defn.NothingType, defn.AnyType)
55465551

5547-
protected def isRange(tp: Type): Boolean = tp.isInstanceOf[Range]
5548-
55495552
protected def lower(tp: Type): Type = tp match {
55505553
case tp: Range => tp.lo
55515554
case _ => tp

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
543543
case _ => literalText(String.valueOf(const.value))
544544
}
545545

546-
def toText(annot: Annotation): Text = s"@${annot.symbol.name}" // for now
546+
/** Usual target for `Annotation#toText`, overridden in RefinedPrinter */
547+
def annotText(annot: Annotation): Text = s"@${annot.symbol.name}"
548+
549+
def toText(annot: Annotation): Text = annot.toText(this)
547550

548551
def toText(param: LambdaParam): Text =
549552
varianceSign(param.paramVariance)
@@ -574,7 +577,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
574577
Text()
575578

576579
nodeName ~ "(" ~ elems ~ tpSuffix ~ ")" ~ (Str(tree.sourcePos.toString) provided printDebug)
577-
}.close // todo: override in refined printer
580+
}.close
578581

579582
def toText(pos: SourcePosition): Text =
580583
if (!pos.exists) "<no position>"

compiler/src/dotty/tools/dotc/printing/Printer.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ abstract class Printer {
119119
/** A description of sym's location */
120120
def extendedLocationText(sym: Symbol): Text
121121

122+
/** Textual description of regular annotation in terms of its tree */
123+
def annotText(annot: Annotation): Text
124+
122125
/** Textual representation of denotation */
123126
def toText(denot: Denotation): Text
124127

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import typer.ProtoTypes._
2121
import Trees._
2222
import TypeApplications._
2323
import Decorators._
24-
import NameKinds.WildcardParamName
24+
import NameKinds.{WildcardParamName, DefaultGetterName}
2525
import util.Chars.isOperatorPart
2626
import transform.TypeUtils._
2727
import transform.SymUtils._
@@ -607,7 +607,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
607607
case tree: Template =>
608608
toTextTemplate(tree)
609609
case Annotated(arg, annot) =>
610-
toTextLocal(arg) ~~ annotText(annot)
610+
toTextLocal(arg) ~~ annotText(annot.symbol.enclosingClass, annot)
611611
case EmptyTree =>
612612
"<empty>"
613613
case TypedSplice(t) =>
@@ -964,14 +964,22 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
964964
keywordStr("package ") ~ toTextPackageId(tree.pid) ~ bodyText
965965
}
966966

967+
/** Textual representation of an instance creation expression without the leading `new` */
967968
protected def constrText(tree: untpd.Tree): Text = toTextLocal(tree).stripPrefix(keywordStr("new ")) // DD
968969

969-
protected def annotText(tree: untpd.Tree): Text = "@" ~ constrText(tree) // DD
970-
971-
override def annotsText(sym: Symbol): Text =
972-
Text(sym.annotations.map(ann =>
973-
if ann.symbol == defn.BodyAnnot then Str(simpleNameString(ann.symbol))
974-
else annotText(ann.tree)))
970+
protected def annotText(sym: Symbol, tree: untpd.Tree): Text =
971+
def recur(t: untpd.Tree): Text = t match
972+
case Apply(fn, Nil) => recur(fn)
973+
case Apply(fn, args) =>
974+
val explicitArgs = args.filterNot(_.symbol.name.is(DefaultGetterName))
975+
recur(fn) ~ "(" ~ toTextGlobal(explicitArgs, ", ") ~ ")"
976+
case TypeApply(fn, args) => recur(fn) ~ "[" ~ toTextGlobal(args, ", ") ~ "]"
977+
case Select(qual, nme.CONSTRUCTOR) => recur(qual)
978+
case New(tpt) => recur(tpt)
979+
case _ =>
980+
val annotSym = sym.orElse(tree.symbol.enclosingClass)
981+
s"@${if annotSym.exists then annotSym.name.toString else t.show}"
982+
recur(tree)
975983

976984
protected def modText(mods: untpd.Modifiers, sym: Symbol, kw: String, isType: Boolean): Text = { // DD
977985
val suppressKw = if (enclDefIsClass) mods.isAllOf(LocalParam) else mods.is(Param)
@@ -984,12 +992,16 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
984992
if (rawFlags.is(Param)) flagMask = flagMask &~ Given &~ Erased
985993
val flags = rawFlags & flagMask
986994
var flagsText = toTextFlags(sym, flags)
987-
val annotations =
988-
if (sym.exists) sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)
989-
else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol))
990-
Text(annotations.map(annotText), " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
995+
val annotTexts =
996+
if sym.exists then
997+
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText)
998+
else
999+
mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol)).map(annotText(NoSymbol, _))
1000+
Text(annotTexts, " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
9911001
}
9921002

1003+
override def annotText(annot: Annotation): Text = annotText(annot.symbol, annot.tree)
1004+
9931005
def optText(name: Name)(encl: Text => Text): Text =
9941006
if (name.isEmpty) "" else encl(toText(name))
9951007

compiler/test/dotty/tools/dotc/printing/PrintingTest.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@ import scala.io.Source
1919
import org.junit.Test
2020

2121
class PrintingTest {
22-
val testsDir = "tests/printing"
23-
val options = List("-Xprint:typer", "-color:never", "-classpath", TestConfiguration.basicClasspath)
2422

25-
private def compileFile(path: JPath): Boolean = {
23+
def options(phase: String) =
24+
List(s"-Xprint:$phase", "-color:never", "-classpath", TestConfiguration.basicClasspath)
25+
26+
private def compileFile(path: JPath, phase: String): Boolean = {
2627
val baseFilePath = path.toString.stripSuffix(".scala")
2728
val checkFilePath = baseFilePath + ".check"
2829
val byteStream = new ByteArrayOutputStream()
2930
val reporter = TestReporter.reporter(new PrintStream(byteStream), INFO)
3031

3132
try {
32-
Main.process((path.toString::options).toArray, reporter, null)
33+
Main.process((path.toString::options(phase)).toArray, reporter, null)
3334
} catch {
3435
case e: Throwable =>
3536
println(s"Compile $path exception:")
@@ -40,11 +41,10 @@ class PrintingTest {
4041
FileDiff.checkAndDump(path.toString, actualLines.toIndexedSeq, checkFilePath)
4142
}
4243

43-
@Test
44-
def printing: Unit = {
44+
def testIn(testsDir: String, phase: String) =
4545
val res = Directory(testsDir).list.toList
4646
.filter(f => f.extension == "scala")
47-
.map { f => compileFile(f.jpath) }
47+
.map { f => compileFile(f.jpath, phase) }
4848

4949
val failed = res.filter(!_)
5050

@@ -53,5 +53,12 @@ class PrintingTest {
5353
assert(failed.length == 0, msg)
5454

5555
println(msg)
56-
}
56+
57+
end testIn
58+
59+
@Test
60+
def printing: Unit = testIn("tests/printing", "typer")
61+
62+
@Test
63+
def untypedPrinting: Unit = testIn("tests/printing/untyped", "parser")
5764
}

tests/neg/annot-printing.check

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
-- [E007] Type Mismatch Error: tests/neg/annot-printing.scala:5:46 -----------------------------------------------------
2+
5 |def x: Int @nowarn @main @Foo @Bar("hello") = "abc" // error
3+
| ^^^^^
4+
| Found: ("abc" : String)
5+
| Required: Int @nowarn() @main @Foo @Bar("hello")
6+
7+
longer explanation available when compiling with `-explain`

tests/neg/annot-printing.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import scala.annotation.*
2+
class Foo() extends Annotation
3+
class Bar(s: String) extends Annotation
4+
5+
def x: Int @nowarn @main @Foo @Bar("hello") = "abc" // error
6+

tests/pos/dependent-annot.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class C
2+
class ann(x: Any*) extends annotation.Annotation
3+
4+
def f(y: C, z: C) =
5+
def g(): C @ann(y, z) = ???
6+
val ac: ((x: C) => Array[String @ann(x)]) = ???
7+
val dc = ac(g())

tests/printing/annot-printing.check

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[[syntax trees at end of typer]] // tests/printing/annot-printing.scala
2+
package <empty> {
3+
import scala.annotation.*
4+
class Foo() extends annotation.Annotation() {}
5+
class Bar(s: String) extends annotation.Annotation() {
6+
private[this] val s: String
7+
}
8+
class Xyz(i: Int) extends annotation.Annotation() {
9+
private[this] val i: Int
10+
}
11+
final lazy module val Xyz: Xyz = new Xyz()
12+
final module class Xyz() extends AnyRef() { this: Xyz.type =>
13+
def $lessinit$greater$default$1: Int @uncheckedVariance = 23
14+
}
15+
final lazy module val annot-printing$package: annot-printing$package =
16+
new annot-printing$package()
17+
final module class annot-printing$package() extends Object() {
18+
this: annot-printing$package.type =>
19+
def x: Int @nowarn() @main @Xyz() @Foo @Bar("hello") = ???
20+
}
21+
}
22+

tests/printing/annot-printing.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import scala.annotation.*
2+
3+
class Foo() extends Annotation
4+
class Bar(s: String) extends Annotation
5+
class Xyz(i: Int = 23) extends Annotation
6+
7+
def x: Int @nowarn @main @Xyz() @Foo @Bar("hello") = ???

0 commit comments

Comments
 (0)