diff --git a/Rules/UseShouldProcessCorrectly.cs b/Rules/UseShouldProcessCorrectly.cs index 6cacc2634..5f50a8f10 100644 --- a/Rules/UseShouldProcessCorrectly.cs +++ b/Rules/UseShouldProcessCorrectly.cs @@ -28,12 +28,14 @@ public class UseShouldProcessCorrectly : IScriptRule private FunctionReferenceDigraph funcDigraph; private List diagnosticRecords; private readonly Vertex shouldProcessVertex; + private readonly Vertex shouldContinueVertex; private readonly Vertex implicitShouldProcessVertex; public UseShouldProcessCorrectly() { diagnosticRecords = new List(); shouldProcessVertex = new Vertex("ShouldProcess", null); + shouldContinueVertex = new Vertex("ShouldContinue", null); implicitShouldProcessVertex = new Vertex("implicitShouldProcessVertex", null); } @@ -151,8 +153,10 @@ private DiagnosticRecord GetViolation(Vertex v) if (DeclaresSupportsShouldProcess(fast)) { bool callsShouldProcess = funcDigraph.IsConnected(v, shouldProcessVertex); + bool callsShouldContinue = funcDigraph.IsConnected(v, shouldContinueVertex); bool callsCommandWithShouldProcess = funcDigraph.IsConnected(v, implicitShouldProcessVertex); if (!callsShouldProcess + && !callsShouldContinue && !callsCommandWithShouldProcess) { return new DiagnosticRecord( @@ -168,7 +172,8 @@ private DiagnosticRecord GetViolation(Vertex v) } else { - if (callsShouldProcessDirectly(v)) + bool callsShouldProc = callsShouldProcessDirectly(v); + if (callsShouldProc || callsShouldContinueDirectly(v)) { // check if upstream function declares SupportShouldProcess // if so, this might just be a helper function @@ -179,14 +184,15 @@ private DiagnosticRecord GetViolation(Vertex v) } return new DiagnosticRecord( - string.Format( - CultureInfo.CurrentCulture, - Strings.ShouldProcessErrorHasCmdlet, - fast.Name), - GetShouldProcessCallExtent(fast), + string.Format( + CultureInfo.CurrentCulture, + Strings.ShouldProcessErrorHasCmdlet, + fast.Name), + callsShouldProc ? GetShouldProcessCallExtent(fast) : GetShouldContinueCallExtent(fast), GetName(), GetDianosticSeverity(), - fileName); + fileName) +; } } @@ -198,7 +204,20 @@ private DiagnosticRecord GetViolation(Vertex v) /// private static IScriptExtent GetShouldProcessCallExtent(FunctionDefinitionAst functionDefinitionAst) { - var invokeMemberExpressionAstFound = functionDefinitionAst.Find(IsShouldProcessCall, true); + var invokeMemberExpressionAstFound = functionDefinitionAst.Find(IsShouldProcessCall, true); + if (invokeMemberExpressionAstFound == null) + { + return functionDefinitionAst.Extent; + } + + return (invokeMemberExpressionAstFound as InvokeMemberExpressionAst).Member.Extent; + } + /// + /// Gets the extent of ShouldContinue call + /// + private static IScriptExtent GetShouldContinueCallExtent(FunctionDefinitionAst functionDefinitionAst) + { + var invokeMemberExpressionAstFound = functionDefinitionAst.Find(IsShouldContinueCall, true); if (invokeMemberExpressionAstFound == null) { return functionDefinitionAst.Extent; @@ -208,7 +227,7 @@ private static IScriptExtent GetShouldProcessCallExtent(FunctionDefinitionAst fu } /// - /// Returns true if ast if of the form $PSCmdlet.PSShouldProcess() + /// Returns true if ast is of the form $PSCmdlet.PSShouldProcess() /// private static bool IsShouldProcessCall(Ast ast) { @@ -224,7 +243,7 @@ private static bool IsShouldProcessCall(Ast ast) return false; } - if ("ShouldProcess".Equals(memberExprAst.Value, StringComparison.OrdinalIgnoreCase)) + if ("ShouldProcess".Equals(memberExprAst.Value, StringComparison.OrdinalIgnoreCase)) { return true; } @@ -232,10 +251,38 @@ private static bool IsShouldProcessCall(Ast ast) return false; } + /// + /// Returns true if ast is of the form $PSCmdlet.PSShouldProcess() + /// + private static bool IsShouldContinueCall(Ast ast) + { + var invokeMemberExpressionAst = ast as InvokeMemberExpressionAst; + if (invokeMemberExpressionAst == null) + { + return false; + } + + var memberExprAst = invokeMemberExpressionAst.Member as StringConstantExpressionAst; + if (memberExprAst == null) + { + return false; + } + + if ("ShouldContinue".Equals(memberExprAst.Value, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + return false; + } private bool callsShouldProcessDirectly(Vertex vertex) { return funcDigraph.GetNeighbors(vertex).Contains(shouldProcessVertex); } + private bool callsShouldContinueDirectly(Vertex vertex) + { + return funcDigraph.GetNeighbors(vertex).Contains(shouldContinueVertex); + } /// /// Checks if an upstream function declares SupportsShouldProcess diff --git a/Tests/Rules/UseShouldProcessCorrectly.tests.ps1 b/Tests/Rules/UseShouldProcessCorrectly.tests.ps1 index cec228f47..3e061db0c 100644 --- a/Tests/Rules/UseShouldProcessCorrectly.tests.ps1 +++ b/Tests/Rules/UseShouldProcessCorrectly.tests.ps1 @@ -114,6 +114,100 @@ function Bar } } +Foo +'@ + $violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess + $violations.Count | Should -Be 1 + } + } + + Context "Where ShouldContinue is called by a downstream function" { + It "finds no violation for 1 level downstream call" { + $scriptDef = @' +function Foo +{ + [CmdletBinding(SupportsShouldProcess=$true)] + param() + + Bar +} + +function Bar +{ + [CmdletBinding(SupportsShouldProcess=$true)] + param() + + if ($PSCmdlet.ShouldContinue("", "")) + { + "Continue normally..." + } + else + { + "what would happen..." + } +} + +Foo +'@ + $violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess + $violations.Count | Should -Be 0 + } + + It "finds violation if downstream function does not declare SupportsShouldProcess" { + $scriptDef = @' +function Foo +{ + [CmdletBinding(SupportsShouldProcess=$true)] + param() + + Bar +} + +function Bar +{ + if ($PSCmdlet.ShouldContinue("", "")) + { + "Continue normally..." + } + else + { + "what would happen..." + } +} + +Foo +'@ + $violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess + $violations.Count | Should -Be 1 + } + + It "finds violation for 2 level downstream calls" { + $scriptDef = @' +function Foo +{ + [CmdletBinding(SupportsShouldProcess=$true)] + param() + + Baz +} + +function Baz +{ + Bar +} + +function Bar +{ + if ($PSCmdlet.ShouldContinue("", "")) + { + "Continue normally..." + } + else + { + "what would happen..." + } +} + Foo '@ $violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess @@ -146,8 +240,33 @@ function Foo } } - Context "When a builtin command that supports ShouldProcess is called" { - It "finds no violation when caller declares SupportsShouldProcess and callee is a cmdlet with ShouldProcess" { + Context "When nested function definition calls ShouldContinue" { + It "finds no violation" { + $scriptDef = @' +function Foo +{ + [CmdletBinding(SupportsShouldProcess)] + param() + begin + { + function Bar + { + if ($PSCmdlet.ShouldContinue('','')) + { + + } + } + bar + } +} +'@ + $violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess + $violations.Count | Should -Be 0 + } + } + + Context "When a builtin command that supports ShouldProcess/ShouldContinue is called" { + It "finds no violation when caller declares SupportsShouldProcess and callee is a cmdlet with ShouldProcess/ShouldContinue" { $scriptDef = @' function Remove-Foo { [CmdletBinding(SupportsShouldProcess)] @@ -278,5 +397,17 @@ function Foo $violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess $violations[0].Extent.Text | Should -Be 'ShouldProcess' } - } + + It "should mark only the ShouldContinue call" { + $scriptDef = @' +function Foo +{ + param() + if ($PSCmdlet.ShouldContinue('', '')) { Write-Output "Should Continue" } +} +'@ + $violations = Invoke-ScriptAnalyzer -ScriptDefinition $scriptDef -IncludeRule PSShouldProcess + $violations[0].Extent.Text | Should -Be 'ShouldContinue' + } + } }