diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 44fa4e9b22fd..7569e16e663f 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1142,6 +1142,11 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { buf.toList } + def collectSubTrees[A](f: PartialFunction[Tree, A])(using Context): List[A] = + val buf = mutable.ListBuffer[A]() + foreachSubTree(f.runWith(buf += _)(_)) + buf.toList + /** Set this tree as the `defTree` of its symbol and return this tree */ def setDefTree(using Context): ThisTree = { val sym = tree.symbol diff --git a/compiler/src/dotty/tools/dotc/core/Flags.scala b/compiler/src/dotty/tools/dotc/core/Flags.scala index 8c1b715e3e30..025b944fe45b 100644 --- a/compiler/src/dotty/tools/dotc/core/Flags.scala +++ b/compiler/src/dotty/tools/dotc/core/Flags.scala @@ -363,7 +363,7 @@ object Flags { val (Enum @ _, EnumVal @ _, _) = newFlags(40, "enum") /** An export forwarder */ - val (Exported @ _, _, _) = newFlags(41, "exported") + val (Exported @ _, ExportedTerm @ _, ExportedType @ _) = newFlags(41, "exported") /** Labeled with `erased` modifier (erased value or class) */ val (Erased @ _, _, _) = newFlags(42, "erased") diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index 57c6b68e9ab8..32a2da8b46b6 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -19,9 +19,7 @@ import DenotTransformers.* import StdNames.* import NameOps.* import NameKinds.LazyImplicitName -import ast.tpd -import tpd.{Tree, TreeProvider, TreeOps} -import ast.TreeTypeMap +import ast.*, tpd.* import Constants.Constant import Variances.Variance import reporting.Message @@ -325,13 +323,26 @@ object Symbols extends SymUtils { /** A symbol related to `sym` that is defined in source code. * - * @see enclosingSourceSymbols + * @see [[interactive.Interactive.enclosingSourceSymbols]] */ @annotation.tailrec final def sourceSymbol(using Context): Symbol = if (!denot.exists) this else if (denot.is(ModuleVal)) this.moduleClass.sourceSymbol // The module val always has a zero-extent position + else if denot.is(ExportedType) then + denot.info.dropAlias.finalResultType.typeConstructor match + case tp: NamedType => tp.symbol.sourceSymbol + case _ => this + else if denot.is(ExportedTerm) then + val root = denot.maybeOwner match + case cls: ClassSymbol => cls.rootTreeContaining(name.toString) + case _ => EmptyTree + val targets = root.collectSubTrees: + case tree: DefDef if tree.symbol == denot.symbol => methPart(tree.rhs).tpe + targets.match + case (tp: NamedType) :: _ => tp.symbol.sourceSymbol + case _ => this else if (denot.is(Synthetic)) { val linked = denot.linkedClass if (linked.exists && !linked.is(Synthetic)) diff --git a/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala b/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala index 5480d4a43043..258d92a2d1a8 100644 --- a/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala +++ b/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala @@ -42,7 +42,12 @@ case class SourceTree(tree: tpd.Import | tpd.NameTree, source: SourceFile) { (treeSpan.end - nameLength, treeSpan.end) Span(start, end, start) } - source.atSpan(position) + // Don't widen the span, only narrow. + // E.g. The star in a wildcard export is 1 character, + // and that is the span of the type alias that results from it + // but the name may very well be larger, which we don't want. + val span1 = if treeSpan.contains(position) then position else treeSpan + source.atSpan(span1) } case _ => NoSourcePosition diff --git a/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala b/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala index 0e64a6c839ab..2c2897e401a1 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala @@ -1,19 +1,14 @@ -package dotty.tools.pc +package dotty.tools +package pc import scala.annotation.tailrec -import dotty.tools.dotc.ast.tpd -import dotty.tools.dotc.ast.tpd.* -import dotty.tools.dotc.ast.untpd -import dotty.tools.dotc.core.Contexts.* -import dotty.tools.dotc.core.Flags.* -import dotty.tools.dotc.core.Names.Name -import dotty.tools.dotc.core.StdNames -import dotty.tools.dotc.core.Symbols.* -import dotty.tools.dotc.core.Types.Type -import dotty.tools.dotc.interactive.SourceTree -import dotty.tools.dotc.util.SourceFile -import dotty.tools.dotc.util.SourcePosition +import dotc.* +import ast.*, tpd.* +import core.*, Contexts.*, Decorators.*, Flags.*, Names.*, Symbols.*, Types.* +import interactive.* +import util.* +import util.SourcePosition object MetalsInteractive: @@ -205,7 +200,10 @@ object MetalsInteractive: Nil case path @ head :: tail => - if head.symbol.is(Synthetic) then + if head.symbol.is(Exported) then + val sym = head.symbol.sourceSymbol + List((sym, sym.info)) + else if head.symbol.is(Synthetic) then enclosingSymbolsWithExpressionType( tail, pos, diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala index c4266ce5d709..0de81ec39711 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala @@ -124,7 +124,7 @@ class PcDefinitionProvider( val isLocal = sym.source == pos.source if isLocal then val defs = - Interactive.findDefinitions(List(sym), driver, false, false) + Interactive.findDefinitions(List(sym), driver, false, false).filter(_.source == sym.source) defs.headOption match case Some(srcTree) => val pos = srcTree.namePos diff --git a/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala index 358e159eb539..9636aea77c2e 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala @@ -199,6 +199,81 @@ class PcDefinitionSuite extends BasePcDefinitionSuite: |""".stripMargin ) + @Test def exportType0 = + check( + """object Foo: + | trait <> + |object Bar: + | export Foo.* + |class Test: + | import Bar.* + | def test = new Ca@@t {} + |""".stripMargin + ) + + @Test def exportType1 = + check( + """object Foo: + | trait <>[A] + |object Bar: + | export Foo.* + |class Test: + | import Bar.* + | def test = new Ca@@t[Int] {} + |""".stripMargin + ) + + @Test def exportTerm0Nullary = + check( + """trait Foo: + | def <>: Int + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th + |""".stripMargin + ) + + @Test def exportTerm0 = + check( + """trait Foo: + | def <>(): Int + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th() + |""".stripMargin + ) + + @Test def exportTerm1 = + check( + """trait Foo: + | def <>(x: Int): Int + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th(0) + |""".stripMargin + ) + + @Test def exportTerm1Poly = + check( + """trait Foo: + | def <>[A](x: A): A + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th(0) + |""".stripMargin + ) + + @Test def exportTerm1Overload = + check( + """trait Foo: + | def <>(x: Int): Int + | def meth(x: String): String + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th(0) + |""".stripMargin + ) + @Test def `named-arg-local` = check( """|