Skip to content

Commit b4dd475

Browse files
committed
Clean up and fix FindReferencesVisitor
This consolidated the constructors so we don't have some unused (and confusing) boolean differentiating them, and instead just provide one constructor with default values. Also clean up `AstOperationsTests`.
1 parent 7131359 commit b4dd475

File tree

4 files changed

+80
-119
lines changed

4 files changed

+80
-119
lines changed

src/PowerShellEditorServices/Services/Symbols/SymbolsService.cs

+1-4
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,7 @@ public static IReadOnlyList<SymbolReference> FindOccurrencesInFile(
264264
return null;
265265
}
266266

267-
return AstOperations.FindReferencesOfSymbol(
268-
file.ScriptAst,
269-
foundSymbol,
270-
needsAliases: false).ToArray();
267+
return AstOperations.FindReferencesOfSymbol(file.ScriptAst, foundSymbol).ToArray();
271268
}
272269

273270
/// <summary>

src/PowerShellEditorServices/Services/Symbols/Vistors/AstOperations.cs

+9-29
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Concurrent;
56
using System.Collections.Generic;
67
using System.Diagnostics;
78
using System.Linq;
@@ -153,42 +154,21 @@ public static SymbolReference FindCommandAtPosition(Ast scriptAst, int lineNumbe
153154
/// </summary>
154155
/// <param name="scriptAst">The abstract syntax tree of the given script</param>
155156
/// <param name="symbolReference">The symbol that we are looking for referneces of</param>
156-
/// <param name="CmdletToAliasDictionary">Dictionary maping cmdlets to aliases for finding alias references</param>
157-
/// <param name="AliasToCmdletDictionary">Dictionary maping aliases to cmdlets for finding alias references</param>
157+
/// <param name="cmdletToAliasDictionary">Dictionary maping cmdlets to aliases for finding alias references</param>
158+
/// <param name="aliasToCmdletDictionary">Dictionary maping aliases to cmdlets for finding alias references</param>
158159
/// <returns></returns>
159160
public static IEnumerable<SymbolReference> FindReferencesOfSymbol(
160161
Ast scriptAst,
161162
SymbolReference symbolReference,
162-
Dictionary<String, List<String>> CmdletToAliasDictionary,
163-
Dictionary<String, String> AliasToCmdletDictionary)
163+
ConcurrentDictionary<string, List<string>> cmdletToAliasDictionary = default,
164+
ConcurrentDictionary<string, string> aliasToCmdletDictionary = default)
164165
{
165166
// find the symbol evaluators for the node types we are handling
166-
FindReferencesVisitor referencesVisitor =
167-
new FindReferencesVisitor(
168-
symbolReference,
169-
CmdletToAliasDictionary,
170-
AliasToCmdletDictionary);
171-
scriptAst.Visit(referencesVisitor);
167+
FindReferencesVisitor referencesVisitor = new(
168+
symbolReference,
169+
cmdletToAliasDictionary,
170+
aliasToCmdletDictionary);
172171

173-
return referencesVisitor.FoundReferences;
174-
}
175-
176-
/// <summary>
177-
/// Finds all references (not including aliases) in a script for the given symbol
178-
/// </summary>
179-
/// <param name="scriptAst">The abstract syntax tree of the given script</param>
180-
/// <param name="foundSymbol">The symbol that we are looking for referneces of</param>
181-
/// <param name="needsAliases">If this reference search needs aliases.
182-
/// This should always be false and used for occurence requests</param>
183-
/// <returns>A collection of SymbolReference objects that are refrences to the symbolRefrence
184-
/// not including aliases</returns>
185-
public static IEnumerable<SymbolReference> FindReferencesOfSymbol(
186-
ScriptBlockAst scriptAst,
187-
SymbolReference foundSymbol,
188-
bool needsAliases)
189-
{
190-
FindReferencesVisitor referencesVisitor =
191-
new FindReferencesVisitor(foundSymbol);
192172
scriptAst.Visit(referencesVisitor);
193173

194174
return referencesVisitor.FoundReferences;

src/PowerShellEditorServices/Services/Symbols/Vistors/FindReferencesVisitor.cs

+64-76
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Concurrent;
56
using System.Collections.Generic;
67
using System.Management.Automation.Language;
78

@@ -12,48 +13,47 @@ namespace Microsoft.PowerShell.EditorServices.Services.Symbols
1213
/// </summary>
1314
internal class FindReferencesVisitor : AstVisitor
1415
{
15-
private SymbolReference symbolRef;
16-
private Dictionary<String, List<String>> CmdletToAliasDictionary;
17-
private Dictionary<String, String> AliasToCmdletDictionary;
18-
private string symbolRefCommandName;
19-
private bool needsAliases;
16+
private readonly SymbolReference _symbolRef;
17+
private readonly ConcurrentDictionary<string, List<string>> _cmdletToAliasDictionary;
18+
private readonly ConcurrentDictionary<string, string> _aliasToCmdletDictionary;
19+
private readonly string _symbolRefCommandName;
20+
private readonly bool _needsAliases;
2021

2122
public List<SymbolReference> FoundReferences { get; set; }
2223

2324
/// <summary>
2425
/// Constructor used when searching for aliases is needed
2526
/// </summary>
2627
/// <param name="symbolReference">The found symbolReference that other symbols are being compared to</param>
27-
/// <param name="CmdletToAliasDictionary">Dictionary maping cmdlets to aliases for finding alias references</param>
28-
/// <param name="AliasToCmdletDictionary">Dictionary maping aliases to cmdlets for finding alias references</param>
28+
/// <param name="cmdletToAliasDictionary">Dictionary maping cmdlets to aliases for finding alias references</param>
29+
/// <param name="aliasToCmdletDictionary">Dictionary maping aliases to cmdlets for finding alias references</param>
2930
public FindReferencesVisitor(
3031
SymbolReference symbolReference,
31-
Dictionary<String, List<String>> CmdletToAliasDictionary,
32-
Dictionary<String, String> AliasToCmdletDictionary)
32+
ConcurrentDictionary<string, List<string>> cmdletToAliasDictionary = default,
33+
ConcurrentDictionary<string, string> aliasToCmdletDictionary = default)
3334
{
34-
this.symbolRef = symbolReference;
35-
this.FoundReferences = new List<SymbolReference>();
36-
this.needsAliases = true;
37-
this.CmdletToAliasDictionary = CmdletToAliasDictionary;
38-
this.AliasToCmdletDictionary = AliasToCmdletDictionary;
39-
40-
// Try to get the symbolReference's command name of an alias,
41-
// if a command name does not exists (if the symbol isn't an alias to a command)
42-
// set symbolRefCommandName to and empty string value
43-
AliasToCmdletDictionary.TryGetValue(symbolReference.ScriptRegion.Text, out symbolRefCommandName);
44-
if (symbolRefCommandName == null) { symbolRefCommandName = string.Empty; }
35+
_symbolRef = symbolReference;
36+
FoundReferences = new List<SymbolReference>();
4537

46-
}
38+
if (cmdletToAliasDictionary is null || aliasToCmdletDictionary is null)
39+
{
40+
_needsAliases = false;
41+
return;
42+
}
4743

48-
/// <summary>
49-
/// Constructor used when searching for aliases is not needed
50-
/// </summary>
51-
/// <param name="foundSymbol">The found symbolReference that other symbols are being compared to</param>
52-
public FindReferencesVisitor(SymbolReference foundSymbol)
53-
{
54-
this.symbolRef = foundSymbol;
55-
this.FoundReferences = new List<SymbolReference>();
56-
this.needsAliases = false;
44+
_needsAliases = true;
45+
_cmdletToAliasDictionary = cmdletToAliasDictionary;
46+
_aliasToCmdletDictionary = aliasToCmdletDictionary;
47+
48+
// Try to get the symbolReference's command name of an alias. If a command name does not
49+
// exists (if the symbol isn't an alias to a command) set symbolRefCommandName to an
50+
// empty string.
51+
aliasToCmdletDictionary.TryGetValue(symbolReference.ScriptRegion.Text, out _symbolRefCommandName);
52+
53+
if (_symbolRefCommandName == null)
54+
{
55+
_symbolRefCommandName = string.Empty;
56+
}
5757
}
5858

5959
/// <summary>
@@ -68,50 +68,44 @@ public override AstVisitAction VisitCommand(CommandAst commandAst)
6868
Ast commandNameAst = commandAst.CommandElements[0];
6969
string commandName = commandNameAst.Extent.Text;
7070

71-
if(symbolRef.SymbolType.Equals(SymbolType.Function))
71+
if (_symbolRef.SymbolType.Equals(SymbolType.Function))
7272
{
73-
if (needsAliases)
73+
if (_needsAliases)
7474
{
75-
// Try to get the commandAst's name and aliases,
76-
// if a command does not exists (if the symbol isn't an alias to a command)
77-
// set command to and empty string value string command
78-
// if the aliases do not exist (if the symvol isn't a command that has aliases)
75+
// Try to get the commandAst's name and aliases.
76+
//
77+
// If a command does not exist (if the symbol isn't an alias to a command) set
78+
// command to an empty string value string command.
79+
//
80+
// If the aliases do not exist (if the symbol isn't a command that has aliases)
7981
// set aliases to an empty List<string>
80-
string command;
81-
List<string> alaises;
82-
CmdletToAliasDictionary.TryGetValue(commandName, out alaises);
83-
AliasToCmdletDictionary.TryGetValue(commandName, out command);
84-
if (alaises == null) { alaises = new List<string>(); }
82+
_cmdletToAliasDictionary.TryGetValue(commandName, out List<string> aliases);
83+
_aliasToCmdletDictionary.TryGetValue(commandName, out string command);
84+
if (aliases == null) { aliases = new List<string>(); }
8585
if (command == null) { command = string.Empty; }
8686

87-
if (symbolRef.SymbolType.Equals(SymbolType.Function))
87+
// Check if the found symbol's name is the same as the commandAst's name OR
88+
// if the symbol's name is an alias for this commandAst's name (commandAst is a cmdlet) OR
89+
// if the symbol's name is the same as the commandAst's cmdlet name (commandAst is a alias)
90+
if (commandName.Equals(_symbolRef.SymbolName, StringComparison.OrdinalIgnoreCase)
91+
// Note that PowerShell command names and aliases are case insensitive.
92+
|| aliases.Exists((match) => string.Equals(match, _symbolRef.ScriptRegion.Text, StringComparison.OrdinalIgnoreCase))
93+
|| command.Equals(_symbolRef.ScriptRegion.Text, StringComparison.OrdinalIgnoreCase)
94+
|| (!string.IsNullOrEmpty(command)
95+
&& command.Equals(_symbolRefCommandName, StringComparison.OrdinalIgnoreCase)))
8896
{
89-
// Check if the found symbol's name is the same as the commandAst's name OR
90-
// if the symbol's name is an alias for this commandAst's name (commandAst is a cmdlet) OR
91-
// if the symbol's name is the same as the commandAst's cmdlet name (commandAst is a alias)
92-
if (commandName.Equals(symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase) ||
93-
alaises.Contains(symbolRef.ScriptRegion.Text.ToLower()) ||
94-
command.Equals(symbolRef.ScriptRegion.Text, StringComparison.CurrentCultureIgnoreCase) ||
95-
(!string.IsNullOrEmpty(command) && command.Equals(symbolRefCommandName, StringComparison.CurrentCultureIgnoreCase)))
96-
{
97-
this.FoundReferences.Add(new SymbolReference(
98-
SymbolType.Function,
99-
commandNameAst.Extent));
100-
}
97+
FoundReferences.Add(new SymbolReference(SymbolType.Function, commandNameAst.Extent));
10198
}
102-
10399
}
104100
else // search does not include aliases
105101
{
106-
if (commandName.Equals(symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase))
102+
if (commandName.Equals(_symbolRef.SymbolName, StringComparison.OrdinalIgnoreCase))
107103
{
108-
this.FoundReferences.Add(new SymbolReference(
109-
SymbolType.Function,
110-
commandNameAst.Extent));
104+
FoundReferences.Add(new SymbolReference(SymbolType.Function, commandNameAst.Extent));
111105
}
112106
}
113-
114107
}
108+
115109
return base.VisitCommand(commandAst);
116110
}
117111

@@ -135,12 +129,10 @@ public override AstVisitAction VisitFunctionDefinition(FunctionDefinitionAst fun
135129
File = functionDefinitionAst.Extent.File
136130
};
137131

138-
if (symbolRef.SymbolType.Equals(SymbolType.Function) &&
139-
nameExtent.Text.Equals(symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase))
132+
if (_symbolRef.SymbolType.Equals(SymbolType.Function) &&
133+
nameExtent.Text.Equals(_symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase))
140134
{
141-
this.FoundReferences.Add(new SymbolReference(
142-
SymbolType.Function,
143-
nameExtent));
135+
FoundReferences.Add(new SymbolReference(SymbolType.Function, nameExtent));
144136
}
145137
return base.VisitFunctionDefinition(functionDefinitionAst);
146138
}
@@ -153,12 +145,10 @@ public override AstVisitAction VisitFunctionDefinition(FunctionDefinitionAst fun
153145
/// <returns>A visit action that continues the search for references</returns>
154146
public override AstVisitAction VisitCommandParameter(CommandParameterAst commandParameterAst)
155147
{
156-
if (symbolRef.SymbolType.Equals(SymbolType.Parameter) &&
157-
commandParameterAst.Extent.Text.Equals(symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase))
148+
if (_symbolRef.SymbolType.Equals(SymbolType.Parameter) &&
149+
commandParameterAst.Extent.Text.Equals(_symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase))
158150
{
159-
this.FoundReferences.Add(new SymbolReference(
160-
SymbolType.Parameter,
161-
commandParameterAst.Extent));
151+
FoundReferences.Add(new SymbolReference(SymbolType.Parameter, commandParameterAst.Extent));
162152
}
163153
return AstVisitAction.Continue;
164154
}
@@ -171,12 +161,10 @@ public override AstVisitAction VisitCommandParameter(CommandParameterAst command
171161
/// <returns>A visit action that continues the search for references</returns>
172162
public override AstVisitAction VisitVariableExpression(VariableExpressionAst variableExpressionAst)
173163
{
174-
if(symbolRef.SymbolType.Equals(SymbolType.Variable) &&
175-
variableExpressionAst.Extent.Text.Equals(symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase))
164+
if (_symbolRef.SymbolType.Equals(SymbolType.Variable)
165+
&& variableExpressionAst.Extent.Text.Equals(_symbolRef.SymbolName, StringComparison.CurrentCultureIgnoreCase))
176166
{
177-
this.FoundReferences.Add(new SymbolReference(
178-
SymbolType.Variable,
179-
variableExpressionAst.Extent));
167+
FoundReferences.Add(new SymbolReference(SymbolType.Variable, variableExpressionAst.Extent));
180168
}
181169
return AstVisitAction.Continue;
182170
}
@@ -186,7 +174,7 @@ private static (int, int) GetStartColumnAndLineNumbersFromAst(FunctionDefinition
186174
{
187175
int startColumnNumber = ast.Extent.StartColumnNumber;
188176
int startLineNumber = ast.Extent.StartLineNumber;
189-
int astOffset = 0;
177+
int astOffset;
190178

191179
if (ast.IsFilter)
192180
{

test/PowerShellEditorServices.Test/Services/Symbols/AstOperationsTests.cs

+6-10
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
using Microsoft.PowerShell.EditorServices.Services.Symbols;
88
using OmniSharp.Extensions.LanguageServer.Protocol.Models;
99
using Xunit;
10-
using Xunit.Abstractions;
1110

1211
namespace Microsoft.PowerShell.EditorServices.Test.Services.Symbols
1312
{
13+
[Trait("Category", "AstOperations")]
1414
public class AstOperationsTests
1515
{
16-
private static string s_scriptString = @"function BasicFunction {}
16+
private const string s_scriptString = @"function BasicFunction {}
1717
BasicFunction
1818
1919
function FunctionWithExtraSpace
@@ -36,9 +36,8 @@ function FunctionWithExtraSpace
3636
3737
FunctionNameOnDifferentLine
3838
";
39-
private static ScriptBlockAst s_ast = (ScriptBlockAst) ScriptBlock.Create(s_scriptString).Ast;
39+
private static readonly ScriptBlockAst s_ast = (ScriptBlockAst) ScriptBlock.Create(s_scriptString).Ast;
4040

41-
[Trait("Category", "AstOperations")]
4241
[Theory]
4342
[InlineData(2, 3, "BasicFunction")]
4443
[InlineData(7, 18, "FunctionWithExtraSpace")]
@@ -50,14 +49,13 @@ public void CanFindSymbolAtPostion(int lineNumber, int columnNumber, string expe
5049
Assert.Equal(expectedName, reference.SymbolName);
5150
}
5251

53-
[Trait("Category", "AstOperations")]
5452
[Theory]
55-
[MemberData(nameof(FindReferencesOfSymbolAtPostionData), parameters: 3)]
53+
[MemberData(nameof(FindReferencesOfSymbolAtPostionData))]
5654
public void CanFindReferencesOfSymbolAtPostion(int lineNumber, int columnNumber, Position[] positions)
5755
{
5856
SymbolReference symbol = AstOperations.FindSymbolAtPosition(s_ast, lineNumber, columnNumber);
5957

60-
IEnumerable<SymbolReference> references = AstOperations.FindReferencesOfSymbol(s_ast, symbol, needsAliases: false);
58+
IEnumerable<SymbolReference> references = AstOperations.FindReferencesOfSymbol(s_ast, symbol);
6159

6260
int positionsIndex = 0;
6361
foreach (SymbolReference reference in references)
@@ -69,9 +67,7 @@ public void CanFindReferencesOfSymbolAtPostion(int lineNumber, int columnNumber,
6967
}
7068
}
7169

72-
public static object[][] FindReferencesOfSymbolAtPostionData => s_findReferencesOfSymbolAtPostionData;
73-
74-
private static readonly object[][] s_findReferencesOfSymbolAtPostionData = new object[][]
70+
public static object[][] FindReferencesOfSymbolAtPostionData { get; } = new object[][]
7571
{
7672
new object[] { 2, 3, new[] { new Position(1, 10), new Position(2, 1) } },
7773
new object[] { 7, 18, new[] { new Position(4, 19), new Position(7, 3) } },

0 commit comments

Comments
 (0)