diff --git a/NOTICE.md b/NOTICE.md index f4d0e6ed2b5a..64c1ede1a5eb 100644 --- a/NOTICE.md +++ b/NOTICE.md @@ -89,15 +89,19 @@ major authors were omitted by oversight. details. * dotty.tools.dotc.coverage: Coverage instrumentation utilities have been - adapted from the scoverage plugin for scala 2 [5], which is under the + adapted from the scoverage plugin for scala 2 [4], which is under the Apache 2.0 license. + * dooty.tools.pc: Presentation compiler implementation adapted from + scalameta/metals [5] mtags module, which is under the Apache 2.0 license. + * The Dotty codebase contains parts which are derived from - the ScalaPB protobuf library [4], which is under the Apache 2.0 license. + the ScalaPB protobuf library [6], which is under the Apache 2.0 license. [1] https://github.com/scala/scala [2] https://github.com/adriaanm/scala/tree/sbt-api-consolidate/src/compiler/scala/tools/sbt [3] https://github.com/sbt/sbt/tree/0.13/compile/interface/src/main/scala/xsbt -[4] https://github.com/lampepfl/dotty/pull/5783/files -[5] https://github.com/scoverage/scalac-scoverage-plugin +[4] https://github.com/scoverage/scalac-scoverage-plugin +[5] https://github.com/scalameta/metals +[6] https://github.com/lampepfl/dotty/pull/5783/files diff --git a/build.sbt b/build.sbt index 80a36739d5e8..d6a366305f96 100644 --- a/build.sbt +++ b/build.sbt @@ -28,6 +28,8 @@ val `scala3-bench-run` = Build.`scala3-bench-run` val dist = Build.dist val `community-build` = Build.`community-build` val `sbt-community-build` = Build.`sbt-community-build` +val `scala3-presentation-compiler` = Build.`scala3-presentation-compiler` +val `scala3-presentation-compiler-bootstrapped` = Build.`scala3-presentation-compiler-bootstrapped` val sjsSandbox = Build.sjsSandbox val sjsJUnitTests = Build.sjsJUnitTests diff --git a/compiler/src/dotty/tools/dotc/CompilationUnit.scala b/compiler/src/dotty/tools/dotc/CompilationUnit.scala index 8415646eb16c..c121fbaf7c00 100644 --- a/compiler/src/dotty/tools/dotc/CompilationUnit.scala +++ b/compiler/src/dotty/tools/dotc/CompilationUnit.scala @@ -5,6 +5,7 @@ import core._ import Contexts._ import SymDenotations.ClassDenotation import Symbols._ +import Comments.Comment import util.{FreshNameCreator, SourceFile, NoSource} import util.Spans.Span import ast.{tpd, untpd} @@ -69,6 +70,9 @@ class CompilationUnit protected (val source: SourceFile) { /** Can this compilation unit be suspended */ def isSuspendable: Boolean = true + /** List of all comments present in this compilation unit */ + var comments: List[Comment] = Nil + /** Suspends the compilation unit by thowing a SuspendException * and recording the suspended compilation unit */ diff --git a/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala b/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala index 132ff162be61..2a2860cd1ba3 100644 --- a/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala +++ b/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala @@ -145,7 +145,7 @@ class InteractiveDriver(val settings: List[String]) extends Driver { (fromSource ++ fromClassPath).distinct } - def run(uri: URI, sourceCode: String): List[Diagnostic] = run(uri, toSource(uri, sourceCode)) + def run(uri: URI, sourceCode: String): List[Diagnostic] = run(uri, SourceFile.virtual(uri, sourceCode)) def run(uri: URI, source: SourceFile): List[Diagnostic] = { import typer.ImportInfo._ @@ -297,9 +297,6 @@ class InteractiveDriver(val settings: List[String]) extends Driver { cleanupTree(tree) } - private def toSource(uri: URI, sourceCode: String): SourceFile = - SourceFile.virtual(Paths.get(uri).toString, sourceCode) - /** * Initialize this driver and compiler. * @@ -323,7 +320,7 @@ object InteractiveDriver { else try // We don't use file.file here since it'll be null - // for the VirtualFiles created by InteractiveDriver#toSource + // for the VirtualFiles created by SourceFile#virtual // TODO: To avoid these round trip conversions, we could add an // AbstractFile#toUri method and implement it by returning a constant // passed as a parameter to a constructor of VirtualFile diff --git a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala index a67bca34cae2..7caff4996b85 100644 --- a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala +++ b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala @@ -30,6 +30,7 @@ class Parser extends Phase { val p = new Parsers.Parser(unit.source) // p.in.debugTokenStream = true val tree = p.parse() + ctx.compilationUnit.comments = p.in.comments if (p.firstXmlPos.exists && !firstXmlPos.exists) firstXmlPos = p.firstXmlPos tree diff --git a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala index fac73bfb4992..0339fc0531f4 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala @@ -227,11 +227,11 @@ object Scanners { */ private var docstringMap: SortedMap[Int, Comment] = SortedMap.empty - /* A Buffer for comment positions */ - private val commentPosBuf = new mutable.ListBuffer[Span] + /* A Buffer for comments */ + private val commentBuf = new mutable.ListBuffer[Comment] - /** Return a list of all the comment positions */ - def commentSpans: List[Span] = commentPosBuf.toList + /** Return a list of all the comments */ + def comments: List[Comment] = commentBuf.toList private def addComment(comment: Comment): Unit = { val lookahead = lookaheadReader() @@ -246,7 +246,7 @@ object Scanners { def getDocComment(pos: Int): Option[Comment] = docstringMap.get(pos) /** A buffer for comments */ - private val commentBuf = CharBuffer(initialCharBufferSize) + private val currentCommentBuf = CharBuffer(initialCharBufferSize) def toToken(identifier: SimpleName): Token = def handleMigration(keyword: Token): Token = @@ -523,7 +523,7 @@ object Scanners { * * The following tokens can start an indentation region: * - * : = => <- if then else while do try catch + * : = => <- if then else while do try catch * finally for yield match throw return with * * Inserting an INDENT starts a new indentation region with the indentation of the current @@ -1019,7 +1019,7 @@ object Scanners { private def skipComment(): Boolean = { def appendToComment(ch: Char) = - if (keepComments) commentBuf.append(ch) + if (keepComments) currentCommentBuf.append(ch) def nextChar() = { appendToComment(ch) Scanner.this.nextChar() @@ -1047,9 +1047,9 @@ object Scanners { def finishComment(): Boolean = { if (keepComments) { val pos = Span(start, charOffset - 1, start) - val comment = Comment(pos, commentBuf.toString) - commentBuf.clear() - commentPosBuf += pos + val comment = Comment(pos, currentCommentBuf.toString) + currentCommentBuf.clear() + commentBuf += comment if (comment.isDocComment) addComment(comment) @@ -1065,7 +1065,7 @@ object Scanners { else if (ch == '*') { nextChar(); skipComment(); finishComment() } else { // This was not a comment, remove the `/` from the buffer - commentBuf.clear() + currentCommentBuf.clear() false } } diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 700b3fbf525f..6ec4ba6ac0ad 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -179,7 +179,7 @@ class PlainPrinter(_ctx: Context) extends Printer { if (printWithoutPrefix.contains(tp.symbol)) toText(tp.name) else - toTextPrefix(tp.prefix) ~ selectionString(tp) + toTextPrefixOf(tp) ~ selectionString(tp) case tp: TermParamRef => ParamRefNameString(tp) ~ lambdaHash(tp.binder) ~ ".type" case tp: TypeParamRef => @@ -353,7 +353,7 @@ class PlainPrinter(_ctx: Context) extends Printer { def toTextRef(tp: SingletonType): Text = controlled { tp match { case tp: TermRef => - toTextPrefix(tp.prefix) ~ selectionString(tp) + toTextPrefixOf(tp) ~ selectionString(tp) case tp: ThisType => nameString(tp.cls) + ".this" case SuperType(thistpe: SingletonType, _) => @@ -375,15 +375,6 @@ class PlainPrinter(_ctx: Context) extends Printer { } } - /** The string representation of this type used as a prefix, including separator */ - def toTextPrefix(tp: Type): Text = controlled { - homogenize(tp) match { - case NoPrefix => "" - case tp: SingletonType => toTextRef(tp) ~ "." - case tp => trimPrefix(toTextLocal(tp)) ~ "#" - } - } - def toTextCaptureRef(tp: Type): Text = homogenize(tp) match case tp: TermRef if tp.symbol == defn.captureRoot => Str("cap") @@ -393,6 +384,15 @@ class PlainPrinter(_ctx: Context) extends Printer { protected def isOmittablePrefix(sym: Symbol): Boolean = defn.unqualifiedOwnerTypes.exists(_.symbol == sym) || isEmptyPrefix(sym) + /** The string representation of type prefix, including separator */ + def toTextPrefixOf(tp: NamedType): Text = controlled { + homogenize(tp.prefix) match { + case NoPrefix => "" + case tp: SingletonType => toTextRef(tp) ~ "." + case tp => trimPrefix(toTextLocal(tp)) ~ "#" + } + } + protected def isEmptyPrefix(sym: Symbol): Boolean = sym.isEffectiveRoot || sym.isAnonymousClass || sym.name.isReplWrapperName diff --git a/compiler/src/dotty/tools/dotc/printing/Printer.scala b/compiler/src/dotty/tools/dotc/printing/Printer.scala index ab0c867ec31f..04cea9fb9702 100644 --- a/compiler/src/dotty/tools/dotc/printing/Printer.scala +++ b/compiler/src/dotty/tools/dotc/printing/Printer.scala @@ -4,7 +4,7 @@ package printing import core._ import Texts._, ast.Trees._ -import Types.{Type, SingletonType, LambdaParam}, +import Types.{Type, SingletonType, LambdaParam, NamedType}, Symbols.Symbol, Scopes.Scope, Constants.Constant, Names.Name, Denotations._, Annotations.Annotation, Contexts.Context import typer.Implicits.* @@ -101,7 +101,7 @@ abstract class Printer { def toTextRef(tp: SingletonType): Text /** Textual representation of a prefix of some reference, ending in `.` or `#` */ - def toTextPrefix(tp: Type): Text + def toTextPrefixOf(tp: NamedType): Text /** Textual representation of a reference in a capture set */ def toTextCaptureRef(tp: Type): Text diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 98478fae92e8..9240f20cdc49 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -117,21 +117,22 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { } } - override def toTextPrefix(tp: Type): Text = controlled { + override def toTextPrefixOf(tp: NamedType): Text = controlled { def isOmittable(sym: Symbol) = if printDebug then false else if homogenizedView then isEmptyPrefix(sym) // drop and anonymous classes, but not scala, Predef. else if sym.isPackageObject then isOmittablePrefix(sym.owner) else isOmittablePrefix(sym) - tp match { - case tp: ThisType if isOmittable(tp.cls) => + + tp.prefix match { + case thisType: ThisType if isOmittable(thisType.cls) => "" - case tp @ TermRef(pre, _) => - val sym = tp.symbol - if sym.isPackageObject && !homogenizedView && !printDebug then toTextPrefix(pre) + case termRef @ TermRef(pre, _) => + val sym = termRef.symbol + if sym.isPackageObject && !homogenizedView && !printDebug then toTextPrefixOf(termRef) else if (isOmittable(sym)) "" - else super.toTextPrefix(tp) - case _ => super.toTextPrefix(tp) + else super.toTextPrefixOf(tp) + case _ => super.toTextPrefixOf(tp) } } @@ -427,8 +428,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { case id @ Ident(name) => val txt = tree.typeOpt match { case tp: NamedType if name != nme.WILDCARD => - val pre = if (tp.symbol.is(JavaStatic)) tp.prefix.widen else tp.prefix - toTextPrefix(pre) ~ withPos(selectionString(tp), tree.sourcePos) + toTextPrefixOf(tp) ~ withPos(selectionString(tp), tree.sourcePos) case _ => toText(name) } diff --git a/compiler/src/dotty/tools/dotc/printing/SyntaxHighlighting.scala b/compiler/src/dotty/tools/dotc/printing/SyntaxHighlighting.scala index 53e6b9472f5e..7030776dd06c 100644 --- a/compiler/src/dotty/tools/dotc/printing/SyntaxHighlighting.scala +++ b/compiler/src/dotty/tools/dotc/printing/SyntaxHighlighting.scala @@ -83,8 +83,8 @@ object SyntaxHighlighting { } } - for (span <- scanner.commentSpans) - highlightPosition(span, CommentColor) + for (comment <- scanner.comments) + highlightPosition(comment.span, CommentColor) object TreeHighlighter extends untpd.UntypedTreeTraverser { import untpd._ diff --git a/compiler/src/dotty/tools/dotc/typer/ImportSuggestions.scala b/compiler/src/dotty/tools/dotc/typer/ImportSuggestions.scala index a9b53f0783bd..70addd442100 100644 --- a/compiler/src/dotty/tools/dotc/typer/ImportSuggestions.scala +++ b/compiler/src/dotty/tools/dotc/typer/ImportSuggestions.scala @@ -330,7 +330,7 @@ trait ImportSuggestions: def importString(ref: TermRef): String = val imported = if ref.symbol.is(ExtensionMethod) then - s"${ctx.printer.toTextPrefix(ref.prefix).show}${ref.symbol.name}" + s"${ctx.printer.toTextPrefixOf(ref).show}${ref.symbol.name}" else ctx.printer.toTextRef(ref).show s" import $imported" diff --git a/compiler/src/dotty/tools/dotc/util/DiffUtil.scala b/compiler/src/dotty/tools/dotc/util/DiffUtil.scala index cec86fa84443..31acc91caa2e 100644 --- a/compiler/src/dotty/tools/dotc/util/DiffUtil.scala +++ b/compiler/src/dotty/tools/dotc/util/DiffUtil.scala @@ -70,7 +70,9 @@ object DiffUtil { * differences are highlighted. */ def mkColoredLineDiff(expected: Seq[String], actual: Seq[String]): String = { - val expectedSize = EOF.length max expected.maxBy(_.length).length + val longestExpected = expected.map(_.length).maxOption.getOrElse(0) + val longestActual = actual.map(_.length).maxOption.getOrElse(0) + val expectedSize = EOF.length max longestActual max longestExpected actual.padTo(expected.length, "").zip(expected.padTo(actual.length, "")).map { case (act, exp) => mkColoredLineDiff(exp, act, expectedSize) }.mkString(System.lineSeparator) @@ -101,11 +103,75 @@ object DiffUtil { case Deleted(str) => deleted(str) }.mkString + (expectedDiff, actualDiff) val pad = " " * 0.max(expectedSize - expected.length) expectedDiff + pad + " | " + actualDiff } + private def ensureLineSeparator(str: String): String = + if str.endsWith(System.lineSeparator) then + str + else + str + System.lineSeparator + + /** + * Returns a colored diffs by comparison of lines instead of tokens. + * It will automatically group subsequential pairs of `Insert` and `Delete` + * in order to improve the readability + * + * @param expected The expected lines + * @param actual The actual lines + * @return A string with colored diffs between `expected` and `actual` grouped whenever possible + */ + def mkColoredHorizontalLineDiff(expected: String, actual: String): String = { + val indent = 2 + val tab = " " * indent + val insertIndent = "+" ++ (" " * (indent - 1)) + val deleteIndent = "-" ++ (" " * (indent - 1)) + + if actual.isEmpty then + (expected.linesIterator.map(line => added(insertIndent + line)).toList :+ deleted("--- EMPTY OUTPUT ---")) + .map(ensureLineSeparator).mkString + else if expected.isEmpty then + (added("--- NO VALUE EXPECTED ---") +: actual.linesIterator.map(line => deleted(deleteIndent + line)).toList) + .map(ensureLineSeparator).mkString + else + lazy val diff = { + val expectedTokens = expected.linesWithSeparators.toArray + val actualTokens = actual.linesWithSeparators.toArray + hirschberg(actualTokens, expectedTokens) + }.toList + + val transformedDiff = diff.flatMap { + case Modified(original, str) => Seq( + Inserted(ensureLineSeparator(original)), Deleted(ensureLineSeparator(str)) + ) + case other => Seq(other) + } + + val zipped = transformedDiff zip transformedDiff.drop(1) + + val (acc, inserts, deletions) = zipped.foldLeft((Seq[Patch](), Seq[Inserted](), Seq[Deleted]())): (acc, patches) => + val (currAcc, inserts, deletions) = acc + patches match + case (currentPatch: Inserted, nextPatch: Deleted) => + (currAcc, inserts :+ currentPatch, deletions) + case (currentPatch: Deleted, nextPatch: Inserted) => + (currAcc, inserts, deletions :+ currentPatch) + case (currentPatch, nextPatch) => + (currAcc :++ inserts :++ deletions :+ currentPatch, Seq.empty, Seq.empty) + + val stackedDiff = acc :++ inserts :++ deletions :+ diff.last + + stackedDiff.collect { + case Unmodified(str) => tab + str + case Inserted(str) => added(insertIndent + str) + case Deleted(str) => deleted(deleteIndent + str) + }.map(ensureLineSeparator).mkString + + } + def mkColoredCodeDiff(code: String, lastCode: String, printDiffDel: Boolean): String = { val tokens = splitTokens(code, Nil).toArray val lastTokens = splitTokens(lastCode, Nil).toArray diff --git a/compiler/src/dotty/tools/dotc/util/SourceFile.scala b/compiler/src/dotty/tools/dotc/util/SourceFile.scala index 42d07869f74e..3462036d7ba6 100644 --- a/compiler/src/dotty/tools/dotc/util/SourceFile.scala +++ b/compiler/src/dotty/tools/dotc/util/SourceFile.scala @@ -16,8 +16,9 @@ import scala.collection.mutable.ArrayBuffer import scala.util.chaining.given import java.io.File.separator +import java.net.URI import java.nio.charset.StandardCharsets -import java.nio.file.{FileSystemException, NoSuchFileException} +import java.nio.file.{FileSystemException, NoSuchFileException, Paths} import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import java.util.regex.Pattern @@ -222,6 +223,13 @@ object SourceFile { SourceFile(new VirtualFile(name.replace(separator, "/"), content.getBytes(StandardCharsets.UTF_8)), content.toCharArray) .tap(_._maybeInComplete = maybeIncomplete) + /** A helper method to create a virtual source file for given URI. + * It relies on SourceFile#virtual implementation to create the virtual file. + */ + def virtual(uri: URI, content: String): SourceFile = + val path = Paths.get(uri).toString + SourceFile.virtual(path, content) + /** Returns the relative path of `source` within the `reference` path * * It returns the absolute path of `source` if it is not contained in `reference`. diff --git a/presentation-compiler/src/main/dotty/tools/pc/AutoImports.scala b/presentation-compiler/src/main/dotty/tools/pc/AutoImports.scala new file mode 100644 index 000000000000..4a204105d7b2 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/AutoImports.scala @@ -0,0 +1,377 @@ +package dotty.tools.pc + +import scala.annotation.tailrec +import scala.jdk.CollectionConverters.* +import scala.meta.internal.pc.AutoImportPosition +import scala.meta.pc.PresentationCompilerConfig + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Comments.Comment +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.dotc.util.Spans +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j as l + +object AutoImports: + + object AutoImport: + def renameConfigMap(config: PresentationCompilerConfig)(using + Context + ): Map[Symbol, String] = + config.symbolPrefixes.asScala.flatMap { (from, to) => + val pkg = SemanticdbSymbols.inverseSemanticdbSymbol(from) + val rename = to.stripSuffix(".").stripSuffix("#") + List(pkg, pkg.map(_.moduleClass)).flatten + .filter(_ != NoSymbol) + .map((_, rename)) + }.toMap + end AutoImport + + sealed trait SymbolIdent: + def value: String + + object SymbolIdent: + case class Direct(value: String) extends SymbolIdent + case class Select(qual: SymbolIdent, name: String) extends SymbolIdent: + def value: String = s"${qual.value}.$name" + + def direct(name: String): SymbolIdent = Direct(name) + + def fullIdent(symbol: Symbol)(using Context): SymbolIdent = + val symbols = symbol.ownersIterator.toList + .takeWhile(_ != ctx.definitions.RootClass) + .reverse + + symbols match + case head :: tail => + tail.foldLeft(direct(head.nameBackticked))((acc, next) => + Select(acc, next.nameBackticked) + ) + case Nil => + SymbolIdent.direct("") + + end SymbolIdent + + sealed trait ImportSel: + def sym: Symbol + + object ImportSel: + final case class Direct(sym: Symbol) extends ImportSel + final case class Rename(sym: Symbol, rename: String) extends ImportSel + + case class SymbolImport( + sym: Symbol, + ident: SymbolIdent, + importSel: Option[ImportSel] + ): + + def name: String = ident.value + + object SymbolImport: + + def simple(sym: Symbol)(using Context): SymbolImport = + SymbolImport(sym, SymbolIdent.direct(sym.nameBackticked), None) + + /** + * Returns AutoImportsGenerator + * + * @param pos A source position where the autoImport is invoked + * @param text Source text of the file + * @param tree A typed tree of the file + * @param indexedContext A context of the position where the autoImport is invoked + * @param config A presentation compiler config, this is used for renames + */ + def generator( + pos: SourcePosition, + text: String, + tree: Tree, + comments: List[Comment], + indexedContext: IndexedContext, + config: PresentationCompilerConfig + ): AutoImportsGenerator = + + import indexedContext.ctx + + val importPos = autoImportPosition(pos, text, tree, comments) + val renameConfig: Map[Symbol, String] = AutoImport.renameConfigMap(config) + + val renames = + (sym: Symbol) => + indexedContext + .rename(sym) + .orElse(renameConfig.get(sym)) + + new AutoImportsGenerator( + pos, + importPos, + indexedContext, + renames + ) + end generator + + case class AutoImportEdits( + nameEdit: Option[l.TextEdit], + importEdit: Option[l.TextEdit] + ): + + def edits: List[l.TextEdit] = List(nameEdit, importEdit).flatten + + object AutoImportEdits: + + def apply(name: l.TextEdit, imp: l.TextEdit): AutoImportEdits = + AutoImportEdits(Some(name), Some(imp)) + def importOnly(edit: l.TextEdit): AutoImportEdits = + AutoImportEdits(None, Some(edit)) + def nameOnly(edit: l.TextEdit): AutoImportEdits = + AutoImportEdits(Some(edit), None) + + /** + * AutoImportsGenerator generates TextEdits of auto-imports + * for the given symbols. + * + * @param pos A source position where the autoImport is invoked + * @param importPosition A position to insert new imports + * @param indexedContext A context of the position where the autoImport is invoked + * @param renames A function that returns the name of the given symbol which is renamed on import statement. + */ + class AutoImportsGenerator( + val pos: SourcePosition, + importPosition: AutoImportPosition, + indexedContext: IndexedContext, + renames: Symbol => Option[String] + ): + + import indexedContext.ctx + + def forSymbol(symbol: Symbol): Option[List[l.TextEdit]] = + editsForSymbol(symbol).map(_.edits) + + /** + * @param symbol A missing symbol to auto-import + */ + def editsForSymbol(symbol: Symbol): Option[AutoImportEdits] = + val symbolImport = inferSymbolImport(symbol) + val nameEdit = symbolImport.ident match + case SymbolIdent.Direct(_) => None + case other => + Some(new l.TextEdit(pos.toLsp, other.value)) + + val importEdit = + symbolImport.importSel.flatMap(sel => renderImports(List(sel))) + if nameEdit.isDefined || importEdit.isDefined then + Some(AutoImportEdits(nameEdit, importEdit)) + else None + end editsForSymbol + + def inferSymbolImport(symbol: Symbol): SymbolImport = + indexedContext.lookupSym(symbol) match + case IndexedContext.Result.Missing => + // in java enum and enum case both have same flags + val enumOwner = symbol.owner.companion + def isJavaEnumCase: Boolean = + symbol.isAllOf(EnumVal) && enumOwner.is(Enum) + + val (name, sel) = + // For enums import owner instead of all members + if symbol.isAllOf(EnumCase) || isJavaEnumCase + then + val ownerImport = inferSymbolImport(enumOwner) + ( + SymbolIdent.Select( + ownerImport.ident, + symbol.nameBackticked(false) + ), + ownerImport.importSel, + ) + else + ( + SymbolIdent.direct(symbol.nameBackticked), + Some(ImportSel.Direct(symbol)), + ) + end val + + SymbolImport( + symbol, + name, + sel + ) + case IndexedContext.Result.Conflict => + val owner = symbol.owner + renames(owner) match + case Some(rename) => + val importSel = + if rename != owner.showName then + Some(ImportSel.Rename(owner, rename)).filter(_ => + !indexedContext.hasRename(owner, rename) + ) + else + Some(ImportSel.Direct(owner)).filter(_ => + !indexedContext.lookupSym(owner).exists + ) + + SymbolImport( + symbol, + SymbolIdent.Select( + SymbolIdent.direct(rename), + symbol.nameBackticked(false) + ), + importSel + ) + case None => + SymbolImport( + symbol, + SymbolIdent.direct(symbol.fullNameBackticked), + None + ) + end match + case IndexedContext.Result.InScope => + val direct = renames(symbol).getOrElse(symbol.nameBackticked) + SymbolImport(symbol, SymbolIdent.direct(direct), None) + end match + end inferSymbolImport + + def renderImports( + imports: List[ImportSel] + )(using Context): Option[l.TextEdit] = + if imports.nonEmpty then + val indent0 = " " * importPosition.indent + val editPos = pos.withSpan(Spans.Span(importPosition.offset)).toLsp + + // for worksheets, we need to remove 2 whitespaces, because it ends up being wrapped in an object + // see WorksheetProvider.worksheetScala3AdjustmentsForPC + val indent = + if pos.source.path.isWorksheet && + editPos.getStart().getCharacter() == 0 + then indent0.drop(2) + else indent0 + val topPadding = + if importPosition.padTop then "\n" + else "" + + val formatted = imports + .map { + case ImportSel.Direct(sym) => importName(sym) + case ImportSel.Rename(sym, rename) => + s"${importName(sym.owner)}.{${sym.nameBackticked(false)} => $rename}" + } + .map(sel => s"${indent}import $sel") + .mkString(topPadding, "\n", "\n") + + Some(new l.TextEdit(editPos, formatted)) + else None + end renderImports + + private def importName(sym: Symbol): String = + if indexedContext.importContext.toplevelClashes(sym) then + s"_root_.${sym.fullNameBackticked(false)}" + else sym.fullNameBackticked(false) + end AutoImportsGenerator + + private def autoImportPosition( + pos: SourcePosition, + text: String, + tree: Tree, + comments: List[Comment] + )(using Context): AutoImportPosition = + + @tailrec + def lastPackageDef( + prev: Option[PackageDef], + tree: Tree + ): Option[PackageDef] = + tree match + case curr @ PackageDef(_, (next: PackageDef) :: Nil) + if !curr.symbol.isPackageObject => + lastPackageDef(Some(curr), next) + case pkg: PackageDef if !pkg.symbol.isPackageObject => Some(pkg) + case _ => prev + + def firstObjectBody(tree: Tree)(using Context): Option[Template] = + tree match + case PackageDef(_, stats) => + stats.flatMap { + case s: PackageDef => firstObjectBody(s) + case TypeDef(_, t @ Template(defDef, _, _, _)) + if defDef.symbol.isConstructor => Some(t) + case _ => None + }.headOption + case _ => None + + def skipUsingDirectivesOffset = + comments + .takeWhile(comment => + !comment.isDocComment && comment.span.end < firstObjectBody(tree) + .fold(0)(_.span.start) + ) + .lastOption + .fold(0)(_.span.end + 1) + + def forScalaSource: Option[AutoImportPosition] = + lastPackageDef(None, tree).map { pkg => + val lastImportStatement = + pkg.stats.takeWhile(_.isInstanceOf[Import]).lastOption + val (lineNumber, padTop) = lastImportStatement match + case Some(stm) => (stm.endPos.line + 1, false) + case None if pkg.pid.symbol.isEmptyPackage => + (pos.source.offsetToLine(skipUsingDirectivesOffset), false) + case None => + val pos = pkg.pid.endPos + val line = + // pos point at the last NL + if pos.endColumn == 0 then math.max(0, pos.line - 1) + else pos.line + 1 + (line, true) + val offset = pos.source.lineToOffset(lineNumber) + new AutoImportPosition(offset, text, padTop) + } + + def forScript(isAmmonite: Boolean): Option[AutoImportPosition] = + firstObjectBody(tree).map { tmpl => + val lastImportStatement = + tmpl.body.takeWhile(_.isInstanceOf[Import]).lastOption + val offset = lastImportStatement match + case Some(stm) => + val offset = pos.source.lineToOffset(stm.endPos.line + 1) + offset + case None => + val scriptOffset = + if isAmmonite then + ScriptFirstImportPosition.ammoniteScStartOffset(text, comments) + else + ScriptFirstImportPosition.scalaCliScStartOffset(text, comments) + + scriptOffset.getOrElse { + val tmplPoint = tmpl.self.srcPos.span.point + if tmplPoint >= 0 && tmplPoint < pos.source.length + then pos.source.lineToOffset(tmpl.self.srcPos.line) + else 0 + } + new AutoImportPosition(offset, text, false) + } + end forScript + + val path = pos.source.path + + def fileStart = + AutoImportPosition( + skipUsingDirectivesOffset, + 0, + padTop = false + ) + + val scriptPos = + if path.isAmmoniteGeneratedFile then forScript(isAmmonite = true) + else if path.isScalaCLIGeneratedFile then forScript(isAmmonite = false) + else None + + scriptPos + .orElse(forScalaSource) + .getOrElse(fileStart) + end autoImportPosition + +end AutoImports diff --git a/presentation-compiler/src/main/dotty/tools/pc/AutoImportsProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/AutoImportsProvider.scala new file mode 100644 index 000000000000..6a1b91cba31f --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/AutoImportsProvider.scala @@ -0,0 +1,105 @@ +package dotty.tools.pc + +import java.nio.file.Paths + +import scala.collection.mutable +import scala.jdk.CollectionConverters.* +import scala.meta.internal.metals.ReportContext +import scala.meta.internal.pc.AutoImportsResultImpl +import scala.meta.pc.* + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.pc.AutoImports.* +import dotty.tools.pc.completions.CompletionPos +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j as l + +final class AutoImportsProvider( + search: SymbolSearch, + driver: InteractiveDriver, + name: String, + params: OffsetParams, + config: PresentationCompilerConfig, + buildTargetIdentifier: String +)(using ReportContext): + + def autoImports(isExtension: Boolean): List[AutoImportsResult] = + val uri = params.uri + val filePath = Paths.get(uri) + driver.run( + uri, + SourceFile.virtual(filePath.toString, params.text) + ) + val unit = driver.currentCtx.run.units.head + val tree = unit.tpdTree + + val pos = driver.sourcePosition(params) + + val newctx = driver.currentCtx.fresh.setCompilationUnit(unit) + val path = + Interactive.pathTo(newctx.compilationUnit.tpdTree, pos.span)(using newctx) + + val indexedContext = IndexedContext( + Interactive.contextOfPath(path)(using newctx) + ) + import indexedContext.ctx + + val isSeen = mutable.Set.empty[String] + val symbols = List.newBuilder[Symbol] + def visit(sym: Symbol): Boolean = + val name = sym.denot.fullName.show + if !isSeen(name) then + isSeen += name + symbols += sym + true + else false + def isExactMatch(sym: Symbol, query: String): Boolean = + sym.name.show == query + + val visitor = new CompilerSearchVisitor(visit) + if isExtension then + search.searchMethods(name, buildTargetIdentifier, visitor) + else search.search(name, buildTargetIdentifier, visitor) + val results = symbols.result.filter(isExactMatch(_, name)) + + if results.nonEmpty then + val correctedPos = CompletionPos.infer(pos, params, path).sourcePos + val mkEdit = + path match + // if we are in import section just specify full name + case (_: Ident) :: (_: Import) :: _ => + (sym: Symbol) => + val nameEdit = + new l.TextEdit(correctedPos.toLsp, sym.fullNameBackticked) + Some(List(nameEdit)) + case _ => + val generator = + AutoImports.generator( + correctedPos, + params.text, + tree, + unit.comments, + indexedContext.importContext, + config + ) + (sym: Symbol) => generator.forSymbol(sym) + end match + end mkEdit + + for + sym <- results + edits <- mkEdit(sym) + yield AutoImportsResultImpl( + sym.owner.showFullName, + edits.asJava + ) + else List.empty + end if + end autoImports + +end AutoImportsProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/CompilerSearchVisitor.scala b/presentation-compiler/src/main/dotty/tools/pc/CompilerSearchVisitor.scala new file mode 100644 index 000000000000..7920e67bc26a --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/CompilerSearchVisitor.scala @@ -0,0 +1,91 @@ +package dotty.tools.pc + +import java.util.logging.Level +import java.util.logging.Logger + +import scala.meta.internal.metals.Report +import scala.meta.internal.metals.ReportContext +import scala.meta.pc.* +import scala.util.control.NonFatal + +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.Symbols.* + +class CompilerSearchVisitor( + visitSymbol: Symbol => Boolean +)(using ctx: Context, reports: ReportContext) + extends SymbolSearchVisitor: + + val logger: Logger = Logger.getLogger(classOf[CompilerSearchVisitor].getName) + + private def isAccessible(sym: Symbol): Boolean = try + sym != NoSymbol && sym.isPublic && sym.isStatic + catch + case NonFatal(e) => + reports.incognito.create( + Report( + "is_public", + s"""Symbol: $sym""".stripMargin, + e + ) + ) + logger.log(Level.SEVERE, e.getMessage(), e) + false + + private def toSymbols( + pkg: String, + parts: List[String] + ): List[Symbol] = + def loop(owners: List[Symbol], parts: List[String]): List[Symbol] = + parts match + case head :: tl => + val next = owners.flatMap { sym => + val term = sym.info.member(termName(head)) + val tpe = sym.info.member(typeName(head)) + + List(term, tpe) + .filter(denot => denot.exists) + .map(_.symbol) + .filter(isAccessible) + } + loop(next, tl) + case Nil => owners + + val pkgSym = requiredPackage(pkg) + loop(List(pkgSym), parts) + end toSymbols + + def visitClassfile(pkgPath: String, filename: String): Int = + val pkg = normalizePackage(pkgPath) + + val innerPath = filename + .stripSuffix(".class") + .stripSuffix("$") + .split("\\$") + + val added = toSymbols(pkg, innerPath.toList).filter(visitSymbol) + added.size + + def visitWorkspaceSymbol( + path: java.nio.file.Path, + symbol: String, + kind: org.eclipse.lsp4j.SymbolKind, + range: org.eclipse.lsp4j.Range + ): Int = + val gsym = SemanticdbSymbols.inverseSemanticdbSymbol(symbol).headOption + gsym + .filter(isAccessible) + .map(visitSymbol) + .map(_ => 1) + .getOrElse(0) + + def shouldVisitPackage(pkg: String): Boolean = + isAccessible(requiredPackage(normalizePackage(pkg))) + + override def isCancelled: Boolean = false + + private def normalizePackage(pkg: String): String = + pkg.replace("/", ".").stripSuffix(".") + +end CompilerSearchVisitor diff --git a/presentation-compiler/src/main/dotty/tools/pc/CompletionItemResolver.scala b/presentation-compiler/src/main/dotty/tools/pc/CompletionItemResolver.scala new file mode 100644 index 000000000000..d393e9204c27 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/CompletionItemResolver.scala @@ -0,0 +1,90 @@ +package dotty.tools.pc + +import scala.meta.internal.pc.ItemResolver +import scala.meta.pc.PresentationCompilerConfig +import scala.meta.pc.SymbolDocumentation +import scala.meta.pc.SymbolSearch + +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.TermRef +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j.CompletionItem + +object CompletionItemResolver extends ItemResolver: + + override def adjustIndexOfJavaParams = 0 + + def resolve( + item: CompletionItem, + msym: String, + search: SymbolSearch, + metalsConfig: PresentationCompilerConfig + )(using Context): CompletionItem = + SemanticdbSymbols.inverseSemanticdbSymbol(msym) match + case gsym :: _ if gsym != NoSymbol => + search + .symbolDocumentation(gsym) + .orElse( + search.symbolDocumentation(gsym.companion) + ) match + case Some(info) if item.getDetail != null => + enrichDocs( + item, + info, + metalsConfig, + fullDocstring(gsym, search), + gsym.is(JavaDefined) + ) + case _ => + item + end match + + case _ => item + end match + end resolve + + private def fullDocstring(gsym: Symbol, search: SymbolSearch)(using + Context + ): String = + def docs(gsym: Symbol): String = + search.symbolDocumentation(gsym).fold("")(_.docstring()) + val gsymDoc = docs(gsym) + def keyword(gsym: Symbol): String = + if gsym.isClass then "class" + else if gsym.is(Trait) then "trait" + else if gsym.isAllOf(JavaInterface) then "interface" + else if gsym.is(Module) then "object" + else "" + val companion = gsym.companion + if companion == NoSymbol || gsym.is(JavaDefined) then + if gsymDoc.isEmpty then + if gsym.isAliasType then + fullDocstring(gsym.info.metalsDealias.typeSymbol, search) + else if gsym.is(Method) then + gsym.info.finalResultType match + case tr @ TermRef(_, sym) => + fullDocstring(tr.symbol, search) + case _ => + "" + else "" + else gsymDoc + else + val companionDoc = docs(companion) + if companionDoc.isEmpty then gsymDoc + else if gsymDoc.isEmpty then companionDoc + else + List( + s"""|### ${keyword(companion)} ${companion.name} + |$companionDoc + |""".stripMargin, + s"""|### ${keyword(gsym)} ${gsym.name} + |${gsymDoc} + |""".stripMargin + ).sorted.mkString("\n") + end if + end fullDocstring + +end CompletionItemResolver diff --git a/presentation-compiler/src/main/dotty/tools/pc/ConvertToNamedArgumentsProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/ConvertToNamedArgumentsProvider.scala new file mode 100644 index 000000000000..99cc82cdf6a1 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/ConvertToNamedArgumentsProvider.scala @@ -0,0 +1,88 @@ +package dotty.tools.pc + +import java.nio.file.Paths + +import scala.meta.internal.pc.CodeActionErrorMessages +import scala.meta.pc.OffsetParams + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Types.MethodType +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j as l + +final class ConvertToNamedArgumentsProvider( + driver: InteractiveDriver, + params: OffsetParams, + argIndices: Set[Int] +): + + def convertToNamedArguments: Either[String, List[l.TextEdit]] = + val uri = params.uri + val filePath = Paths.get(uri) + driver.run( + uri, + SourceFile.virtual(filePath.toString, params.text) + ) + val unit = driver.currentCtx.run.units.head + val newctx = driver.currentCtx.fresh.setCompilationUnit(unit) + val pos = driver.sourcePosition(params) + val trees = driver.openedTrees(uri) + val tree = Interactive.pathTo(trees, pos)(using newctx).headOption + + def paramss(fun: tpd.Tree)(using Context): List[String] = + fun.tpe match + case m: MethodType => m.paramNamess.flatten.map(_.toString) + case _ => + fun.symbol.rawParamss.flatten + .filter(!_.isTypeParam) + .map(_.nameBackticked) + + object FromNewApply: + def unapply(tree: tpd.Tree): Option[(tpd.Tree, List[tpd.Tree])] = + tree match + case fun @ tpd.Select(tpd.New(_), _) => + Some((fun, Nil)) + case tpd.TypeApply(FromNewApply(fun, argss), _) => + Some(fun, argss) + case tpd.Apply(FromNewApply(fun, argss), args) => + Some(fun, argss ++ args) + case _ => None + + def edits(tree: Option[tpd.Tree])(using + Context + ): Either[String, List[l.TextEdit]] = + def makeTextEdits(fun: tpd.Tree, args: List[tpd.Tree]) = + if fun.symbol.is(Flags.JavaDefined) then + Left(CodeActionErrorMessages.ConvertToNamedArguments.IsJavaObject) + else + Right( + args.zipWithIndex + .zip(paramss(fun)) + .collect { + case ((arg, index), param) if argIndices.contains(index) => + val position = arg.sourcePos.toLsp + position.setEnd(position.getStart()) + new l.TextEdit(position, s"$param = ") + } + ) + + tree match + case Some(t) => + t match + case FromNewApply(fun, args) => + makeTextEdits(fun, args) + case tpd.Apply(fun, args) => + makeTextEdits(fun, args) + case _ => Right(Nil) + case _ => Right(Nil) + end match + end edits + edits(tree)(using newctx) + end convertToNamedArguments +end ConvertToNamedArgumentsProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/ExtractMethodProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/ExtractMethodProvider.scala new file mode 100644 index 000000000000..cbdc39a90118 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/ExtractMethodProvider.scala @@ -0,0 +1,173 @@ +package dotty.tools.pc + +import java.nio.file.Paths + +import scala.meta.internal.metals.ReportContext +import scala.meta.internal.pc.ExtractMethodUtils +import scala.meta.pc.OffsetParams +import scala.meta.pc.RangeParams +import scala.meta.pc.SymbolSearch +import scala.meta as m + +import dotty.tools.dotc.ast.Trees.* +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.DeepFolder +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.MethodType +import dotty.tools.dotc.core.Types.PolyType +import dotty.tools.dotc.core.Types.Type +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.printer.ShortenedTypePrinter +import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j.TextEdit +import org.eclipse.lsp4j as l + +final class ExtractMethodProvider( + range: RangeParams, + extractionPos: OffsetParams, + driver: InteractiveDriver, + search: SymbolSearch, + noIndent: Boolean +)(using ReportContext) + extends ExtractMethodUtils: + + def extractMethod(): List[TextEdit] = + val text = range.text() + val uri = range.uri + val filePath = Paths.get(uri) + val source = SourceFile.virtual(filePath.toString, text) + driver.run(uri, source) + val unit = driver.currentCtx.run.units.head + val pos = driver.sourcePosition(range).startPos + val path = + Interactive.pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + given locatedCtx: Context = + val newctx = driver.currentCtx.fresh.setCompilationUnit(unit) + Interactive.contextOfPath(path)(using newctx) + val indexedCtx = IndexedContext(locatedCtx) + val printer = + ShortenedTypePrinter(search, IncludeDefaultParam.Never)(using indexedCtx) + def prettyPrint(tpe: Type) = + def prettyPrintReturnType(tpe: Type): String = + tpe match + case mt: (MethodType | PolyType) => + prettyPrintReturnType(tpe.resultType) + case tpe => printer.tpe(tpe) + def printParams(params: List[Type]) = + params match + case p :: Nil => prettyPrintReturnType(p) + case _ => s"(${params.map(prettyPrintReturnType).mkString(", ")})" + + if tpe.paramInfoss.isEmpty + then prettyPrintReturnType(tpe) + else + val params = tpe.paramInfoss.map(printParams).mkString(" => ") + s"$params => ${prettyPrintReturnType(tpe)}" + end prettyPrint + + def extractFromBlock(t: tpd.Tree): List[tpd.Tree] = + t match + case Block(stats, expr) => + (stats :+ expr).filter(stat => range.encloses(stat.sourcePos)) + case temp: Template[?] => + temp.body.filter(stat => range.encloses(stat.sourcePos)) + case other => List(other) + + def localRefs( + ts: List[tpd.Tree], + defnPos: SourcePosition, + extractedPos: SourcePosition + ): (List[Symbol], List[Symbol]) = + def nonAvailable(sym: Symbol): Boolean = + val symPos = sym.sourcePos + symPos.exists && defnPos.contains(symPos) && !extractedPos + .contains(symPos) + def collectNames(symbols: Set[Symbol], tree: tpd.Tree): Set[Symbol] = + tree match + case id @ Ident(_) => + val sym = id.symbol + if nonAvailable(sym) && (sym.isTerm || sym.isTypeParam) + then symbols + sym + else symbols + case _ => symbols + + val traverser = new DeepFolder[Set[Symbol]](collectNames) + val allSymbols = ts + .foldLeft(Set.empty[Symbol])(traverser(_, _)) + + val methodParams = allSymbols.toList.filter(_.isTerm) + val methodParamTypes = methodParams + .flatMap(p => p :: p.paramSymss.flatten) + .map(_.info.typeSymbol) + .filter(tp => nonAvailable(tp) && tp.isTypeParam) + .distinct + // Type parameter can be a type of one of the parameters or a type parameter in extracted code + val typeParams = + allSymbols.filter(_.isTypeParam) ++ methodParamTypes + + ( + methodParams.sortBy(_.decodedName), + typeParams.toList.sortBy(_.decodedName), + ) + end localRefs + val edits = + for + enclosing <- path.find(src => src.sourcePos.encloses(range)) + extracted = extractFromBlock(enclosing) + head <- extracted.headOption + expr <- extracted.lastOption + shortenedPath = + path.takeWhile(src => extractionPos.offset() <= src.sourcePos.start) + stat = shortenedPath.lastOption.getOrElse(head) + yield + val defnPos = stat.sourcePos + val extractedPos = head.sourcePos.withEnd(expr.sourcePos.end) + val exprType = prettyPrint(expr.tpe.widen) + val name = + genName(indexedCtx.scopeSymbols.map(_.decodedName).toSet, "newMethod") + val (methodParams, typeParams) = + localRefs(extracted, stat.sourcePos, extractedPos) + val methodParamsText = methodParams + .map(sym => s"${sym.decodedName}: ${prettyPrint(sym.info)}") + .mkString(", ") + val typeParamsText = typeParams + .map(_.decodedName) match + case Nil => "" + case params => params.mkString("[", ", ", "]") + val exprParamsText = methodParams.map(_.decodedName).mkString(", ") + val newIndent = stat.startPos.startColumnPadding + val oldIndentLen = head.startPos.startColumnPadding.length() + val toExtract = + textToExtract( + range.text(), + head.startPos.start, + expr.endPos.end, + newIndent, + oldIndentLen + ) + val (obracket, cbracket) = + if noIndent && extracted.length > 1 then (" {", s"$newIndent}") + else ("", "") + val defText = + s"def $name$typeParamsText($methodParamsText): $exprType =$obracket\n${toExtract}\n$cbracket\n$newIndent" + val replacedText = s"$name($exprParamsText)" + List( + new l.TextEdit( + extractedPos.toLsp, + replacedText + ), + new l.TextEdit( + defnPos.startPos.toLsp, + defText + ) + ) + + edits.getOrElse(Nil) + end extractMethod +end ExtractMethodProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/HoverProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/HoverProvider.scala new file mode 100644 index 000000000000..cdd4b273bdcc --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/HoverProvider.scala @@ -0,0 +1,204 @@ +package dotty.tools.pc + +import java.util as ju + +import scala.meta.internal.metals.Report +import scala.meta.internal.metals.ReportContext +import scala.meta.internal.pc.ScalaHover +import scala.meta.pc.HoverSignature +import scala.meta.pc.OffsetParams +import scala.meta.pc.SymbolSearch + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Constants.* +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.printer.ShortenedTypePrinter +import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam +import dotty.tools.pc.utils.MtagsEnrichments.* + +object HoverProvider: + + def hover( + params: OffsetParams, + driver: InteractiveDriver, + search: SymbolSearch + )(implicit reportContext: ReportContext): ju.Optional[HoverSignature] = + val uri = params.uri + val sourceFile = SourceFile.virtual(params.uri, params.text) + driver.run(uri, sourceFile) + + given ctx: Context = driver.currentCtx + val pos = driver.sourcePosition(params) + val trees = driver.openedTrees(uri) + val indexedContext = IndexedContext(ctx) + + def typeFromPath(path: List[Tree]) = + if path.isEmpty then NoType else path.head.tpe + + val path = Interactive.pathTo(trees, pos) + val tp = typeFromPath(path) + val tpw = tp.widenTermRefExpr + // For expression we need to find all enclosing applies to get the exact generic type + val enclosing = path.expandRangeToEnclosingApply(pos) + + if tp.isError || tpw == NoType || tpw.isError || path.isEmpty + then + def report = + val posId = + if path.isEmpty || path.head.sourcePos == null || !path.head.sourcePos.exists + then pos.start + else path.head.sourcePos.start + Report( + "empty-hover-scala3", + s"""|$uri + |pos: ${pos.toLsp} + | + |tp: $tp + |has error: ${tp.isError} + | + |tpw: $tpw + |has error: ${tpw.isError} + | + |path: + |- ${path.map(_.toString()).mkString("\n- ")} + |trees: + |- ${trees.map(_.toString()).mkString("\n- ")} + |""".stripMargin, + s"$uri::$posId" + ) + end report + reportContext.unsanitized.create(report, ifVerbose = true) + ju.Optional.empty() + else + val skipCheckOnName = + !pos.isPoint // don't check isHoveringOnName for RangeHover + + val printerContext = + driver.compilationUnits.get(uri) match + case Some(unit) => + val newctx = + ctx.fresh.setCompilationUnit(unit) + Interactive.contextOfPath(enclosing)(using newctx) + case None => ctx + val printer = ShortenedTypePrinter(search, IncludeDefaultParam.Include)( + using IndexedContext(printerContext) + ) + MetalsInteractive.enclosingSymbolsWithExpressionType( + enclosing, + pos, + indexedContext, + skipCheckOnName + ) match + case Nil => + fallbackToDynamics(path, printer) + case (symbol, tpe) :: _ + if symbol.name == nme.selectDynamic || symbol.name == nme.applyDynamic => + fallbackToDynamics(path, printer) + case symbolTpes @ ((symbol, tpe) :: _) => + val exprTpw = tpe.widenTermRefExpr.metalsDealias + val hoverString = + tpw match + // https://github.com/lampepfl/dotty/issues/8891 + case tpw: ImportType => + printer.hoverSymbol(symbol, symbol.paramRef) + case _ => + val (tpe, sym) = + if symbol.isType then (symbol.typeRef, symbol) + else enclosing.head.seenFrom(symbol) + + val finalTpe = + if tpe != NoType then tpe + else tpw + + printer.hoverSymbol(sym, finalTpe) + end match + end hoverString + + val docString = symbolTpes + .flatMap(symTpe => search.symbolDocumentation(symTpe._1)) + .map(_.docstring) + .mkString("\n") + printer.expressionType(exprTpw) match + case Some(expressionType) => + val forceExpressionType = + !pos.span.isZeroExtent || ( + !hoverString.endsWith(expressionType) && + !symbol.isType && + !symbol.is(Module) && + !symbol.flags.isAllOf(EnumCase) + ) + ju.Optional.of( + new ScalaHover( + expressionType = Some(expressionType), + symbolSignature = Some(hoverString), + docstring = Some(docString), + forceExpressionType = forceExpressionType + ) + ) + case _ => + ju.Optional.empty + end match + end match + end if + end hover + + extension (pos: SourcePosition) + private def isPoint: Boolean = pos.start == pos.end + + private def fallbackToDynamics( + path: List[Tree], + printer: ShortenedTypePrinter + )(using Context): ju.Optional[HoverSignature] = path match + case SelectDynamicExtractor(sel, n, name) => + def findRefinement(tp: Type): ju.Optional[HoverSignature] = + tp match + case RefinedType(info, refName, tpe) if name == refName.toString() => + val tpeString = + if n == nme.selectDynamic then s": ${printer.tpe(tpe.resultType)}" + else printer.tpe(tpe) + ju.Optional.of( + new ScalaHover( + expressionType = Some(tpeString), + symbolSignature = Some(s"def $name$tpeString") + ) + ) + case RefinedType(info, _, _) => + findRefinement(info) + case _ => ju.Optional.empty() + + findRefinement(sel.tpe.termSymbol.info.dealias) + case _ => + ju.Optional.empty() + +end HoverProvider + +object SelectDynamicExtractor: + def unapply(path: List[Tree])(using Context) = + path match + // tests `structural-types` and `structural-types1` in HoverScala3TypeSuite + case Select(_, _) :: Apply( + Select(Apply(reflSel, List(sel)), n), + List(Literal(Constant(name: String))) + ) :: _ + if (n == nme.selectDynamic || n == nme.applyDynamic) && + nme.reflectiveSelectable == reflSel.symbol.name => + Some(sel, n, name) + // tests `selectable`, `selectable2` and `selectable-full` in HoverScala3TypeSuite + case Select(_, _) :: Apply( + Select(sel, n), + List(Literal(Constant(name: String))) + ) :: _ if n == nme.selectDynamic || n == nme.applyDynamic => + Some(sel, n, name) + case _ => None + end match + end unapply +end SelectDynamicExtractor diff --git a/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala b/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala new file mode 100644 index 000000000000..01456a367bba --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala @@ -0,0 +1,222 @@ +package dotty.tools.pc + +import scala.annotation.tailrec +import scala.util.control.NonFatal + +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.typer.ImportInfo +import dotty.tools.pc.IndexedContext.Result +import dotty.tools.pc.utils.MtagsEnrichments.* + +sealed trait IndexedContext: + given ctx: Context + def scopeSymbols: List[Symbol] + def names: IndexedContext.Names + def rename(sym: Symbol): Option[String] + def outer: IndexedContext + + def findSymbol(name: String): Option[List[Symbol]] + + final def findSymbol(name: Name): Option[List[Symbol]] = + findSymbol(name.decoded) + + final def lookupSym(sym: Symbol): Result = + findSymbol(sym.decodedName) match + case Some(symbols) if symbols.exists(_ == sym) => + Result.InScope + case Some(symbols) + if symbols + .exists(s => isTypeAliasOf(s, sym) || isTermAliasOf(s, sym)) => + Result.InScope + // when all the conflicting symbols came from an old version of the file + case Some(symbols) if symbols.nonEmpty && symbols.forall(_.isStale) => + Result.Missing + case Some(_) => Result.Conflict + case None => Result.Missing + end lookupSym + + final def hasRename(sym: Symbol, as: String): Boolean = + rename(sym) match + case Some(v) => v == as + case None => false + + // detects import scope aliases like + // object Predef: + // val Nil = scala.collection.immutable.Nil + private def isTermAliasOf(termAlias: Symbol, sym: Symbol): Boolean = + termAlias.isTerm && ( + sym.info match + case clz: ClassInfo => clz.appliedRef =:= termAlias.info.resultType + case _ => false + ) + + private def isTypeAliasOf(alias: Symbol, sym: Symbol): Boolean = + alias.isAliasType && alias.info.metalsDealias.typeSymbol == sym + + final def isEmpty: Boolean = this match + case IndexedContext.Empty => true + case _ => false + + final def importContext: IndexedContext = + this match + case IndexedContext.Empty => this + case _ if ctx.owner.is(Package) => this + case _ => outer.importContext + + @tailrec + final def toplevelClashes(sym: Symbol): Boolean = + if sym == NoSymbol || sym.owner == NoSymbol || sym.owner.isRoot then + lookupSym(sym) match + case IndexedContext.Result.Conflict => true + case _ => false + else toplevelClashes(sym.owner) + +end IndexedContext + +object IndexedContext: + + def apply(ctx: Context): IndexedContext = + ctx match + case null => Empty + case NoContext => Empty + case _ => LazyWrapper(using ctx) + + case object Empty extends IndexedContext: + given ctx: Context = NoContext + def findSymbol(name: String): Option[List[Symbol]] = None + def scopeSymbols: List[Symbol] = List.empty + val names: Names = Names(Map.empty, Map.empty) + def rename(sym: Symbol): Option[String] = None + def outer: IndexedContext = this + + class LazyWrapper(using val ctx: Context) extends IndexedContext: + val outer: IndexedContext = IndexedContext(ctx.outer) + val names: Names = extractNames(ctx) + + def findSymbol(name: String): Option[List[Symbol]] = + names.symbols + .get(name) + .map(_.toList) + .orElse(outer.findSymbol(name)) + + def scopeSymbols: List[Symbol] = + val acc = Set.newBuilder[Symbol] + (this :: outers).foreach { ref => + acc ++= ref.names.symbols.values.flatten + } + acc.result.toList + + def rename(sym: Symbol): Option[String] = + names.renames + .get(sym) + .orElse(outer.rename(sym)) + + private def outers: List[IndexedContext] = + val builder = List.newBuilder[IndexedContext] + var curr = outer + while !curr.isEmpty do + builder += curr + curr = curr.outer + builder.result + end LazyWrapper + + enum Result: + case InScope, Conflict, Missing + def exists: Boolean = this match + case InScope | Conflict => true + case Missing => false + + case class Names( + symbols: Map[String, List[Symbol]], + renames: Map[Symbol, String] + ) + + private def extractNames(ctx: Context): Names = + def isAccessibleFromSafe(sym: Symbol, site: Type): Boolean = + try sym.isAccessibleFrom(site, superAccess = false) + catch + case NonFatal(e) => + false + + def accessibleSymbols(site: Type, tpe: Type)(using + Context + ): List[Symbol] = + tpe.decls.toList.filter(sym => isAccessibleFromSafe(sym, site)) + + def accesibleMembers(site: Type)(using Context): List[Symbol] = + site.allMembers + .filter(denot => + try isAccessibleFromSafe(denot.symbol, site) + catch + case NonFatal(e) => + false + ) + .map(_.symbol) + .toList + + def allAccessibleSymbols( + tpe: Type, + filter: Symbol => Boolean = _ => true + )(using Context): List[Symbol] = + val initial = accessibleSymbols(tpe, tpe).filter(filter) + val fromPackageObjects = + initial + .filter(_.isPackageObject) + .flatMap(sym => accessibleSymbols(tpe, sym.thisType)) + initial ++ fromPackageObjects + + def fromImport(site: Type, name: Name)(using Context): List[Symbol] = + List(site.member(name.toTypeName), site.member(name.toTermName)) + .flatMap(_.alternatives) + .map(_.symbol) + + def fromImportInfo( + imp: ImportInfo + )(using Context): List[(Symbol, Option[TermName])] = + val excludedNames = imp.excluded.map(_.decoded) + + if imp.isWildcardImport then + allAccessibleSymbols( + imp.site, + sym => !excludedNames.contains(sym.name.decoded) + ).map((_, None)) + else + imp.forwardMapping.toList.flatMap { (name, rename) => + val isRename = name != rename + if !isRename && !excludedNames.contains(name.decoded) then + fromImport(imp.site, name).map((_, None)) + else if isRename then + fromImport(imp.site, name).map((_, Some(rename))) + else Nil + } + end if + end fromImportInfo + + given Context = ctx + val (symbols, renames) = + if ctx.isImportContext then + val (syms, renames) = + fromImportInfo(ctx.importInfo) + .map((sym, rename) => (sym, rename.map(r => sym -> r.decoded))) + .unzip + (syms, renames.flatten.toMap) + else if ctx.owner.isClass then + val site = ctx.owner.thisType + (accesibleMembers(site), Map.empty) + else if ctx.scope != null then (ctx.scope.toList, Map.empty) + else (List.empty, Map.empty) + + val initial = Map.empty[String, List[Symbol]] + val values = + symbols.foldLeft(initial) { (acc, sym) => + val name = sym.decodedName + val syms = acc.getOrElse(name, List.empty) + acc.updated(name, sym :: syms) + } + Names(values, renames) + end extractNames +end IndexedContext diff --git a/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala new file mode 100644 index 000000000000..578353ef4c90 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala @@ -0,0 +1,329 @@ +package dotty.tools.pc + +import java.nio.file.Paths + +import scala.annotation.tailrec +import scala.meta.internal.metals.ReportContext +import scala.meta.pc.OffsetParams +import scala.meta.pc.PresentationCompilerConfig +import scala.meta.pc.SymbolSearch +import scala.meta as m + +import dotty.tools.dotc.ast.Trees.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.dotc.util.Spans +import dotty.tools.dotc.util.Spans.Span +import dotty.tools.pc.printer.ShortenedTypePrinter +import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j.TextEdit +import org.eclipse.lsp4j as l + +/** + * Tries to calculate edits needed to insert the inferred type annotation + * in all the places that it is possible such as: + * - value or variable declaration + * - methods + * - pattern matches + * - for comprehensions + * - lambdas + * + * The provider will not check if the type does not exist, since there is no way to + * get that data from the presentation compiler. The actual check is being done via + * scalameta parser in InsertInferredType code action. + * + * @param params position and actual source + * @param driver Scala 3 interactive compiler driver + * @param config presentation compielr configuration + */ +final class InferredTypeProvider( + params: OffsetParams, + driver: InteractiveDriver, + config: PresentationCompilerConfig, + symbolSearch: SymbolSearch +)(using ReportContext): + + case class AdjustTypeOpts( + text: String, + adjustedEndPos: l.Position + ) + + def inferredTypeEdits( + adjustOpt: Option[AdjustTypeOpts] = None + ): List[TextEdit] = + val retryType = adjustOpt.isEmpty + val uri = params.uri + val filePath = Paths.get(uri) + + val sourceText = adjustOpt.map(_.text).getOrElse(params.text) + val source = + SourceFile.virtual(filePath.toString, sourceText) + driver.run(uri, source) + val unit = driver.currentCtx.run.units.head + val pos = driver.sourcePosition(params) + val path = + Interactive.pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + + given locatedCtx: Context = driver.localContext(params) + val indexedCtx = IndexedContext(locatedCtx) + val autoImportsGen = AutoImports.generator( + pos, + params.text, + unit.tpdTree, + unit.comments, + indexedCtx, + config + ) + + def removeType(nameEnd: Int, tptEnd: Int) = + sourceText.substring(0, nameEnd) + + sourceText.substring(tptEnd + 1, sourceText.length()) + + def optDealias(tpe: Type): Type = + def isInScope(tpe: Type): Boolean = + tpe match + case tref: TypeRef => + indexedCtx.lookupSym( + tref.currentSymbol + ) == IndexedContext.Result.InScope + case AppliedType(tycon, args) => + isInScope(tycon) && args.forall(isInScope) + case _ => true + if isInScope(tpe) + then tpe + else tpe.metalsDealias + + val printer = ShortenedTypePrinter( + symbolSearch, + includeDefaultParam = IncludeDefaultParam.ResolveLater, + isTextEdit = true + )(using indexedCtx) + + def imports: List[TextEdit] = + printer.imports(autoImportsGen) + + def printType(tpe: Type): String = + printer.tpe(tpe) + + path.headOption match + /* `val a = 1` or `var b = 2` + * turns into + * `val a: Int = 1` or `var b: Int = 2` + * + *`.map(a => a + a)` + * turns into + * `.map((a: Int) => a + a)` + */ + case Some(vl @ ValDef(sym, tpt, rhs)) => + val isParam = path match + case head :: next :: _ if next.symbol.isAnonymousFunction => true + case head :: (b @ Block(stats, expr)) :: next :: _ + if next.symbol.isAnonymousFunction => + true + case _ => false + def baseEdit(withParens: Boolean): TextEdit = + val keywordOffset = if isParam then 0 else 4 + val endPos = + findNamePos(params.text, vl, keywordOffset).endPos.toLsp + adjustOpt.foreach(adjust => endPos.setEnd(adjust.adjustedEndPos)) + new TextEdit( + endPos, + ": " + printType(optDealias(tpt.tpe)) + { + if withParens then ")" else "" + } + ) + + def checkForParensAndEdit( + applyEndingPos: Int, + toCheckFor: Char, + blockStartPos: SourcePosition + ) = + val text = params.text + val isParensFunction: Boolean = text(applyEndingPos) == toCheckFor + + val alreadyHasParens = + text(blockStartPos.start) == '(' + + if isParensFunction && !alreadyHasParens then + new TextEdit(blockStartPos.toLsp, "(") :: baseEdit(withParens = + true + ) :: Nil + else baseEdit(withParens = false) :: Nil + end checkForParensAndEdit + + def typeNameEdit: List[TextEdit] = + path match + // lambda `map(a => ???)` apply + case _ :: _ :: (block: untpd.Block) :: (appl: untpd.Apply) :: _ + if isParam => + checkForParensAndEdit(appl.fun.endPos.end, '(', block.startPos) + + // labda `map{a => ???}` apply + // Ensures that this becomes {(a: Int) => ???} since parentheses + // are required around the parameter of a lambda in Scala 3 + case valDef :: defDef :: (block: untpd.Block) :: (_: untpd.Block) :: (appl: untpd.Apply) :: _ + if isParam => + checkForParensAndEdit(appl.fun.endPos.end, '{', block.startPos) + + case _ => + baseEdit(withParens = false) :: Nil + + def simpleType = + typeNameEdit ::: imports + + rhs match + case t: Tree[?] + if t.typeOpt.isErroneous && retryType && !tpt.sourcePos.span.isZeroExtent => + inferredTypeEdits( + Some( + AdjustTypeOpts( + removeType(vl.namePos.end, tpt.sourcePos.end - 1), + tpt.sourcePos.toLsp.getEnd() + ) + ) + ) + case _ => simpleType + end match + /* `def a[T](param : Int) = param` + * turns into + * `def a[T](param : Int): Int = param` + */ + case Some(df @ DefDef(name, _, tpt, rhs)) => + def typeNameEdit = + /* NOTE: In Scala 3.1.3, `List((1,2)).map((<>,b) => ...)` + * turns into `List((1,2)).map((:Inta,b) => ...)`, + * because `tpt.SourcePos == df.namePos.startPos`, so we use `df.namePos.endPos` instead + * After dropping support for 3.1.3 this can be removed + */ + val end = + if tpt.endPos.end > df.namePos.end then tpt.endPos.toLsp + else df.namePos.endPos.toLsp + + adjustOpt.foreach(adjust => end.setEnd(adjust.adjustedEndPos)) + new TextEdit( + end, + ": " + printType(optDealias(tpt.tpe)) + ) + end typeNameEdit + + def lastColon = + var i = tpt.startPos.start + while i >= 0 && sourceText(i) != ':' do i -= 1 + i + rhs match + case t: Tree[?] + if t.typeOpt.isErroneous && retryType && !tpt.sourcePos.span.isZeroExtent => + inferredTypeEdits( + Some( + AdjustTypeOpts( + removeType(lastColon, tpt.sourcePos.end - 1), + tpt.sourcePos.toLsp.getEnd() + ) + ) + ) + case _ => + typeNameEdit :: imports + + /* `case t =>` + * turns into + * `case t: Int =>` + */ + case Some(bind @ Bind(name, body)) => + def baseEdit(withParens: Boolean) = + new TextEdit( + bind.endPos.toLsp, + ": " + printType(optDealias(body.tpe)) + { + if withParens then ")" else "" + } + ) + val typeNameEdit = path match + /* In case it's an infix pattern match + * we need to add () for example in: + * case (head : Int) :: tail => + */ + case _ :: (unappl @ UnApply(_, _, patterns)) :: _ + if patterns.size > 1 => + val firstEnd = patterns(0).endPos.end + val secondStart = patterns(1).startPos.start + val hasDot = params + .text() + .substring(firstEnd, secondStart) + .exists(_ == ',') + if !hasDot then + val leftParen = new TextEdit(body.startPos.toLsp, "(") + leftParen :: baseEdit(withParens = true) :: Nil + else baseEdit(withParens = false) :: Nil + + case _ => + baseEdit(withParens = false) :: Nil + typeNameEdit ::: imports + + /* `for(t <- 0 to 10)` + * turns into + * `for(t: Int <- 0 to 10)` + */ + case Some(i @ Ident(name)) => + val typeNameEdit = new TextEdit( + i.endPos.toLsp, + ": " + printType(optDealias(i.tpe.widen)) + ) + typeNameEdit :: imports + + case _ => + Nil + end match + end inferredTypeEdits + + private def findNamePos( + text: String, + tree: untpd.NamedDefTree, + kewordOffset: Int + )(using + Context + ): SourcePosition = + val realName = tree.name.stripModuleClassSuffix.toString.toList + + // `NamedDefTree.namePos` is incorrect for bacticked names + @tailrec + def lookup( + idx: Int, + start: Option[(Int, List[Char])], + withBacktick: Boolean + ): Option[SourcePosition] = + start match + case Some((start, nextMatch :: left)) => + if text.charAt(idx) == nextMatch then + lookup(idx + 1, Some((start, left)), withBacktick) + else lookup(idx + 1, None, withBacktick = false) + case Some((start, Nil)) => + val end = if withBacktick then idx + 1 else idx + val pos = tree.source.atSpan(Span(start, end, start)) + Some(pos) + case None if idx < text.length => + val ch = text.charAt(idx) + if ch == realName.head then + lookup(idx + 1, Some((idx, realName.tail)), withBacktick) + else if ch == '`' then lookup(idx + 1, None, withBacktick = true) + else lookup(idx + 1, None, withBacktick = false) + case _ => + None + + val matchedByText = + if realName.nonEmpty then + lookup(tree.sourcePos.start + kewordOffset, None, false) + else None + + matchedByText.getOrElse(tree.namePos) + end findNamePos + +end InferredTypeProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/MetalsDriver.scala b/presentation-compiler/src/main/dotty/tools/pc/MetalsDriver.scala new file mode 100644 index 000000000000..4f7ed751f958 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/MetalsDriver.scala @@ -0,0 +1,56 @@ +package dotty.tools.pc + +import java.net.URI +import java.util as ju + +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.reporting.Diagnostic +import dotty.tools.dotc.util.SourceFile + +/** + * MetalsDriver is a wrapper class that provides a compilation cache for InteractiveDriver. + * MetalsDriver skips running compilation if + * - the target URI of `run` is the same as the previous target URI + * - the content didn't change since the last compilation. + * + * This compilation cache enables Metals to skip compilation and re-use + * the typed tree under the situation like developers + * sequentially hover on the symbols in the same file without any changes. + * + * Note: we decided to cache only if the target URI is the same as in the previous run + * because of `InteractiveDriver.currentCtx` that should return the context that + * refers to the last compiled source file. + * It would be ideal if we could update currentCtx even when we skip the compilation, + * but we struggled to do that. See the discussion https://github.com/scalameta/metals/pull/4225#discussion_r941138403 + * To avoid the complexity related to currentCtx, + * we decided to cache only when the target URI only if the same as the previous run. + */ +class MetalsDriver( + override val settings: List[String] +) extends InteractiveDriver(settings): + + @volatile private var lastCompiledURI: URI = _ + + private def alreadyCompiled(uri: URI, content: Array[Char]): Boolean = + compilationUnits.get(uri) match + case Some(unit) + if lastCompiledURI == uri && + ju.Arrays.equals(unit.source.content(), content) => + true + case _ => false + + override def run(uri: URI, source: SourceFile): List[Diagnostic] = + val diags = + if alreadyCompiled(uri, source.content) then Nil + else super.run(uri, source) + lastCompiledURI = uri + diags + + override def run(uri: URI, sourceCode: String): List[Diagnostic] = + val diags = + if alreadyCompiled(uri, sourceCode.toCharArray()) then Nil + else super.run(uri, sourceCode) + lastCompiledURI = uri + diags + +end MetalsDriver diff --git a/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala b/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala new file mode 100644 index 000000000000..076c1d9a9b88 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala @@ -0,0 +1,279 @@ +package dotty.tools.pc + +import scala.annotation.tailrec + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Names.Name +import dotty.tools.dotc.core.StdNames +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.Type +import dotty.tools.dotc.interactive.SourceTree +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition + +object MetalsInteractive: + + def contextOfStat( + stats: List[Tree], + stat: Tree, + exprOwner: Symbol, + ctx: Context + ): Context = stats match + case Nil => + ctx + case first :: _ if first eq stat => + ctx.exprContext(stat, exprOwner) + case (imp: Import) :: rest => + contextOfStat( + rest, + stat, + exprOwner, + ctx.importContext(imp, inContext(ctx) { imp.symbol }) + ) + case _ :: rest => + contextOfStat(rest, stat, exprOwner, ctx) + + /** + * Check if the given `sourcePos` is on the name of enclosing tree. + * ``` + * // For example, if the postion is on `foo`, returns true + * def foo(x: Int) = { ... } + * ^ + * + * // On the other hand, it points to non-name position, return false. + * def foo(x: Int) = { ... } + * ^ + * ``` + * @param path - path to the position given by `Interactive.pathTo` + */ + def isOnName( + path: List[Tree], + sourcePos: SourcePosition, + source: SourceFile + )(using Context): Boolean = + def contains(tree: Tree): Boolean = tree match + case select: Select => + // using `nameSpan` as SourceTree for Select (especially symbolic-infix e.g. `::` of `1 :: Nil`) miscalculate positions + select.nameSpan.contains(sourcePos.span) + case tree: Ident => + tree.sourcePos.contains(sourcePos) + case tree: NamedDefTree => + tree.namePos.contains(sourcePos) + case tree: NameTree => + SourceTree(tree, source).namePos.contains(sourcePos) + // TODO: check the positions for NamedArg and Import + case _: NamedArg => true + case _: Import => true + case app: (Apply | TypeApply) => contains(app.fun) + case _ => false + end contains + + val enclosing = path + .dropWhile(t => !t.symbol.exists && !t.isInstanceOf[NamedArg]) + .headOption + .getOrElse(EmptyTree) + contains(enclosing) + end isOnName + + private lazy val isForName: Set[Name] = Set[Name]( + StdNames.nme.map, + StdNames.nme.withFilter, + StdNames.nme.flatMap, + StdNames.nme.foreach + ) + def isForSynthetic(gtree: Tree)(using Context): Boolean = + def isForComprehensionSyntheticName(select: Select): Boolean = + select.sourcePos.toSynthetic == select.qualifier.sourcePos.toSynthetic && isForName( + select.name + ) + gtree match + case Apply(fun, List(_: Block)) => isForSynthetic(fun) + case TypeApply(fun, _) => isForSynthetic(fun) + case gtree: Select if isForComprehensionSyntheticName(gtree) => true + case _ => false + + def enclosingSymbols( + path: List[Tree], + pos: SourcePosition, + indexed: IndexedContext, + skipCheckOnName: Boolean = false + ): List[Symbol] = + enclosingSymbolsWithExpressionType(path, pos, indexed, skipCheckOnName) + .map(_._1) + + /** + * Returns the list of tuple enclosing symbol and + * the symbol's expression type if possible. + */ + @tailrec + def enclosingSymbolsWithExpressionType( + path: List[Tree], + pos: SourcePosition, + indexed: IndexedContext, + skipCheckOnName: Boolean = false + ): List[(Symbol, Type)] = + import indexed.ctx + path match + // For a named arg, find the target `DefDef` and jump to the param + case NamedArg(name, _) :: Apply(fn, _) :: _ => + val funSym = fn.symbol + if funSym.is(Synthetic) && funSym.owner.is(CaseClass) then + val sym = funSym.owner.info.member(name).symbol + List((sym, sym.info)) + else + val paramSymbol = + for param <- funSym.paramSymss.flatten.find(_.name == name) + yield param + val sym = paramSymbol.getOrElse(fn.symbol) + List((sym, sym.info)) + + case (_: untpd.ImportSelector) :: (imp: Import) :: _ => + importedSymbols(imp, _.span.contains(pos.span)).map(sym => + (sym, sym.info) + ) + + case (imp: Import) :: _ => + importedSymbols(imp, _.span.contains(pos.span)).map(sym => + (sym, sym.info) + ) + + // wildcard param + case head :: _ if (head.symbol.is(Param) && head.symbol.is(Synthetic)) => + List((head.symbol, head.typeOpt)) + + case (head @ Select(target, name)) :: _ + if head.symbol.is(Synthetic) && name == StdNames.nme.apply => + val sym = target.symbol + if sym.is(Synthetic) && sym.is(Module) then + List((sym.companionClass, sym.companionClass.info)) + else List((target.symbol, target.typeOpt)) + + // L@@ft(...) + case (head @ ApplySelect(select)) :: _ + if select.qualifier.sourcePos.contains(pos) && + select.name == StdNames.nme.apply => + List((head.symbol, head.typeOpt)) + + // for Inlined we don't have a symbol, but it's needed to show proper type + case (head @ Inlined(call, bindings, expansion)) :: _ => + List((call.symbol, head.typeOpt)) + + // for comprehension + case (head @ ApplySelect(select)) :: _ if isForSynthetic(head) => + // If the cursor is on the qualifier, return the symbol for it + // `for { x <- List(1).head@@Option }` returns the symbol of `headOption` + if select.qualifier.sourcePos.contains(pos) then + List((select.qualifier.symbol, select.qualifier.typeOpt)) + // Otherwise, returns the symbol of for synthetics such as "withFilter" + else List((head.symbol, head.typeOpt)) + + // f@@oo.bar + case Select(target, _) :: _ + if target.span.isSourceDerived && + target.sourcePos.contains(pos) => + List((target.symbol, target.typeOpt)) + + /* In some cases type might be represented by TypeTree, however it's possible + * that the type tree will not be marked properly as synthetic even if it doesn't + * exist in the code. + * + * For example for `Li@@st(1)` we will get the type tree representing [Int] + * despite it not being in the code. + * + * To work around it we check if the current and parent spans match, if they match + * this most likely means that the type tree is synthetic, since it has efectively + * span of 0. + */ + case (tpt: TypeTree) :: parent :: _ + if tpt.span != parent.span && !tpt.symbol.is(Synthetic) => + List((tpt.symbol, tpt.tpe)) + + /* TypeTest class https://dotty.epfl.ch/docs/reference/other-new-features/type-test.html + * compiler automatically adds unapply if possible, we need to find the type symbol + */ + case (head @ CaseDef(pat, _, _)) :: _ + if pat.symbol.exists && defn.TypeTestClass == pat.symbol.owner => + pat match + case UnApply(fun, _, pats) => + val tpeSym = pats.head.typeOpt.typeSymbol + List((tpeSym, tpeSym.info)) + case _ => + Nil + + case path @ head :: tail => + if head.symbol.is(Synthetic) then + enclosingSymbolsWithExpressionType( + tail, + pos, + indexed, + skipCheckOnName + ) + else if head.symbol != NoSymbol then + if skipCheckOnName || + MetalsInteractive.isOnName( + path, + pos, + indexed.ctx.source + ) + then List((head.symbol, head.typeOpt)) + /* Type tree for List(1) has an Int type variable, which has span + * but doesn't exist in code. + * https://github.com/lampepfl/dotty/issues/15937 + */ + else if head.isInstanceOf[TypeTree] then + enclosingSymbolsWithExpressionType(tail, pos, indexed) + else Nil + else + val recovered = recoverError(head, indexed) + if recovered.isEmpty then + enclosingSymbolsWithExpressionType( + tail, + pos, + indexed, + skipCheckOnName + ) + else recovered.map(sym => (sym, sym.info)) + end if + case Nil => Nil + end match + end enclosingSymbolsWithExpressionType + + import dotty.tools.pc.utils.MtagsEnrichments.* + + private def recoverError( + tree: Tree, + indexed: IndexedContext + ): List[Symbol] = + import indexed.ctx + + tree match + case select: Select => + select.qualifier.typeOpt + .member(select.name) + .allSymbols + .filter(_ != NoSymbol) + case ident: Ident => indexed.findSymbol(ident.name).toList.flatten + case _ => Nil + end recoverError + + object ApplySelect: + def unapply(tree: Tree): Option[Select] = Option(tree).collect { + case select: Select => select + case Apply(select: Select, _) => select + case Apply(TypeApply(select: Select, _), _) => select + } + end ApplySelect + + object TreeApply: + def unapply(tree: Tree): Option[(Tree, List[Tree])] = + tree match + case TypeApply(qual, args) => Some(qual -> args) + case Apply(qual, args) => Some(qual -> args) + case UnApply(qual, implicits, args) => Some(qual -> (implicits ++ args)) + case AppliedTypeTree(qual, args) => Some(qual -> args) + case _ => None +end MetalsInteractive diff --git a/presentation-compiler/src/main/dotty/tools/pc/Params.scala b/presentation-compiler/src/main/dotty/tools/pc/Params.scala new file mode 100644 index 000000000000..9226bb8e231f --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/Params.scala @@ -0,0 +1,23 @@ +package dotty.tools.pc + +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Symbols.Symbol + +case class Params( + labels: Seq[String], + kind: Params.Kind +) + +object Params: + enum Kind: + case TypeParameter, Normal, Implicit, Using + + def paramsKind(syms: List[Symbol])(using Context): Params.Kind = + syms match + case head :: _ => + if head.isType then Kind.TypeParameter + else if head.is(Given) then Kind.Using + else if head.is(Implicit) then Kind.Implicit + else Kind.Normal + case Nil => Kind.Normal diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcCollector.scala b/presentation-compiler/src/main/dotty/tools/pc/PcCollector.scala new file mode 100644 index 000000000000..7b9cf7acd4a7 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcCollector.scala @@ -0,0 +1,579 @@ +package dotty.tools.pc + +import java.nio.file.Paths + +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.pc.OffsetParams +import scala.meta.pc.VirtualFileParams +import scala.meta as m + +import dotty.tools.dotc.ast.NavigateAST +import dotty.tools.dotc.ast.Positioned +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.ast.untpd.ExtMethods +import dotty.tools.dotc.ast.untpd.ImportSelector +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.NameOps.* +import dotty.tools.dotc.core.Names.* +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.dotc.util.Spans.Span +import dotty.tools.pc.utils.MtagsEnrichments.* + +abstract class PcCollector[T]( + driver: InteractiveDriver, + params: VirtualFileParams +): + private val caseClassSynthetics: Set[Name] = Set(nme.apply, nme.copy) + val uri = params.uri() + val filePath = Paths.get(uri) + val sourceText = params.text + val source = + SourceFile.virtual(filePath.toString, sourceText) + driver.run(uri, source) + given ctx: Context = driver.currentCtx + + val unit = driver.currentCtx.run.units.head + val compilatonUnitContext = ctx.fresh.setCompilationUnit(unit) + val offset = params match + case op: OffsetParams => op.offset() + case _ => 0 + val offsetParams = + params match + case op: OffsetParams => op + case _ => + CompilerOffsetParams(params.uri(), params.text(), 0, params.token()) + val pos = driver.sourcePosition(offsetParams) + val rawPath = + Interactive + .pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + .dropWhile(t => // NamedArg anyway doesn't have symbol + t.symbol == NoSymbol && !t.isInstanceOf[NamedArg] || + // same issue https://github.com/lampepfl/dotty/issues/15937 as below + t.isInstanceOf[TypeTree] + ) + + val path = rawPath match + // For type it will sometimes go into the wrong tree since TypeTree also contains the same span + // https://github.com/lampepfl/dotty/issues/15937 + case TypeApply(sel: Select, _) :: tail if sel.span.contains(pos.span) => + Interactive.pathTo(sel, pos.span) ::: rawPath + case _ => rawPath + def collect( + parent: Option[Tree] + )(tree: Tree, pos: SourcePosition, symbol: Option[Symbol]): T + + /** + * @return (adjusted position, should strip backticks) + */ + def adjust( + pos1: SourcePosition, + forRename: Boolean = false + ): (SourcePosition, Boolean) = + if !pos1.span.isCorrect then (pos1, false) + else + val pos0 = + val span = pos1.span + if span.exists && span.point > span.end then + pos1.withSpan( + span + .withStart(span.point) + .withEnd(span.point + (span.end - span.start)) + ) + else pos1 + + val pos = + if pos0.end > 0 && sourceText(pos0.end - 1) == ',' then + pos0.withEnd(pos0.end - 1) + else pos0 + val isBackticked = + sourceText(pos.start) == '`' && + pos.end > 0 && + sourceText(pos.end - 1) == '`' + // when the old name contains backticks, the position is incorrect + val isOldNameBackticked = sourceText(pos.start) != '`' && + pos.start > 0 && + sourceText(pos.start - 1) == '`' && + sourceText(pos.end) == '`' + if isBackticked && forRename then + (pos.withStart(pos.start + 1).withEnd(pos.end - 1), true) + else if isOldNameBackticked then + (pos.withStart(pos.start - 1).withEnd(pos.end + 1), false) + else (pos, false) + end adjust + + def symbolAlternatives(sym: Symbol) = + val all = + if sym.is(Flags.ModuleClass) then + Set(sym, sym.companionModule, sym.companionModule.companion) + else if sym.isClass then + Set(sym, sym.companionModule, sym.companion.moduleClass) + else if sym.is(Flags.Module) then + Set(sym, sym.companionClass, sym.moduleClass) + else if sym.isTerm && (sym.owner.isClass || sym.owner.isConstructor) + then + val info = + if sym.owner.isClass then sym.owner.info else sym.owner.owner.info + Set( + sym, + info.member(sym.asTerm.name.setterName).symbol, + info.member(sym.asTerm.name.getterName).symbol + ) ++ sym.allOverriddenSymbols.toSet + // type used in primary constructor will not match the one used in the class + else if sym.isTypeParam && sym.owner.isPrimaryConstructor then + Set(sym, sym.owner.owner.info.member(sym.name).symbol) + else Set(sym) + all.filter(s => s != NoSymbol && !s.isError) + end symbolAlternatives + + private def isGeneratedGiven(df: NamedDefTree)(using Context) = + val nameSpan = df.nameSpan + df.symbol.is(Flags.Given) && sourceText.substring( + nameSpan.start, + nameSpan.end + ) != df.name.toString() + + // First identify the symbol we are at, comments identify @@ as current cursor position + def soughtSymbols(path: List[Tree]): Option[(Set[Symbol], SourcePosition)] = + val sought = path match + /* reference of an extension paramter + * extension [EF](<>: List[EF]) + * def double(ys: List[EF]) = <> ++ ys + */ + case (id: Ident) :: _ + if id.symbol + .is(Flags.Param) && id.symbol.owner.is(Flags.ExtensionMethod) => + Some(findAllExtensionParamSymbols(id.sourcePos, id.name, id.symbol)) + /* simple identifier: + * val a = val@@ue + value + */ + case (id: Ident) :: _ => + Some(symbolAlternatives(id.symbol), id.sourcePos) + /* simple selector: + * object.val@@ue + */ + case (sel: Select) :: _ if selectNameSpan(sel).contains(pos.span) => + Some(symbolAlternatives(sel.symbol), pos.withSpan(sel.nameSpan)) + /* named argument: + * foo(nam@@e = "123") + */ + case (arg: NamedArg) :: (appl: Apply) :: _ => + val realName = arg.name.stripModuleClassSuffix.lastPart + if pos.span.start > arg.span.start && pos.span.end < arg.span.point + realName.length + then + appl.symbol.paramSymss.flatten.find(_.name == arg.name).map { s => + // if it's a case class we need to look for parameters also + if caseClassSynthetics(s.owner.name) && s.owner.is(Flags.Synthetic) + then + ( + Set( + s, + s.owner.owner.companion.info.member(s.name).symbol, + s.owner.owner.info.member(s.name).symbol + ) + .filter(_ != NoSymbol), + arg.sourcePos, + ) + else (Set(s), arg.sourcePos) + } + else None + end if + /* all definitions: + * def fo@@o = ??? + * class Fo@@o = ??? + * etc. + */ + case (df: NamedDefTree) :: _ + if df.nameSpan.contains(pos.span) && !isGeneratedGiven(df) => + Some(symbolAlternatives(df.symbol), pos.withSpan(df.nameSpan)) + /** + * For traversing annotations: + * @JsonNo@@tification("") + * def params() = ??? + */ + case (df: MemberDef) :: _ if df.span.contains(pos.span) => + val annotTree = df.mods.annotations.find { t => + t.span.contains(pos.span) + } + collectTrees(annotTree).flatMap { t => + soughtSymbols( + Interactive.pathTo(t, pos.span) + ) + }.headOption + + /* Import selectors: + * import scala.util.Tr@@y + */ + case (imp: Import) :: _ if imp.span.contains(pos.span) => + imp + .selector(pos.span) + .map(sym => (symbolAlternatives(sym), sym.sourcePos)) + + case _ => None + + sought match + case None => seekInExtensionParameters() + case _ => sought + + end soughtSymbols + + lazy val extensionMethods = + NavigateAST + .untypedPath(pos.span)(using compilatonUnitContext) + .collectFirst { case em @ ExtMethods(_, _) => em } + + private def findAllExtensionParamSymbols( + pos: SourcePosition, + name: Name, + sym: Symbol + ) = + val symbols = + for + methods <- extensionMethods.map(_.methods) + symbols <- collectAllExtensionParamSymbols( + unit.tpdTree, + ExtensionParamOccurence(name, pos, sym, methods) + ) + yield symbols + symbols.getOrElse((symbolAlternatives(sym), pos)) + end findAllExtensionParamSymbols + + private def seekInExtensionParameters() = + def collectParams( + extMethods: ExtMethods + ): Option[ExtensionParamOccurence] = + NavigateAST + .pathTo(pos.span, extMethods.paramss.flatten)(using + compilatonUnitContext + ) + .collectFirst { + case v: untpd.ValOrTypeDef => + ExtensionParamOccurence( + v.name, + v.namePos, + v.symbol, + extMethods.methods + ) + case i: untpd.Ident => + ExtensionParamOccurence( + i.name, + i.sourcePos, + i.symbol, + extMethods.methods + ) + } + + for + extensionMethodScope <- extensionMethods + occurence <- collectParams(extensionMethodScope) + symbols <- collectAllExtensionParamSymbols( + path.headOption.getOrElse(unit.tpdTree), + occurence + ) + yield symbols + end seekInExtensionParameters + + private def collectAllExtensionParamSymbols( + tree: tpd.Tree, + occurrence: ExtensionParamOccurence + ): Option[(Set[Symbol], SourcePosition)] = + occurrence match + case ExtensionParamOccurence(_, namePos, symbol, _) + if symbol != NoSymbol && !symbol.isError && !symbol.owner.is( + Flags.ExtensionMethod + ) => + Some((symbolAlternatives(symbol), namePos)) + case ExtensionParamOccurence(name, namePos, _, methods) => + val symbols = + for + method <- methods.toSet + symbol <- + Interactive.pathTo(tree, method.span) match + case (d: DefDef) :: _ => + d.paramss.flatten.collect { + case param if param.name.decoded == name.decoded => + param.symbol + } + case _ => Set.empty[Symbol] + if (symbol != NoSymbol && !symbol.isError) + withAlt <- symbolAlternatives(symbol) + yield withAlt + if symbols.nonEmpty then Some((symbols, namePos)) else None + end collectAllExtensionParamSymbols + + def result(): List[T] = + params match + case _: OffsetParams => resultWithSought() + case _ => resultAllOccurences().toList + + def resultAllOccurences(): Set[T] = + def noTreeFilter = (_: Tree) => true + def noSoughtFilter = (_: Symbol => Boolean) => true + + traverseSought(noTreeFilter, noSoughtFilter) + + def resultWithSought(): List[T] = + soughtSymbols(path) match + case Some((sought, _)) => + lazy val owners = sought + .flatMap { s => Set(s.owner, s.owner.companionModule) } + .filter(_ != NoSymbol) + lazy val soughtNames: Set[Name] = sought.map(_.name) + + /* + * For comprehensions have two owners, one for the enumerators and one for + * yield. This is a heuristic to find that out. + */ + def isForComprehensionOwner(named: NameTree) = + soughtNames(named.name) && + scala.util + .Try(named.symbol.owner) + .toOption + .exists(_.isAnonymousFunction) && + owners.exists(o => + o.span.exists && o.span.point == named.symbol.owner.span.point + ) + + def soughtOrOverride(sym: Symbol) = + sought(sym) || sym.allOverriddenSymbols.exists(sought(_)) + + def soughtTreeFilter(tree: Tree): Boolean = + tree match + case ident: Ident + if soughtOrOverride(ident.symbol) || + isForComprehensionOwner(ident) => + true + case sel: Select if soughtOrOverride(sel.symbol) => true + case df: NamedDefTree + if soughtOrOverride(df.symbol) && !df.symbol.isSetter => + true + case imp: Import if owners(imp.expr.symbol) => true + case _ => false + + def soughtFilter(f: Symbol => Boolean): Boolean = + sought.exists(f) + + traverseSought(soughtTreeFilter, soughtFilter).toList + + case None => Nil + + extension (span: Span) + def isCorrect = + !span.isZeroExtent && span.exists && span.start < sourceText.size && span.end <= sourceText.size + + def traverseSought( + filter: Tree => Boolean, + soughtFilter: (Symbol => Boolean) => Boolean + ): Set[T] = + def collectNamesWithParent( + occurences: Set[T], + tree: Tree, + parent: Option[Tree] + ): Set[T] = + def collect( + tree: Tree, + pos: SourcePosition, + symbol: Option[Symbol] = None + ) = + this.collect(parent)(tree, pos, symbol) + tree match + /** + * All indentifiers such as: + * val a = <> + */ + case ident: Ident if ident.span.isCorrect && filter(ident) => + // symbols will differ for params in different ext methods, but source pos will be the same + if soughtFilter(_.sourcePos == ident.symbol.sourcePos) + then + occurences + collect( + ident, + ident.sourcePos + ) + else occurences + /** + * All select statements such as: + * val a = hello.<> + */ + case sel: Select if sel.span.isCorrect && filter(sel) => + occurences + collect( + sel, + pos.withSpan(selectNameSpan(sel)) + ) + /* all definitions: + * def <> = ??? + * class <> = ??? + * etc. + */ + case df: NamedDefTree + if df.span.isCorrect && df.nameSpan.isCorrect && + filter(df) && !isGeneratedGiven(df) => + val annots = collectTrees(df.mods.annotations) + val traverser = + new PcCollector.DeepFolderWithParent[Set[T]]( + collectNamesWithParent + ) + annots.foldLeft( + occurences + collect( + df, + pos.withSpan(df.nameSpan) + ) + ) { case (set, tree) => + traverser(set, tree) + } + + /* Named parameters don't have symbol so we need to check the owner + * foo(<> = "abc") + * User(<> = "abc") + * etc. + */ + case apply: Apply => + val args: List[NamedArg] = apply.args.collect { + case arg: NamedArg + if soughtFilter(sym => + sym.name == arg.name && + // foo(name = "123") for normal params + (sym.owner == apply.symbol || + // Bar(name = "123") for case class, copy and apply methods + apply.symbol.is(Flags.Synthetic) && + (sym.owner == apply.symbol.owner.companion || sym.owner == apply.symbol.owner)) + ) => + arg + } + val named = args.map { arg => + val realName = arg.name.stripModuleClassSuffix.lastPart + val sym = apply.symbol.paramSymss.flatten + .find(_.name == realName) + collect( + arg, + pos + .withSpan( + arg.span + .withEnd(arg.span.start + realName.length) + .withPoint(arg.span.start) + ), + sym + ) + } + occurences ++ named + + /** + * For traversing annotations: + * @<>("") + * def params() = ??? + */ + case mdf: MemberDef if mdf.mods.annotations.nonEmpty => + val trees = collectTrees(mdf.mods.annotations) + val traverser = + new PcCollector.DeepFolderWithParent[Set[T]]( + collectNamesWithParent + ) + trees.foldLeft(occurences) { case (set, tree) => + traverser(set, tree) + } + /** + * For traversing import selectors: + * import scala.util.<> + */ + case imp: Import if filter(imp) => + imp.selectors + .collect { + case sel: ImportSelector + if soughtFilter(_.decodedName == sel.name.decoded) => + // Show both rename and main together + val spans = + if !sel.renamed.isEmpty then + Set(sel.renamed.span, sel.imported.span) + else Set(sel.imported.span) + // See https://github.com/scalameta/metals/pull/5100 + val symbol = imp.expr.symbol.info.member(sel.name).symbol match + // We can get NoSymbol when we import "_", "*"", "given" or when the names don't match + // eg. "@@" doesn't match "$at$at". + // Then we try to find member based on decodedName + case NoSymbol => + imp.expr.symbol.info.allMembers + .find(_.name.decoded == sel.name.decoded) + .map(_.symbol) + .getOrElse(NoSymbol) + case sym => sym + spans.filter(_.isCorrect).map { span => + collect( + imp, + pos.withSpan(span), + Some(symbol) + ) + } + } + .flatten + .toSet ++ occurences + case inl: Inlined => + val traverser = + new PcCollector.DeepFolderWithParent[Set[T]]( + collectNamesWithParent + ) + val trees = inl.call :: inl.bindings + trees.foldLeft(occurences) { case (set, tree) => + traverser(set, tree) + } + case o => + occurences + end match + end collectNamesWithParent + + val traverser = + new PcCollector.DeepFolderWithParent[Set[T]](collectNamesWithParent) + val all = traverser(Set.empty[T], unit.tpdTree) + all + end traverseSought + + // @note (tgodzik) Not sure currently how to get rid of the warning, but looks to correctly + // @nowarn + private def collectTrees(trees: Iterable[Positioned]): Iterable[Tree] = + trees.collect { case t: Tree => + t + } + + // NOTE: Connected to https://github.com/lampepfl/dotty/issues/16771 + // `sel.nameSpan` is calculated incorrectly in (1 + 2).toString + // See test DocumentHighlightSuite.select-parentheses + private def selectNameSpan(sel: Select): Span = + val span = sel.span + if span.exists then + val point = span.point + if sel.name.toTermName == nme.ERROR then Span(point) + else if sel.qualifier.span.start > span.point then // right associative + val realName = sel.name.stripModuleClassSuffix.lastPart + Span(span.start, span.start + realName.length, point) + else Span(point, span.end, point) + else span +end PcCollector + +object PcCollector: + private class WithParentTraverser[X](f: (X, Tree, Option[Tree]) => X) + extends TreeAccumulator[List[Tree]]: + def apply(x: List[Tree], tree: Tree)(using Context): List[Tree] = tree :: x + def traverse(acc: X, tree: Tree, parent: Option[Tree])(using Context): X = + val res = f(acc, tree, parent) + val children = foldOver(Nil, tree).reverse + children.foldLeft(res)((a, c) => traverse(a, c, Some(tree))) + + // Folds over the tree as `DeepFolder` but `f` takes also the parent. + class DeepFolderWithParent[X](f: (X, Tree, Option[Tree]) => X): + private val traverser = WithParentTraverser[X](f) + def apply(x: X, tree: Tree)(using Context) = + traverser.traverse(x, tree, None) +end PcCollector + +case class ExtensionParamOccurence( + name: Name, + pos: SourcePosition, + sym: Symbol, + methods: List[untpd.Tree] +) diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala new file mode 100644 index 000000000000..5d80b7d9be48 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala @@ -0,0 +1,170 @@ +package dotty.tools.pc + +import java.nio.file.Paths +import java.util.ArrayList + +import scala.jdk.CollectionConverters.* +import scala.meta.internal.pc.DefinitionResultImpl +import scala.meta.pc.DefinitionResult +import scala.meta.pc.OffsetParams +import scala.meta.pc.SymbolSearch + +import dotty.tools.dotc.ast.NavigateAST +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags.ModuleClass +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j.Location + +class PcDefinitionProvider( + driver: InteractiveDriver, + params: OffsetParams, + search: SymbolSearch +): + + def definitions(): DefinitionResult = + definitions(findTypeDef = false) + + def typeDefinitions(): DefinitionResult = + definitions(findTypeDef = true) + + private def definitions(findTypeDef: Boolean): DefinitionResult = + val uri = params.uri + val filePath = Paths.get(uri) + driver.run( + uri, + SourceFile.virtual(filePath.toString, params.text) + ) + + val pos = driver.sourcePosition(params) + val path = + Interactive.pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + + given ctx: Context = driver.localContext(params) + val indexedContext = IndexedContext(ctx) + val result = + if findTypeDef then findTypeDefinitions(path, pos, indexedContext) + else findDefinitions(path, pos, indexedContext) + + if result.locations().isEmpty() then fallbackToUntyped(pos)(using ctx) + else result + end definitions + + /** + * Some nodes might disapear from the typed tree, since they are mostly + * used as syntactic sugar. In those cases we check the untyped tree + * and try to get the symbol from there, which might actually be there, + * because these are the same nodes that go through the typer. + * + * This will happen for: + * - `.. derives Show` + * @param unit compilation unit of the file + * @param pos cursor position + * @return definition result + */ + private def fallbackToUntyped(pos: SourcePosition)( + using ctx: Context + ) = + lazy val untpdPath = NavigateAST + .untypedPath(pos.span) + .collect { case t: untpd.Tree => t } + + definitionsForSymbol(untpdPath.headOption.map(_.symbol).toList, pos) + end fallbackToUntyped + + private def findDefinitions( + path: List[Tree], + pos: SourcePosition, + indexed: IndexedContext + ): DefinitionResult = + import indexed.ctx + definitionsForSymbol( + MetalsInteractive.enclosingSymbols(path, pos, indexed), + pos + ) + end findDefinitions + + private def findTypeDefinitions( + path: List[Tree], + pos: SourcePosition, + indexed: IndexedContext + ): DefinitionResult = + import indexed.ctx + val enclosing = path.expandRangeToEnclosingApply(pos) + val typeSymbols = MetalsInteractive + .enclosingSymbolsWithExpressionType(enclosing, pos, indexed) + .map { case (_, tpe) => + tpe.typeSymbol + } + typeSymbols match + case Nil => + path.headOption match + case Some(value: Literal) => + definitionsForSymbol(List(value.tpe.widen.typeSymbol), pos) + case _ => DefinitionResultImpl.empty + case _ => + definitionsForSymbol(typeSymbols, pos) + + end findTypeDefinitions + + private def definitionsForSymbol( + symbols: List[Symbol], + pos: SourcePosition + )(using ctx: Context): DefinitionResult = + symbols match + case symbols @ (sym :: other) => + val isLocal = sym.source == pos.source + if isLocal then + val defs = + Interactive.findDefinitions(List(sym), driver, false, false) + defs.headOption match + case Some(srcTree) => + val pos = srcTree.namePos + pos.toLocation match + case None => DefinitionResultImpl.empty + case Some(loc) => + DefinitionResultImpl( + SemanticdbSymbols.symbolName(sym), + List(loc).asJava + ) + case None => + DefinitionResultImpl.empty + else + val res = new ArrayList[Location]() + semanticSymbolsSorted(symbols) + .foreach { sym => + res.addAll(search.definition(sym, params.uri())) + } + DefinitionResultImpl( + SemanticdbSymbols.symbolName(sym), + res + ) + end if + case Nil => DefinitionResultImpl.empty + end match + end definitionsForSymbol + + def semanticSymbolsSorted( + syms: List[Symbol] + )(using ctx: Context): List[String] = + syms + .map { sym => + // in case of having the same type and teerm symbol + // term comes first + // used only for ordering symbols that come from `Import` + val termFlag = + if sym.is(ModuleClass) then sym.sourceModule.isTerm + else sym.isTerm + (termFlag, SemanticdbSymbols.symbolName(sym)) + } + .sorted + .map(_._2) + +end PcDefinitionProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcDocumentHighlightProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcDocumentHighlightProvider.scala new file mode 100644 index 000000000000..71e36297cbba --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcDocumentHighlightProvider.scala @@ -0,0 +1,34 @@ +package dotty.tools.pc + +import scala.meta.pc.OffsetParams + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j.DocumentHighlight +import org.eclipse.lsp4j.DocumentHighlightKind + +final class PcDocumentHighlightProvider( + driver: InteractiveDriver, + params: OffsetParams +) extends PcCollector[DocumentHighlight](driver, params): + + def collect( + parent: Option[Tree] + )( + tree: Tree, + toAdjust: SourcePosition, + sym: Option[Symbol] + ): DocumentHighlight = + val (pos, _) = adjust(toAdjust) + tree match + case _: NamedDefTree => + DocumentHighlight(pos.toLsp, DocumentHighlightKind.Write) + case _ => DocumentHighlight(pos.toLsp, DocumentHighlightKind.Read) + + def highlights: List[DocumentHighlight] = + result().distinctBy(_.getRange()) +end PcDocumentHighlightProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProviderImpl.scala b/presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProviderImpl.scala new file mode 100644 index 000000000000..2c140b47f1e9 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProviderImpl.scala @@ -0,0 +1,205 @@ +package dotty.tools.pc + +import scala.meta.internal.pc.Definition +import scala.meta.internal.pc.InlineValueProvider +import scala.meta.internal.pc.InlineValueProvider.Errors +import scala.meta.internal.pc.RangeOffset +import scala.meta.internal.pc.Reference +import scala.meta.pc.OffsetParams + +import dotty.tools.dotc.ast.NavigateAST +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.ast.untpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.StdNames +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j as l + +final class PcInlineValueProviderImpl( + val driver: InteractiveDriver, + val params: OffsetParams +) extends PcCollector[Occurence](driver, params) + with InlineValueProvider: + + val text = params.text.toCharArray() + + val position: l.Position = pos.toLsp.getStart() + + override def collect(parent: Option[Tree])( + tree: Tree, + pos: SourcePosition, + sym: Option[Symbol] + ): Occurence = + val (adjustedPos, _) = adjust(pos) + Occurence(tree, parent, adjustedPos) + + override def defAndRefs(): Either[String, (Definition, List[Reference])] = + val newctx = driver.currentCtx.fresh.setCompilationUnit(unit) + val allOccurences = result() + for + definition <- allOccurences + .collectFirst { case Occurence(defn: ValDef, _, pos) => + DefinitionTree(defn, pos) + } + .toRight(Errors.didNotFindDefinition) + symbols = symbolsUsedInDefn(definition.tree.rhs) + references <- getReferencesToInline(definition, allOccurences, symbols) + yield + val (deleteDefinition, refsEdits) = references + + val defPos = definition.tree.sourcePos + val defEdit = Definition( + defPos.toLsp, + adjustRhs(definition.tree.rhs.sourcePos), + RangeOffset(defPos.start, defPos.end), + definitionRequiresBrackets(definition.tree.rhs)(using newctx), + deleteDefinition + ) + + (defEdit, refsEdits) + end for + end defAndRefs + + private def definitionRequiresBrackets(tree: Tree)(using Context): Boolean = + NavigateAST + .untypedPath(tree.span) + .headOption + .map { + case _: untpd.If => true + case _: untpd.Function => true + case _: untpd.Match => true + case _: untpd.ForYield => true + case _: untpd.InfixOp => true + case _: untpd.ParsedTry => true + case _: untpd.Try => true + case _: untpd.Block => true + case _: untpd.Typed => true + case _ => false + } + .getOrElse(false) + + end definitionRequiresBrackets + + private def referenceRequiresBrackets(tree: Tree)(using Context): Boolean = + NavigateAST.untypedPath(tree.span) match + case (_: untpd.InfixOp) :: _ => true + case _ => + tree match + case _: Apply => StdNames.nme.raw.isUnary(tree.symbol.name) + case _: Select => true + case _: Ident => true + case _ => false + + end referenceRequiresBrackets + + private def adjustRhs(pos: SourcePosition) = + def extend(point: Int, acceptedChar: Char, step: Int): Int = + val newPoint = point + step + if newPoint > 0 && newPoint < text.length && text( + newPoint + ) == acceptedChar + then extend(newPoint, acceptedChar, step) + else point + val adjustedStart = extend(pos.start, '(', -1) + val adjustedEnd = extend(pos.end - 1, ')', 1) + 1 + text.slice(adjustedStart, adjustedEnd).mkString + + private def symbolsUsedInDefn( + rhs: Tree + ): List[Symbol] = + def collectNames( + symbols: List[Symbol], + tree: Tree + ): List[Symbol] = + tree match + case id: (Ident | Select) + if !id.symbol.is(Synthetic) && !id.symbol.is(Implicit) => + tree.symbol :: symbols + case _ => symbols + + val traverser = new DeepFolder[List[Symbol]](collectNames) + traverser(List(), rhs) + end symbolsUsedInDefn + + private def getReferencesToInline( + definition: DefinitionTree, + allOccurences: List[Occurence], + symbols: List[Symbol] + ): Either[String, (Boolean, List[Reference])] = + val defIsLocal = definition.tree.symbol.ownersIterator + .drop(1) + .exists(e => e.isTerm) + def allreferences = allOccurences.filterNot(_.isDefn) + def inlineAll() = + makeRefsEdits(allreferences, symbols).map((true, _)) + if definition.tree.sourcePos.toLsp.encloses(position) + then if defIsLocal then inlineAll() else Left(Errors.notLocal) + else + allreferences match + case ref :: Nil if defIsLocal => inlineAll() + case list => + for + ref <- list + .find(_.pos.toLsp.encloses(position)) + .toRight(Errors.didNotFindReference) + refEdits <- makeRefsEdits(List(ref), symbols) + yield (false, refEdits) + end if + end getReferencesToInline + + private def makeRefsEdits( + refs: List[Occurence], + symbols: List[Symbol] + ): Either[String, List[Reference]] = + val newctx = driver.currentCtx.fresh.setCompilationUnit(unit) + def buildRef(occurence: Occurence): Either[String, Reference] = + val path = + Interactive.pathTo(unit.tpdTree, occurence.pos.span)(using newctx) + val indexedContext = IndexedContext( + Interactive.contextOfPath(path)(using newctx) + ) + import indexedContext.ctx + val conflictingSymbols = symbols + .withFilter { + indexedContext.lookupSym(_) match + case IndexedContext.Result.Conflict => true + case _ => false + } + .map(_.fullNameBackticked) + if conflictingSymbols.isEmpty then + Right( + Reference( + occurence.pos.toLsp, + occurence.parent.map(p => + RangeOffset(p.sourcePos.start, p.sourcePos.end) + ), + occurence.parent + .map(p => referenceRequiresBrackets(p)(using newctx)) + .getOrElse(false) + ) + ) + else Left(Errors.variablesAreShadowed(conflictingSymbols.mkString(", "))) + end buildRef + refs.foldLeft((Right(List())): Either[String, List[Reference]])((acc, r) => + for + collectedEdits <- acc + currentEdit <- buildRef(r) + yield currentEdit :: collectedEdits + ) + end makeRefsEdits + +end PcInlineValueProviderImpl + +case class Occurence(tree: Tree, parent: Option[Tree], pos: SourcePosition): + def isDefn = + tree match + case _: ValDef => true + case _ => false + +case class DefinitionTree(tree: ValDef, pos: SourcePosition) diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcRenameProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcRenameProvider.scala new file mode 100644 index 000000000000..4477529d7124 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcRenameProvider.scala @@ -0,0 +1,54 @@ +package dotty.tools.pc + +import scala.meta.pc.OffsetParams + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags.* +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j as l + +final class PcRenameProvider( + driver: InteractiveDriver, + params: OffsetParams, + name: Option[String] +) extends PcCollector[l.TextEdit](driver, params): + private val forbiddenMethods = + Set("equals", "hashCode", "unapply", "unary_!", "!") + def canRenameSymbol(sym: Symbol)(using Context): Boolean = + (!sym.is(Method) || !forbiddenMethods(sym.decodedName)) + && (sym.ownersIterator.drop(1).exists(ow => ow.is(Method)) + || sym.source.path.isWorksheet) + + def prepareRename(): Option[l.Range] = + soughtSymbols(path).flatMap((symbols, pos) => + if symbols.forall(canRenameSymbol) then Some(pos.toLsp) + else None + ) + + val newName = name.map(_.stripBackticks.backticked).getOrElse("newName") + + def collect( + parent: Option[Tree] + )(tree: Tree, toAdjust: SourcePosition, sym: Option[Symbol]): l.TextEdit = + val (pos, stripBackticks) = adjust(toAdjust, forRename = true) + l.TextEdit( + pos.toLsp, + if stripBackticks then newName.stripBackticks else newName + ) + end collect + + def rename( + ): List[l.TextEdit] = + val (symbols, _) = soughtSymbols(path).getOrElse(Set.empty, pos) + if symbols.nonEmpty && symbols.forall(canRenameSymbol(_)) + then + val res = result() + res + else Nil + end rename +end PcRenameProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcSemanticTokensProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/PcSemanticTokensProvider.scala new file mode 100644 index 000000000000..5f47b4d0d8bb --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcSemanticTokensProvider.scala @@ -0,0 +1,152 @@ +package dotty.tools.pc + +import scala.meta.internal.pc.SemanticTokens.* +import scala.meta.internal.pc.TokenNode +import scala.meta.pc.Node +import scala.meta.pc.VirtualFileParams + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Symbols.NoSymbol +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.utils.MtagsEnrichments.* + +import org.eclipse.lsp4j.SemanticTokenModifiers +import org.eclipse.lsp4j.SemanticTokenTypes + +/** + * Provides semantic tokens of file(@param params) + * according to the LSP specification. + */ +final class PcSemanticTokensProvider( + driver: InteractiveDriver, + params: VirtualFileParams +): + /** + * Declaration is set for: + * 1. parameters, + * 2. defs/vals/vars without rhs, + * 3. type parameters, + * In all those cases we don't have a specific value for sure. + */ + private def isDeclaration(tree: Tree) = tree match + case df: ValOrDefDef => df.rhs.isEmpty + case df: TypeDef => + df.rhs match + case _: Template => false + case _ => df.rhs.isEmpty + case _ => false + + /** + * Definition is set for: + * 1. defs/vals/vars/type with rhs. + * 2. pattern matches + * + * We don't want to set it for enum cases despite the fact + * that the compiler sees them as vals, as it's not clear + * if they should be declaration/definition at all. + */ + private def isDefinition(tree: Tree) = tree match + case df: Bind => true + case df: ValOrDefDef => + !df.rhs.isEmpty && !df.symbol.isAllOf(Flags.EnumCase) + case df: TypeDef => + df.rhs match + case _: Template => false + case _ => !df.rhs.isEmpty + case _ => false + + object Collector extends PcCollector[Option[Node]](driver, params): + override def collect( + parent: Option[Tree] + )(tree: Tree, pos: SourcePosition, symbol: Option[Symbol]): Option[Node] = + val sym = symbol.fold(tree.symbol)(identity) + if !pos.exists || sym == null || sym == NoSymbol then None + else + Some( + makeNode( + sym = sym, + pos = adjust(pos)._1, + isDefinition = isDefinition(tree), + isDeclaration = isDeclaration(tree) + ) + ) + end collect + end Collector + + given Context = Collector.ctx + + def provide(): List[Node] = + Collector + .result() + .flatten + .sortWith((n1, n2) => + if n1.start() == n2.start() then n1.end() < n2.end() + else n1.start() < n2.start() + ) + + def makeNode( + sym: Symbol, + pos: SourcePosition, + isDefinition: Boolean, + isDeclaration: Boolean + ): Node = + + var mod: Int = 0 + def addPwrToMod(tokenID: String) = + val place: Int = getModifierId(tokenID) + if place != -1 then mod += (1 << place) + // get Type + val typ = + if sym.is(Flags.Param) && !sym.isTypeParam + then + addPwrToMod(SemanticTokenModifiers.Readonly) + getTypeId(SemanticTokenTypes.Parameter) + else if sym.isTypeParam || sym.isSkolem then + getTypeId(SemanticTokenTypes.TypeParameter) + else if sym.is(Flags.Enum) || sym.isAllOf(Flags.EnumVal) + then getTypeId(SemanticTokenTypes.Enum) + else if sym.is(Flags.Trait) then + getTypeId(SemanticTokenTypes.Interface) // "interface" + else if sym.isClass then getTypeId(SemanticTokenTypes.Class) // "class" + else if sym.isType && !sym.is(Flags.Param) then + getTypeId(SemanticTokenTypes.Type) // "type" + else if sym.is(Flags.Mutable) then + getTypeId(SemanticTokenTypes.Variable) // "var" + else if sym.is(Flags.Package) then + getTypeId(SemanticTokenTypes.Namespace) // "package" + else if sym.is(Flags.Module) then + getTypeId(SemanticTokenTypes.Class) // "object" + else if sym.isRealMethod then + if sym.isGetter | sym.isSetter then + getTypeId(SemanticTokenTypes.Variable) + else getTypeId(SemanticTokenTypes.Method) // "def" + else if isPredefClass(sym) then + getTypeId(SemanticTokenTypes.Class) // "class" + else if sym.isTerm && + (!sym.is(Flags.Param) || sym.is(Flags.ParamAccessor)) + then + addPwrToMod(SemanticTokenModifiers.Readonly) + getTypeId(SemanticTokenTypes.Variable) // "val" + else -1 + + // Modifiers except by ReadOnly + if sym.is(Flags.Abstract) || sym.isAbstractOrParamType || + sym.isOneOf(Flags.AbstractOrTrait) + then addPwrToMod(SemanticTokenModifiers.Abstract) + if sym.annotations.exists(_.symbol.decodedName == "deprecated") + then addPwrToMod(SemanticTokenModifiers.Deprecated) + + if isDeclaration then addPwrToMod(SemanticTokenModifiers.Declaration) + if isDefinition then addPwrToMod(SemanticTokenModifiers.Definition) + + TokenNode(pos.start, pos.`end`, typ, mod) + end makeNode + + def isPredefClass(sym: Symbol)(using Context) = + sym.is(Flags.Method) && sym.info.resultType.typeSymbol.is(Flags.Module) + +end PcSemanticTokensProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/Scala3CompilerAccess.scala b/presentation-compiler/src/main/dotty/tools/pc/Scala3CompilerAccess.scala new file mode 100644 index 000000000000..ef5aaf4e5ed0 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/Scala3CompilerAccess.scala @@ -0,0 +1,39 @@ +package dotty.tools.pc + +import java.util.concurrent.ScheduledExecutorService + +import scala.concurrent.ExecutionContextExecutor +import scala.meta.internal.metals.ReportContext +import scala.meta.internal.pc.CompilerAccess +import scala.meta.pc.PresentationCompilerConfig + +import dotty.tools.dotc.reporting.StoreReporter + +class Scala3CompilerAccess( + config: PresentationCompilerConfig, + sh: Option[ScheduledExecutorService], + newCompiler: () => Scala3CompilerWrapper +)(using ec: ExecutionContextExecutor, rc: ReportContext) + extends CompilerAccess[StoreReporter, MetalsDriver]( + config, + sh, + newCompiler, + /* If running inside the executor, we need to reset the job queue + * Otherwise it will block indefinetely in case of infinite loops. + */ + shouldResetJobQueue = true + ): + + def newReporter = new StoreReporter(null) + + /** + * Handle the exception in order to make sure that + * we retry immediately. Otherwise, we will wait until + * the end of the timeout, which is 20s by default. + */ + protected def handleSharedCompilerException( + t: Throwable + ): Option[String] = None + + protected def ignoreException(t: Throwable): Boolean = false +end Scala3CompilerAccess diff --git a/presentation-compiler/src/main/dotty/tools/pc/Scala3CompilerWrapper.scala b/presentation-compiler/src/main/dotty/tools/pc/Scala3CompilerWrapper.scala new file mode 100644 index 000000000000..de4fb282edc9 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/Scala3CompilerWrapper.scala @@ -0,0 +1,28 @@ +package dotty.tools.pc + +import scala.meta.internal.pc.CompilerWrapper +import scala.meta.internal.pc.ReporterAccess + +import dotty.tools.dotc.reporting.StoreReporter + +class Scala3CompilerWrapper(driver: MetalsDriver) + extends CompilerWrapper[StoreReporter, MetalsDriver]: + + override def compiler(): MetalsDriver = driver + + override def resetReporter(): Unit = + val ctx = driver.currentCtx + ctx.reporter.removeBufferedMessages(using ctx) + + override def reporterAccess: ReporterAccess[StoreReporter] = + new ReporterAccess[StoreReporter]: + def reporter = driver.currentCtx.reporter.asInstanceOf[StoreReporter] + + override def askShutdown(): Unit = () + + override def isAlive(): Boolean = false + + override def stop(): Unit = {} + + override def presentationCompilerThread: Option[Thread] = None +end Scala3CompilerWrapper diff --git a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala new file mode 100644 index 000000000000..170a6418f2dc --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala @@ -0,0 +1,409 @@ +package dotty.tools.pc + +import java.io.File +import java.net.URI +import java.nio.file.Path +import java.util.Optional +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ExecutorService +import java.util.concurrent.ScheduledExecutorService +import java.util as ju + +import scala.concurrent.ExecutionContext +import scala.concurrent.ExecutionContextExecutor +import scala.jdk.CollectionConverters._ +import scala.meta.internal.metals.CompilerVirtualFileParams +import scala.meta.internal.metals.EmptyCancelToken +import scala.meta.internal.metals.EmptyReportContext +import scala.meta.internal.metals.ReportContext +import scala.meta.internal.metals.ReportLevel +import scala.meta.internal.metals.StdReportContext +import scala.meta.internal.pc.CompilerAccess +import scala.meta.internal.pc.DefinitionResultImpl +import scala.meta.internal.pc.EmptyCompletionList +import scala.meta.internal.pc.EmptySymbolSearch +import scala.meta.internal.pc.PresentationCompilerConfigImpl +import scala.meta.pc.* + +import dotty.tools.dotc.reporting.StoreReporter +import dotty.tools.pc.completions.CompletionProvider +import dotty.tools.pc.completions.OverrideCompletions +import dotty.tools.pc.util.BuildInfo + +import org.eclipse.lsp4j.DocumentHighlight +import org.eclipse.lsp4j.TextEdit +import org.eclipse.lsp4j as l + +case class ScalaPresentationCompiler( + buildTargetIdentifier: String = "", + classpath: Seq[Path] = Nil, + options: List[String] = Nil, + search: SymbolSearch = EmptySymbolSearch, + ec: ExecutionContextExecutor = ExecutionContext.global, + sh: Option[ScheduledExecutorService] = None, + config: PresentationCompilerConfig = PresentationCompilerConfigImpl(), + folderPath: Option[Path] = None, + reportsLevel: ReportLevel = ReportLevel.Info +) extends PresentationCompiler: + + def this() = this("", Nil, Nil) + + val scalaVersion = BuildInfo.scalaVersion + + private val forbiddenOptions = Set("-print-lines", "-print-tasty") + private val forbiddenDoubleOptions = Set("-release") + given ReportContext = + folderPath + .map(StdReportContext(_, reportsLevel)) + .getOrElse(EmptyReportContext) + + override def withReportsLoggerLevel(level: String): PresentationCompiler = + copy(reportsLevel = ReportLevel.fromString(level)) + + val compilerAccess: CompilerAccess[StoreReporter, MetalsDriver] = + Scala3CompilerAccess( + config, + sh, + () => new Scala3CompilerWrapper(newDriver) + )(using + ec + ) + + private def removeDoubleOptions(options: List[String]): List[String] = + options match + case head :: _ :: tail if forbiddenDoubleOptions(head) => + removeDoubleOptions(tail) + case head :: tail => head :: removeDoubleOptions(tail) + case Nil => options + + def newDriver: MetalsDriver = + val implicitSuggestionTimeout = List("-Ximport-suggestion-timeout", "0") + val defaultFlags = List("-color:never") + val filteredOptions = removeDoubleOptions( + options.filterNot(forbiddenOptions) + ) + val settings = + filteredOptions ::: defaultFlags ::: implicitSuggestionTimeout ::: "-classpath" :: classpath + .mkString( + File.pathSeparator + ) :: Nil + new MetalsDriver(settings) + + override def semanticTokens( + params: VirtualFileParams + ): CompletableFuture[ju.List[Node]] = + compilerAccess.withInterruptableCompiler(Some(params))( + new ju.ArrayList[Node](), + params.token() + ) { access => + val driver = access.compiler() + new PcSemanticTokensProvider(driver, params).provide().asJava + } + + override def getTasty( + targetUri: URI, + isHttpEnabled: Boolean + ): CompletableFuture[String] = + CompletableFuture.completedFuture { + TastyUtils.getTasty(targetUri, isHttpEnabled) + } + + def complete(params: OffsetParams): CompletableFuture[l.CompletionList] = + compilerAccess.withInterruptableCompiler(Some(params))( + EmptyCompletionList(), + params.token + ) { access => + val driver = access.compiler() + new CompletionProvider( + search, + driver, + params, + config, + buildTargetIdentifier, + folderPath + ).completions() + + } + + def definition(params: OffsetParams): CompletableFuture[DefinitionResult] = + compilerAccess.withInterruptableCompiler(Some(params))( + DefinitionResultImpl.empty, + params.token + ) { access => + val driver = access.compiler() + PcDefinitionProvider(driver, params, search).definitions() + } + + override def typeDefinition( + params: OffsetParams + ): CompletableFuture[DefinitionResult] = + compilerAccess.withInterruptableCompiler(Some(params))( + DefinitionResultImpl.empty, + params.token + ) { access => + val driver = access.compiler() + PcDefinitionProvider(driver, params, search).typeDefinitions() + } + + def documentHighlight( + params: OffsetParams + ): CompletableFuture[ju.List[DocumentHighlight]] = + compilerAccess.withInterruptableCompiler(Some(params))( + List.empty[DocumentHighlight].asJava, + params.token + ) { access => + val driver = access.compiler() + PcDocumentHighlightProvider(driver, params).highlights.asJava + } + + def shutdown(): Unit = + compilerAccess.shutdown() + + def restart(): Unit = + compilerAccess.shutdownCurrentCompiler() + + def diagnosticsForDebuggingPurposes(): ju.List[String] = + List[String]().asJava + + def semanticdbTextDocument( + filename: URI, + code: String + ): CompletableFuture[Array[Byte]] = + val virtualFile = CompilerVirtualFileParams(filename, code) + compilerAccess.withNonInterruptableCompiler(Some(virtualFile))( + Array.empty[Byte], + EmptyCancelToken + ) { access => + val driver = access.compiler() + val provider = SemanticdbTextDocumentProvider(driver, folderPath) + provider.textDocument(filename, code) + } + + def completionItemResolve( + item: l.CompletionItem, + symbol: String + ): CompletableFuture[l.CompletionItem] = + compilerAccess.withNonInterruptableCompiler(None)( + item, + EmptyCancelToken + ) { access => + val driver = access.compiler() + CompletionItemResolver.resolve(item, symbol, search, config)(using + driver.currentCtx + ) + } + + def autoImports( + name: String, + params: scala.meta.pc.OffsetParams, + isExtension: java.lang.Boolean + ): CompletableFuture[ + ju.List[scala.meta.pc.AutoImportsResult] + ] = + compilerAccess.withNonInterruptableCompiler(Some(params))( + List.empty[scala.meta.pc.AutoImportsResult].asJava, + params.token + ) { access => + val driver = access.compiler() + new AutoImportsProvider( + search, + driver, + name, + params, + config, + buildTargetIdentifier + ) + .autoImports(isExtension) + .asJava + } + + def implementAbstractMembers( + params: OffsetParams + ): CompletableFuture[ju.List[l.TextEdit]] = + val empty: ju.List[l.TextEdit] = new ju.ArrayList[l.TextEdit]() + compilerAccess.withNonInterruptableCompiler(Some(params))( + empty, + params.token + ) { pc => + val driver = pc.compiler() + OverrideCompletions.implementAllAt( + params, + driver, + search, + config + ) + } + end implementAbstractMembers + + override def insertInferredType( + params: OffsetParams + ): CompletableFuture[ju.List[l.TextEdit]] = + val empty: ju.List[l.TextEdit] = new ju.ArrayList[l.TextEdit]() + compilerAccess.withNonInterruptableCompiler(Some(params))( + empty, + params.token + ) { pc => + new InferredTypeProvider(params, pc.compiler(), config, search) + .inferredTypeEdits() + .asJava + } + + override def inlineValue( + params: OffsetParams + ): CompletableFuture[ju.List[l.TextEdit]] = + val empty: Either[String, List[l.TextEdit]] = Right(List()) + (compilerAccess + .withInterruptableCompiler(Some(params))(empty, params.token) { pc => + new PcInlineValueProviderImpl(pc.compiler(), params) + .getInlineTextEdits() + }) + .thenApply { + case Right(edits: List[TextEdit]) => edits.asJava + case Left(error: String) => throw new DisplayableException(error) + } + end inlineValue + + override def extractMethod( + range: RangeParams, + extractionPos: OffsetParams + ): CompletableFuture[ju.List[l.TextEdit]] = + val empty: ju.List[l.TextEdit] = new ju.ArrayList[l.TextEdit]() + compilerAccess.withInterruptableCompiler(Some(range))(empty, range.token) { + pc => + new ExtractMethodProvider( + range, + extractionPos, + pc.compiler(), + search, + options.contains("-no-indent"), + ) + .extractMethod() + .asJava + } + end extractMethod + + override def convertToNamedArguments( + params: OffsetParams, + argIndices: ju.List[Integer] + ): CompletableFuture[ju.List[l.TextEdit]] = + val empty: Either[String, List[l.TextEdit]] = Right(List()) + (compilerAccess + .withNonInterruptableCompiler(Some(params))(empty, params.token) { pc => + new ConvertToNamedArgumentsProvider( + pc.compiler(), + params, + argIndices.asScala.map(_.toInt).toSet + ).convertToNamedArguments + }) + .thenApplyAsync { + case Left(error: String) => throw new DisplayableException(error) + case Right(edits: List[l.TextEdit]) => edits.asJava + } + end convertToNamedArguments + override def selectionRange( + params: ju.List[OffsetParams] + ): CompletableFuture[ju.List[l.SelectionRange]] = + CompletableFuture.completedFuture { + compilerAccess.withSharedCompiler(params.asScala.headOption)( + List.empty[l.SelectionRange].asJava + ) { pc => + new SelectionRangeProvider( + pc.compiler(), + params, + ).selectionRange().asJava + } + } + end selectionRange + + def hover( + params: OffsetParams + ): CompletableFuture[ju.Optional[HoverSignature]] = + compilerAccess.withNonInterruptableCompiler(Some(params))( + ju.Optional.empty[HoverSignature](), + params.token + ) { access => + val driver = access.compiler() + HoverProvider.hover(params, driver, search) + } + end hover + + def prepareRename( + params: OffsetParams + ): CompletableFuture[ju.Optional[l.Range]] = + compilerAccess.withNonInterruptableCompiler(Some(params))( + Optional.empty[l.Range](), + params.token + ) { access => + val driver = access.compiler() + Optional.ofNullable( + PcRenameProvider(driver, params, None).prepareRename().orNull + ) + } + + def rename( + params: OffsetParams, + name: String + ): CompletableFuture[ju.List[l.TextEdit]] = + compilerAccess.withNonInterruptableCompiler(Some(params))( + List[l.TextEdit]().asJava, + params.token + ) { access => + val driver = access.compiler() + PcRenameProvider(driver, params, Some(name)).rename().asJava + } + + def newInstance( + buildTargetIdentifier: String, + classpath: ju.List[Path], + options: ju.List[String] + ): PresentationCompiler = + copy( + buildTargetIdentifier = buildTargetIdentifier, + classpath = classpath.asScala.toSeq, + options = options.asScala.toList + ) + + def signatureHelp(params: OffsetParams): CompletableFuture[l.SignatureHelp] = + compilerAccess.withNonInterruptableCompiler(Some(params))( + new l.SignatureHelp(), + params.token + ) { access => + val driver = access.compiler() + SignatureHelpProvider.signatureHelp(driver, params, search) + } + + override def didChange( + params: VirtualFileParams + ): CompletableFuture[ju.List[l.Diagnostic]] = + CompletableFuture.completedFuture(Nil.asJava) + + override def didClose(uri: URI): Unit = + compilerAccess.withNonInterruptableCompiler(None)( + (), + EmptyCancelToken + ) { access => access.compiler().close(uri) } + + override def withExecutorService( + executorService: ExecutorService + ): PresentationCompiler = + copy(ec = ExecutionContext.fromExecutorService(executorService)) + + override def withConfiguration( + config: PresentationCompilerConfig + ): PresentationCompiler = + copy(config = config) + + override def withScheduledExecutorService( + sh: ScheduledExecutorService + ): PresentationCompiler = + copy(sh = Some(sh)) + + def withSearch(search: SymbolSearch): PresentationCompiler = + copy(search = search) + + def withWorkspace(workspace: Path): PresentationCompiler = + copy(folderPath = Some(workspace)) + + override def isLoaded() = compilerAccess.isLoaded() + +end ScalaPresentationCompiler diff --git a/presentation-compiler/src/main/dotty/tools/pc/ScriptFirstImportPosition.scala b/presentation-compiler/src/main/dotty/tools/pc/ScriptFirstImportPosition.scala new file mode 100644 index 000000000000..2bb8023cee08 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/ScriptFirstImportPosition.scala @@ -0,0 +1,44 @@ +package dotty.tools.pc + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Comments.Comment + +object ScriptFirstImportPosition: + + val usingDirectives: List[String] = List("// using", "//> using") + val ammHeaders: List[String] = List("// scala", "// ammonite") + + def ammoniteScStartOffset( + text: String, + comments: List[Comment] + ): Option[Int] = + findStartOffset(text, comments, commentQuery = "/**/", ammHeaders) + + def scalaCliScStartOffset( + text: String, + comments: List[Comment] + ): Option[Int] = + findStartOffset( + text, + comments, + commentQuery = "/*