Skip to content

PsShouldProcess - ShouldContinue not included #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions Rules/UseShouldProcessCorrectly.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ public class UseShouldProcessCorrectly : IScriptRule
private FunctionReferenceDigraph funcDigraph;
private List<DiagnosticRecord> diagnosticRecords;
private readonly Vertex shouldProcessVertex;
private readonly Vertex shouldContinueVertex;
private readonly Vertex implicitShouldProcessVertex;

public UseShouldProcessCorrectly()
{
diagnosticRecords = new List<DiagnosticRecord>();
shouldProcessVertex = new Vertex("ShouldProcess", null);
shouldContinueVertex = new Vertex("ShouldContinue", null);
implicitShouldProcessVertex = new Vertex("implicitShouldProcessVertex", null);
}

Expand Down Expand Up @@ -151,8 +153,10 @@ private DiagnosticRecord GetViolation(Vertex v)
if (DeclaresSupportsShouldProcess(fast))
{
bool callsShouldProcess = funcDigraph.IsConnected(v, shouldProcessVertex);
bool callsShouldContinue = funcDigraph.IsConnected(v, shouldContinueVertex);
bool callsCommandWithShouldProcess = funcDigraph.IsConnected(v, implicitShouldProcessVertex);
if (!callsShouldProcess
&& !callsShouldContinue
&& !callsCommandWithShouldProcess)
{
return new DiagnosticRecord(
Expand All @@ -168,7 +172,8 @@ private DiagnosticRecord GetViolation(Vertex v)
}
else
{
if (callsShouldProcessDirectly(v))
bool callsShouldProc = callsShouldProcessDirectly(v);
if (callsShouldProc || callsShouldContinueDirectly(v))
{
// check if upstream function declares SupportShouldProcess
// if so, this might just be a helper function
Expand All @@ -179,14 +184,15 @@ private DiagnosticRecord GetViolation(Vertex v)
}

return new DiagnosticRecord(
string.Format(
CultureInfo.CurrentCulture,
Strings.ShouldProcessErrorHasCmdlet,
fast.Name),
GetShouldProcessCallExtent(fast),
string.Format(
CultureInfo.CurrentCulture,
Strings.ShouldProcessErrorHasCmdlet,
fast.Name),
callsShouldProc ? GetShouldProcessCallExtent(fast) : GetShouldContinueCallExtent(fast),
GetName(),
GetDianosticSeverity(),
fileName);
fileName)
;
}
}

Expand All @@ -198,7 +204,20 @@ private DiagnosticRecord GetViolation(Vertex v)
/// </summary>
private static IScriptExtent GetShouldProcessCallExtent(FunctionDefinitionAst functionDefinitionAst)
{
var invokeMemberExpressionAstFound = functionDefinitionAst.Find(IsShouldProcessCall, true);
var invokeMemberExpressionAstFound = functionDefinitionAst.Find(IsShouldProcessCall, true);
if (invokeMemberExpressionAstFound == null)
{
return functionDefinitionAst.Extent;
}

return (invokeMemberExpressionAstFound as InvokeMemberExpressionAst).Member.Extent;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return (invokeMemberExpressionAstFound as InvokeMemberExpressionAst).Member.Extent;
return ((InvokeMemberExpressionAst)invokeMemberExpressionAstFound).Member.Extent;

}
/// <summary>
/// Gets the extent of ShouldContinue call
/// </summary>
private static IScriptExtent GetShouldContinueCallExtent(FunctionDefinitionAst functionDefinitionAst)
{
var invokeMemberExpressionAstFound = functionDefinitionAst.Find(IsShouldContinueCall, true);
if (invokeMemberExpressionAstFound == null)
{
return functionDefinitionAst.Extent;
Expand All @@ -208,7 +227,7 @@ private static IScriptExtent GetShouldProcessCallExtent(FunctionDefinitionAst fu
}

/// <summary>
/// Returns true if ast if of the form $PSCmdlet.PSShouldProcess()
/// Returns true if ast is of the form $PSCmdlet.PSShouldProcess()
/// </summary>
private static bool IsShouldProcessCall(Ast ast)
{
Expand All @@ -224,18 +243,46 @@ private static bool IsShouldProcessCall(Ast ast)
return false;
}

if ("ShouldProcess".Equals(memberExprAst.Value, StringComparison.OrdinalIgnoreCase))
if ("ShouldProcess".Equals(memberExprAst.Value, StringComparison.OrdinalIgnoreCase))
{
return true;
}

return false;
}

/// <summary>
/// Returns true if ast is of the form $PSCmdlet.PSShouldProcess()
/// </summary>
private static bool IsShouldContinueCall(Ast ast)
{
var invokeMemberExpressionAst = ast as InvokeMemberExpressionAst;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nicer to do this:

return ast is InvokeMemberExpressionAst invokeMemberExpressionAst
    && invokeMemberExpressionAst.Member is StringConstantExpressionAst memberExprAst
    && string.Equals(memberExprAst.Value, "ShouldContinue", StringComparison.OrdinalIgnoreCase);

if (invokeMemberExpressionAst == null)
{
return false;
}

var memberExprAst = invokeMemberExpressionAst.Member as StringConstantExpressionAst;
if (memberExprAst == null)
{
return false;
}

if ("ShouldContinue".Equals(memberExprAst.Value, StringComparison.OrdinalIgnoreCase))
{
return true;
}

return false;
}
private bool callsShouldProcessDirectly(Vertex vertex)
{
return funcDigraph.GetNeighbors(vertex).Contains(shouldProcessVertex);
}
private bool callsShouldContinueDirectly(Vertex vertex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private bool callsShouldContinueDirectly(Vertex vertex)
private bool callsShouldContinueDirectly(Vertex vertex)

{
return funcDigraph.GetNeighbors(vertex).Contains(shouldContinueVertex);
}

/// <summary>
/// Checks if an upstream function declares SupportsShouldProcess
Expand Down
137 changes: 134 additions & 3 deletions Tests/Rules/UseShouldProcessCorrectly.tests.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,100 @@ function Bar
}
}

Foo
'@
$violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess
$violations.Count | Should -Be 1
}
}

Context "Where ShouldContinue is called by a downstream function" {
It "finds no violation for 1 level downstream call" {
$scriptDef = @'
function Foo
{
[CmdletBinding(SupportsShouldProcess=$true)]
param()

Bar
}

function Bar
{
[CmdletBinding(SupportsShouldProcess=$true)]
param()

if ($PSCmdlet.ShouldContinue("", ""))
{
"Continue normally..."
}
else
{
"what would happen..."
}
}

Foo
'@
$violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess
$violations.Count | Should -Be 0
}

It "finds violation if downstream function does not declare SupportsShouldProcess" {
$scriptDef = @'
function Foo
{
[CmdletBinding(SupportsShouldProcess=$true)]
param()

Bar
}

function Bar
{
if ($PSCmdlet.ShouldContinue("", ""))
{
"Continue normally..."
}
else
{
"what would happen..."
}
}

Foo
'@
$violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess
$violations.Count | Should -Be 1
}

It "finds violation for 2 level downstream calls" {
$scriptDef = @'
function Foo
{
[CmdletBinding(SupportsShouldProcess=$true)]
param()

Baz
}

function Baz
{
Bar
}

function Bar
{
if ($PSCmdlet.ShouldContinue("", ""))
{
"Continue normally..."
}
else
{
"what would happen..."
}
}

Foo
'@
$violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess
Expand Down Expand Up @@ -146,8 +240,33 @@ function Foo
}
}

Context "When a builtin command that supports ShouldProcess is called" {
It "finds no violation when caller declares SupportsShouldProcess and callee is a cmdlet with ShouldProcess" {
Context "When nested function definition calls ShouldContinue" {
It "finds no violation" {
$scriptDef = @'
function Foo
{
[CmdletBinding(SupportsShouldProcess)]
param()
begin
{
function Bar
{
if ($PSCmdlet.ShouldContinue('',''))
{

}
}
bar
}
}
'@
$violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess
$violations.Count | Should -Be 0
}
}

Context "When a builtin command that supports ShouldProcess/ShouldContinue is called" {
It "finds no violation when caller declares SupportsShouldProcess and callee is a cmdlet with ShouldProcess/ShouldContinue" {
$scriptDef = @'
function Remove-Foo {
[CmdletBinding(SupportsShouldProcess)]
Expand Down Expand Up @@ -278,5 +397,17 @@ function Foo
$violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess
$violations[0].Extent.Text | Should -Be 'ShouldProcess'
}
}

It "should mark only the ShouldContinue call" {
$scriptDef = @'
function Foo
{
param()
if ($PSCmdlet.ShouldContinue('', '')) { Write-Output "Should Continue" }
}
'@
$violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess
$violations[0].Extent.Text | Should -Be 'ShouldContinue'
}
}
}