Skip to content

Overhaul workspace search for symbol references #1917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/PowerShellEditorServices/Server/PsesLanguageServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public async Task StartAsync()
.WithHandler<ShowHelpHandler>()
.WithHandler<ExpandAliasHandler>()
.WithHandler<PsesSemanticTokensHandler>()
.WithHandler<DidChangeWatchedFilesHandler>()
// NOTE: The OnInitialize delegate gets run when we first receive the
// _Initialize_ request:
// https://microsoft.github.io/language-server-protocol/specifications/specification-current/#initialize
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Management.Automation;
Expand Down Expand Up @@ -57,6 +58,39 @@ public record struct AliasMap(
internal static readonly ConcurrentDictionary<string, List<string>> s_cmdletToAliasCache = new(System.StringComparer.OrdinalIgnoreCase);
internal static readonly ConcurrentDictionary<string, string> s_aliasToCmdletCache = new(System.StringComparer.OrdinalIgnoreCase);

/// <summary>
/// Gets the actual command behind a fully module qualified command invocation, e.g.
/// <c>Microsoft.PowerShell.Management\Get-ChildItem</c> will return <c>Get-ChilddItem</c>
/// </summary>
/// <param name="invocationName">
/// The potentially module qualified command name at the site of invocation.
/// </param>
/// <param name="moduleName">
/// A reference that will contain the module name if the invocation is module qualified.
/// </param>
/// <returns>The actual command name.</returns>
public static string StripModuleQualification(string invocationName, out ReadOnlyMemory<char> moduleName)
{
int slashIndex = invocationName.IndexOf('\\');
if (slashIndex is -1)
{
moduleName = default;
return invocationName;
}

// If '\' is the last character then it's probably not a module qualified command.
if (slashIndex == invocationName.Length - 1)
{
moduleName = default;
return invocationName;
}

// Storing moduleName as ROMemory safes a string allocation in the common case where it
// is not needed.
moduleName = invocationName.AsMemory().Slice(0, slashIndex);
return invocationName.Substring(slashIndex + 1);
}

/// <summary>
/// Gets the CommandInfo instance for a command with a particular name.
/// </summary>
Expand Down
100 changes: 100 additions & 0 deletions src/PowerShellEditorServices/Services/Symbols/ReferenceTable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#nullable enable

using System;
using System.Collections.Concurrent;
using System.Management.Automation.Language;
using Microsoft.PowerShell.EditorServices.Services.TextDocument;
using Microsoft.PowerShell.EditorServices.Services.PowerShell.Utility;

namespace Microsoft.PowerShell.EditorServices.Services;

/// <summary>
/// Represents the symbols that are referenced and their locations within a single document.
/// </summary>
internal sealed class ReferenceTable
{
private readonly ScriptFile _parent;

private readonly ConcurrentDictionary<string, ConcurrentBag<IScriptExtent>> _symbolReferences = new(StringComparer.OrdinalIgnoreCase);

private bool _isInited;

public ReferenceTable(ScriptFile parent) => _parent = parent;

/// <summary>
/// Clears the reference table causing it to rescan the source AST when queried.
/// </summary>
public void TagAsChanged()
{
_symbolReferences.Clear();
_isInited = false;
}

// Prefer checking if the dictionary has contents to determine if initialized. The field
// `_isInited` is to guard against rescanning files with no command references, but will
// generally be less reliable of a check.
private bool IsInitialized => !_symbolReferences.IsEmpty || _isInited;

internal bool TryGetReferences(string command, out ConcurrentBag<IScriptExtent>? references)
{
EnsureInitialized();
return _symbolReferences.TryGetValue(command, out references);
}

internal void EnsureInitialized()
{
if (IsInitialized)
{
return;
}

_parent.ScriptAst.Visit(new ReferenceVisitor(this));
}

private void AddReference(string symbol, IScriptExtent extent)
{
_symbolReferences.AddOrUpdate(
symbol,
_ => new ConcurrentBag<IScriptExtent> { extent },
(_, existing) =>
{
existing.Add(extent);
return existing;
});
}

private sealed class ReferenceVisitor : AstVisitor
{
private readonly ReferenceTable _references;

public ReferenceVisitor(ReferenceTable references) => _references = references;

public override AstVisitAction VisitCommand(CommandAst commandAst)
{
string commandName = commandAst.GetCommandName();
if (string.IsNullOrEmpty(commandName))
{
return AstVisitAction.Continue;
}

_references.AddReference(
CommandHelpers.StripModuleQualification(commandName, out _),
commandAst.CommandElements[0].Extent);
return AstVisitAction.Continue;
}

public override AstVisitAction VisitVariableExpression(VariableExpressionAst variableExpressionAst)
{
// TODO: Consider tracking unscoped variable references only when they declared within
// the same function definition.
_references.AddReference(
$"${variableExpressionAst.VariablePath.UserPath}",
variableExpressionAst.Extent);

return AstVisitAction.Continue;
}
}
}
172 changes: 137 additions & 35 deletions src/PowerShellEditorServices/Services/Symbols/SymbolsService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ internal class SymbolsService

private readonly ConcurrentDictionary<string, ICodeLensProvider> _codeLensProviders;
private readonly ConcurrentDictionary<string, IDocumentSymbolProvider> _documentSymbolProviders;
private readonly ConfigurationService _configurationService;
#endregion

#region Constructors
Expand All @@ -65,6 +66,7 @@ public SymbolsService(
_runspaceContext = runspaceContext;
_executionService = executionService;
_workspaceService = workspaceService;
_configurationService = configurationService;

_codeLensProviders = new ConcurrentDictionary<string, ICodeLensProvider>();
if (configurationService.CurrentSettings.EnableReferencesCodeLens)
Expand Down Expand Up @@ -177,8 +179,15 @@ public async Task<List<SymbolReference>> FindReferencesOfSymbol(
_executionService,
cancellationToken).ConfigureAwait(false);

Dictionary<string, List<string>> cmdletToAliases = aliases.CmdletToAliases;
Dictionary<string, string> aliasToCmdlets = aliases.AliasToCmdlets;
string targetName = foundSymbol.SymbolName;
if (foundSymbol.SymbolType is SymbolType.Function)
{
targetName = CommandHelpers.StripModuleQualification(targetName, out _);
if (aliases.AliasToCmdlets.TryGetValue(foundSymbol.SymbolName, out string aliasDefinition))
{
targetName = aliasDefinition;
}
}

// We want to look for references first in referenced files, hence we use ordered dictionary
// TODO: File system case-sensitivity is based on filesystem not OS, but OS is a much cheaper heuristic
Expand All @@ -191,52 +200,63 @@ public async Task<List<SymbolReference>> FindReferencesOfSymbol(
fileMap[scriptFile.FilePath] = scriptFile;
}

foreach (string filePath in workspace.EnumeratePSFiles())
await ScanWorkspacePSFiles(cancellationToken).ConfigureAwait(false);

List<SymbolReference> symbolReferences = new();

// Using a nested method here to get a bit more readability and to avoid roslynator
// asserting we should use a giant nested ternary here.
static string[] GetIdentifiers(string symbolName, SymbolType symbolType, CommandHelpers.AliasMap aliases)
{
if (!fileMap.Contains(filePath))
if (symbolType is not SymbolType.Function)
{
// This async method is pretty dense with synchronous code
// so it's helpful to add some yields.
await Task.Yield();
cancellationToken.ThrowIfCancellationRequested();
if (!workspace.TryGetFile(filePath, out ScriptFile scriptFile))
{
// If we can't access the file for some reason, just ignore it
continue;
}
return new[] { symbolName };
}

fileMap[filePath] = scriptFile;
if (!aliases.CmdletToAliases.TryGetValue(symbolName, out List<string> foundAliasList))
{
return new[] { symbolName };
}

return foundAliasList.Prepend(symbolName)
.Distinct(StringComparer.OrdinalIgnoreCase)
.ToArray();
}

List<SymbolReference> symbolReferences = new();
foreach (object fileName in fileMap.Keys)
{
ScriptFile file = (ScriptFile)fileMap[fileName];
string[] allIdentifiers = GetIdentifiers(targetName, foundSymbol.SymbolType, aliases);

IEnumerable<SymbolReference> references = AstOperations.FindReferencesOfSymbol(
file.ScriptAst,
foundSymbol,
cmdletToAliases,
aliasToCmdlets);

foreach (SymbolReference reference in references)
foreach (ScriptFile file in _workspaceService.GetOpenedFiles())
{
foreach (string targetIdentifier in allIdentifiers)
{
try
if (!file.References.TryGetReferences(targetIdentifier, out ConcurrentBag<IScriptExtent> references))
{
reference.SourceLine = file.GetLine(reference.ScriptRegion.StartLineNumber);
continue;
}
catch (ArgumentOutOfRangeException e)

foreach (IScriptExtent extent in references)
{
reference.SourceLine = string.Empty;
_logger.LogException("Found reference is out of range in script file", e);
SymbolReference reference = new(
SymbolType.Function,
foundSymbol.SymbolName,
extent);

try
{
reference.SourceLine = file.GetLine(reference.ScriptRegion.StartLineNumber);
}
catch (ArgumentOutOfRangeException e)
{
reference.SourceLine = string.Empty;
_logger.LogException("Found reference is out of range in script file", e);
}
reference.FilePath = file.FilePath;
symbolReferences.Add(reference);
}
reference.FilePath = file.FilePath;
symbolReferences.Add(reference);
}

await Task.Yield();
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
cancellationToken.ThrowIfCancellationRequested();
}
}

return symbolReferences;
Expand Down Expand Up @@ -495,6 +515,59 @@ await CommandHelpers.GetCommandInfoAsync(
return foundDefinition;
}

private Task _workspaceScanCompleted;

private async Task ScanWorkspacePSFiles(CancellationToken cancellationToken = default)
{
if (_configurationService.CurrentSettings.AnalyzeOpenDocumentsOnly)
{
return;
}

Task scanTask = _workspaceScanCompleted;
// It's not impossible for two scans to start at once but it should be exceedingly
// unlikely, and shouldn't break anything if it happens to. So we can save some
// lock time by accepting that possibility.
if (scanTask is null)
{
scanTask = Task.Run(
() =>
{
foreach (string file in _workspaceService.EnumeratePSFiles())
{
if (_workspaceService.TryGetFile(file, out ScriptFile scriptFile))
{
scriptFile.References.EnsureInitialized();
}
}
},
CancellationToken.None);

// Ignore the analyzer yelling that we're not awaiting this task, we'll get there.
#pragma warning disable CS4014
Interlocked.CompareExchange(ref _workspaceScanCompleted, scanTask, null);
#pragma warning restore CS4014
}

// In the simple case where the task is already completed or the token we're given cannot
// be cancelled, do a simple await.
if (scanTask.IsCompleted || !cancellationToken.CanBeCanceled)
{
await scanTask.ConfigureAwait(false);
return;
}

// If it's not yet done and we can be cancelled, create a new task to represent the
// cancellation. That way we can exit a request that relies on the scan without
// having to actually stop the work (and then request it again in a few seconds).
//
// TODO: There's a new API in net6 that lets you await a task with a cancellation token.
// we should #if that in if feasible.
TaskCompletionSource<bool> cancelled = new();
cancellationToken.Register(() => cancelled.TrySetCanceled());
await Task.WhenAny(scanTask, cancelled.Task).ConfigureAwait(false);
}

/// <summary>
/// Gets a path from a dot-source symbol.
/// </summary>
Expand Down Expand Up @@ -673,6 +746,35 @@ public static FunctionDefinitionAst GetFunctionDefinitionAtLine(

internal void OnConfigurationUpdated(object _, LanguageServerSettings e)
{
if (e.AnalyzeOpenDocumentsOnly)
{
Task scanInProgress = _workspaceScanCompleted;
if (scanInProgress is not null)
{
// Wait until after the scan completes to close unopened files.
_ = scanInProgress.ContinueWith(_ => CloseUnopenedFiles(), TaskScheduler.Default);
}
else
{
CloseUnopenedFiles();
}

_workspaceScanCompleted = null;

void CloseUnopenedFiles()
{
foreach (ScriptFile scriptFile in _workspaceService.GetOpenedFiles())
{
if (scriptFile.IsOpen)
{
continue;
}

_workspaceService.CloseFile(scriptFile);
}
}
}

if (e.EnableReferencesCodeLens)
{
if (_codeLensProviders.ContainsKey(ReferencesCodeLensProvider.Id))
Expand Down
Loading