Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gh-12] detect reentrant calls #20

Merged
merged 8 commits into from
Jul 18, 2018
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using NSubstitute.Analyzers.Shared.DiagnosticAnalyzers;

namespace NSubstitute.Analyzers.CSharp.DiagnosticAnalyzers
{
internal class ReEntrantCallFinder : AbstractReEntrantCallFinder
{
protected override ImmutableList<ISymbol> GetReEntrantSymbols(SemanticModel semanticModel, SyntaxNode rootNode)
{
var visitor = new ReEntrantCallVisitor(this, semanticModel);
visitor.Visit(rootNode);
return visitor.InvocationSymbols;
}

private class ReEntrantCallVisitor : CSharpSyntaxWalker
{
private readonly ReEntrantCallFinder _reEntrantCallFinder;
private readonly SemanticModel _semanticModel;
private readonly HashSet<SyntaxNode> _visitedNodes = new HashSet<SyntaxNode>();
private readonly List<ISymbol> _invocationSymbols = new List<ISymbol>();

public ImmutableList<ISymbol> InvocationSymbols => _invocationSymbols.ToImmutableList();

public ReEntrantCallVisitor(ReEntrantCallFinder reEntrantCallFinder, SemanticModel semanticModel)
{
_reEntrantCallFinder = reEntrantCallFinder;
_semanticModel = semanticModel;
}

public override void VisitInvocationExpression(InvocationExpressionSyntax node)
{
var symbolInfo = _semanticModel.GetSymbolInfo(node);
if (_reEntrantCallFinder.IsReturnsLikeMethod(_semanticModel, symbolInfo.Symbol))
{
_invocationSymbols.Add(symbolInfo.Symbol);
}

base.VisitInvocationExpression(node);
}

public override void DefaultVisit(SyntaxNode node)
{
VisitRelatedSymbols(node);
base.DefaultVisit(node);
}

private void VisitRelatedSymbols(SyntaxNode syntaxNode)
{
if (_visitedNodes.Contains(syntaxNode) == false &&
(syntaxNode.IsKind(SyntaxKind.IdentifierName) ||
syntaxNode.IsKind(SyntaxKind.ElementAccessExpression) ||
syntaxNode.IsKind(SyntaxKind.SimpleMemberAccessExpression)))
{
_visitedNodes.Add(syntaxNode);
foreach (var relatedNode in _reEntrantCallFinder.GetRelatedNodes(_semanticModel, syntaxNode))
{
Visit(relatedNode);
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using NSubstitute.Analyzers.Shared.DiagnosticAnalyzers;

namespace NSubstitute.Analyzers.CSharp.DiagnosticAnalyzers
{
[DiagnosticAnalyzer(LanguageNames.CSharp)]
internal class ReEntrantSetupAnalyzer : AbstractReEntrantSetupAnalyzer<SyntaxKind, InvocationExpressionSyntax>
{
public ReEntrantSetupAnalyzer()
: base(new DiagnosticDescriptorsProvider())
{
}

protected override AbstractReEntrantCallFinder GetReEntrantCallFinder()
{
return new ReEntrantCallFinder();
}

protected override SyntaxKind InvocationExpressionKind { get; } = SyntaxKind.InvocationExpression;

protected override IEnumerable<SyntaxNode> ExtractArguments(InvocationExpressionSyntax invocationExpressionSyntax)
{
return invocationExpressionSyntax.ArgumentList.Arguments.Select(arg => arg.Expression);
}
}
}
27 changes: 27 additions & 0 deletions src/NSubstitute.Analyzers.CSharp/Resources.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions src/NSubstitute.Analyzers.CSharp/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,16 @@
<value>Non-virtual setup specification.</value>
<comment>The title of the diagnostic.</comment>
</data>
<data name="ReEntrantSubstituteCallDescription" xml:space="preserve">
<value>Re-entrant substitute call.</value>
<comment>An optional longer localizable description of the diagnostic.</comment>
</data>
<data name="ReEntrantSubstituteCallMessageFormat" xml:space="preserve">
<value>{0}() is set with a method that itself calls {1}. This can cause problems with NSubstitute. Consider replacing with a lambda: {0}(x => {2}).</value>
<comment>The format-able message the diagnostic displays.</comment>
</data>
<data name="ReEntrantSubstituteCallTitle" xml:space="preserve">
<value>Re-entrant substitute call.</value>
<comment>The title of the diagnostic.</comment>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ internal class AbstractDiagnosticDescriptorsProvider<T> : IDiagnosticDescriptors
public DiagnosticDescriptor NonVirtualReceivedSetupSpecification { get; } = DiagnosticDescriptors<T>.NonVirtualReceivedSetupSpecification;

public DiagnosticDescriptor NonVirtualWhenSetupSpecification { get; } = DiagnosticDescriptors<T>.NonVirtualWhenSetupSpecification;

public DiagnosticDescriptor ReEntrantSubstituteCall { get; } = DiagnosticDescriptors<T>.ReEntrantSubstituteCall;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace NSubstitute.Analyzers.Shared.DiagnosticAnalyzers
{
internal abstract class AbstractReEntrantCallFinder
{
private static readonly ImmutableDictionary<string, string> MethodNames = new Dictionary<string, string>()
{
[MetadataNames.NSubstituteReturnsMethod] = MetadataNames.NSubstituteSubstituteExtensionsFullTypeName,
[MetadataNames.NSubstituteReturnsForAnyArgsMethod] = MetadataNames.NSubstituteSubstituteExtensionsFullTypeName,
[MetadataNames.NSubstituteDoMethod] = MetadataNames.NSubstituteWhenCalledType
}.ToImmutableDictionary();

public ImmutableList<ISymbol> GetReEntrantCalls(SemanticModel semanticModel, SyntaxNode rootNode)
{
var typeInfo = semanticModel.GetTypeInfo(rootNode);
if (IsCalledViaDelegate(semanticModel, typeInfo))
{
return ImmutableList<ISymbol>.Empty;
}

return GetReEntrantSymbols(semanticModel, rootNode);
}

protected abstract ImmutableList<ISymbol> GetReEntrantSymbols(SemanticModel semanticModel, SyntaxNode rootNode);

protected IEnumerable<SyntaxNode> GetRelatedNodes(SemanticModel semanticModel, SyntaxNode syntaxNode)
{
var symbol = semanticModel.GetSymbolInfo(syntaxNode);
if (symbol.Symbol != null && symbol.Symbol.Locations.Any())
{
foreach (var symbolLocation in symbol.Symbol.Locations.Where(location => location.SourceTree != null))
{
var root = symbolLocation.SourceTree.GetRoot();
var relatedNode = root.FindNode(symbolLocation.SourceSpan);
if (relatedNode != null)
{
yield return relatedNode;
}
}
}
}

protected bool IsReturnsLikeMethod(SemanticModel semanticModel, ISymbol symbol)
{
if (symbol == null || MethodNames.TryGetValue(symbol.Name, out var containingType) == false)
{
return false;
}

return symbol.ContainingAssembly?.Name.Equals(MetadataNames.NSubstituteAssemblyName, StringComparison.OrdinalIgnoreCase) == true &&
(symbol.ContainingType?.ToString().Equals(containingType, StringComparison.OrdinalIgnoreCase) == true ||
(symbol.ContainingType?.ConstructedFrom.Name)?.Equals(containingType, StringComparison.OrdinalIgnoreCase) == true);
}

private static bool IsCalledViaDelegate(SemanticModel semanticModel, TypeInfo typeInfo)
{
var typeSymbol = typeInfo.Type ?? typeInfo.ConvertedType;
var isCalledViaDelegate = typeSymbol != null &&
typeSymbol.TypeKind == TypeKind.Delegate &&
typeSymbol is INamedTypeSymbol namedTypeSymbol &&
namedTypeSymbol.ConstructedFrom.Equals(semanticModel.Compilation.GetTypeByMetadataName("System.Func`2")) &&
IsCallInfoParameter(namedTypeSymbol.TypeArguments.First());

return isCalledViaDelegate;
}

private static bool IsCallInfoParameter(ITypeSymbol symbol)
{
return symbol.ContainingAssembly?.Name.Equals(MetadataNames.NSubstituteAssemblyName, StringComparison.OrdinalIgnoreCase) == true &&
symbol.ToString().Equals(MetadataNames.NSubstituteCoreFullTypeName, StringComparison.OrdinalIgnoreCase) == true;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;

namespace NSubstitute.Analyzers.Shared.DiagnosticAnalyzers
{
internal abstract class AbstractReEntrantSetupAnalyzer<TSyntaxKind, TInvocationExpressionSyntax> : AbstractDiagnosticAnalyzer
where TInvocationExpressionSyntax : SyntaxNode
where TSyntaxKind : struct
{
private static readonly ImmutableHashSet<string> MethodNames = ImmutableHashSet.Create(
MetadataNames.NSubstituteReturnsMethod,
MetadataNames.NSubstituteReturnsForAnyArgsMethod);

private AbstractReEntrantCallFinder ReEntrantCallFinder => _reEntrantCallFinderProxy.Value;

private readonly Lazy<AbstractReEntrantCallFinder> _reEntrantCallFinderProxy;

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics =>
ImmutableArray.Create(DiagnosticDescriptorsProvider.ReEntrantSubstituteCall);

protected AbstractReEntrantSetupAnalyzer(IDiagnosticDescriptorsProvider diagnosticDescriptorsProvider)
: base(diagnosticDescriptorsProvider)
{
_reEntrantCallFinderProxy = new Lazy<AbstractReEntrantCallFinder>(GetReEntrantCallFinder);
}

protected abstract AbstractReEntrantCallFinder GetReEntrantCallFinder();

protected abstract TSyntaxKind InvocationExpressionKind { get; }

public override void Initialize(AnalysisContext context)
{
context.RegisterSyntaxNodeAction(AnalyzeInvocation, InvocationExpressionKind);
}

protected abstract IEnumerable<SyntaxNode> ExtractArguments(TInvocationExpressionSyntax invocationExpressionSyntax);

private void AnalyzeInvocation(SyntaxNodeAnalysisContext syntaxNodeContext)
{
var invocationExpression = (TInvocationExpressionSyntax)syntaxNodeContext.Node;
var methodSymbolInfo = syntaxNodeContext.SemanticModel.GetSymbolInfo(invocationExpression);

if (methodSymbolInfo.Symbol?.Kind != SymbolKind.Method)
{
return;
}

var methodSymbol = (IMethodSymbol)methodSymbolInfo.Symbol;

if (IsReturnsLikeMethod(syntaxNodeContext, invocationExpression, methodSymbol.Name) == false)
{
return;
}

var allArguments = ExtractArguments(invocationExpression);
var argumentsForAnalysis = methodSymbol.MethodKind == MethodKind.ReducedExtension ? allArguments : allArguments.Skip(1);

foreach (var argument in argumentsForAnalysis)
{
var reentrantSymbol = ReEntrantCallFinder.GetReEntrantCalls(syntaxNodeContext.SemanticModel, argument).FirstOrDefault();
if (reentrantSymbol != null)
{
var diagnostic = Diagnostic.Create(
DiagnosticDescriptorsProvider.ReEntrantSubstituteCall,
argument.GetLocation(),
methodSymbol.Name,
reentrantSymbol.Name,
argument.ToString());

syntaxNodeContext.ReportDiagnostic(diagnostic);
}
}
}

private bool IsReturnsLikeMethod(SyntaxNodeAnalysisContext syntaxNodeContext, SyntaxNode syntax, string memberName)
{
if (MethodNames.Contains(memberName) == false)
{
return false;
}

var symbol = syntaxNodeContext.SemanticModel.GetSymbolInfo(syntax);

return symbol.Symbol?.ContainingAssembly?.Name.Equals(MetadataNames.NSubstituteAssemblyName, StringComparison.OrdinalIgnoreCase) == true &&
symbol.Symbol?.ContainingType?.ToString().Equals(MetadataNames.NSubstituteSubstituteExtensionsFullTypeName, StringComparison.OrdinalIgnoreCase) == true;
}
}
}
8 changes: 8 additions & 0 deletions src/NSubstitute.Analyzers.Shared/DiagnosticDescriptors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ internal class DiagnosticDescriptors<T>
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);

public static DiagnosticDescriptor ReEntrantSubstituteCall { get; } =
CreateDiagnosticDescriptor(
name: nameof(ReEntrantSubstituteCall),
id: DiagnosticIdentifiers.ReEntrantSubstituteCall,
category: DiagnosticCategories.Usage,
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);

private static DiagnosticDescriptor CreateDiagnosticDescriptor(
string name, string id, string category, DiagnosticSeverity defaultSeverity, bool isEnabledByDefault)
{
Expand Down
1 change: 1 addition & 0 deletions src/NSubstitute.Analyzers.Shared/DiagnosticIdentifiers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ internal class DiagnosticIdentifiers
public static readonly string SubstituteConstructorArgumentsForDelegate = "NS010";
public static readonly string NonVirtualReceivedSetupSpecification = "NS011";
public static readonly string NonVirtualWhenSetupSpecification = "NS012";
public static readonly string ReEntrantSubstituteCall = "NS013";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ internal interface IDiagnosticDescriptorsProvider
DiagnosticDescriptor NonVirtualReceivedSetupSpecification { get; }

DiagnosticDescriptor NonVirtualWhenSetupSpecification { get; }

DiagnosticDescriptor ReEntrantSubstituteCall { get; }
}
}
3 changes: 3 additions & 0 deletions src/NSubstitute.Analyzers.Shared/MetadataNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ internal class MetadataNames
{
public const string NSubstituteAssemblyName = "NSubstitute";
public const string NSubstituteSubstituteExtensionsFullTypeName = "NSubstitute.SubstituteExtensions";
public const string NSubstituteCoreFullTypeName = "NSubstitute.Core.CallInfo";
public const string NSubstituteSubstituteFullTypeName = "NSubstitute.Substitute";
public const string NSubstituteReturnsMethod = "Returns";
public const string NSubstituteReturnsForAnyArgsMethod = "ReturnsForAnyArgs";
public const string NSubstituteDoMethod = "Do";
public const string NSubstituteReceivedMethod = "Received";
public const string NSubstituteReceivedWithAnyArgsMethod = "ReceivedWithAnyArgs";
public const string NSubstituteDidNotReceiveMethod = "DidNotReceive";
Expand All @@ -17,5 +19,6 @@ internal class MetadataNames
public const string CastleDynamicProxyGenAssembly2Name = "DynamicProxyGenAssembly2";
public const string NSubstituteWhenMethod = "When";
public const string NSubstituteWhenForAnyArgsMethod = "WhenForAnyArgs";
public const string NSubstituteWhenCalledType = "WhenCalled";
}
}
Loading