diff --git a/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/ReEntrantCallFinder.cs b/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/ReEntrantCallFinder.cs index 5d078b0b..961124fa 100644 --- a/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/ReEntrantCallFinder.cs +++ b/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/ReEntrantCallFinder.cs @@ -1,3 +1,4 @@ +#nullable enable using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -50,7 +51,7 @@ private IEnumerable GetPotentialOtherSubstituteInvocations } } - private IEnumerable GetOtherSubstitutionsForSymbol(Compilation compilation, IOperation rootOperation, ISymbol rootNodeSymbol) + private IEnumerable GetOtherSubstitutionsForSymbol(Compilation compilation, IOperation rootOperation, ISymbol? rootNodeSymbol) { if (rootNodeSymbol == null) { @@ -59,7 +60,7 @@ private IEnumerable GetOtherSubstitutionsForSymbol(Compilation compi var rootIdentifierNode = GetLocalReferenceOperation(rootOperation); - var rootIdentifierSymbol = rootIdentifierNode?.ExtractSymbol(); + var rootIdentifierSymbol = rootIdentifierNode.ExtractSymbol(); if (rootIdentifierSymbol == null) { @@ -78,7 +79,12 @@ private IEnumerable GetOtherSubstitutionsForSymbol(Compilation compi var substitutedNode = _substitutionNodeFinder.FindForStandardExpression(operation); - var substituteNodeSymbol = substitutedNode?.ExtractSymbol(); + if (substitutedNode == null) + { + yield break; + } + + var substituteNodeSymbol = substitutedNode.ExtractSymbol(); if (substituteNodeSymbol == null) { @@ -101,18 +107,22 @@ private IEnumerable GetOtherSubstitutionsForSymbol(Compilation compi private static IEnumerable GetConstructorOperations(Compilation compilation, ISymbol fieldReferenceOperation) { - // TODO naming - foreach (var location in fieldReferenceOperation.ContainingType.GetMembers().OfType() - .Where(methodSymbol => methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor && methodSymbol.Locations.Length > 0) + SemanticModel? semanticModel = null; + foreach (var constructorLocation in fieldReferenceOperation.ContainingType.Constructors + .Where(methodSymbol => methodSymbol.Locations.Length > 0) .SelectMany(x => x.Locations) .Where(location => location.IsInSource)) { - var root = location.SourceTree.GetRoot(); - var relatedNode = root.FindNode(location.SourceSpan); + var root = constructorLocation.SourceTree.GetRoot(); + var relatedNode = root.FindNode(constructorLocation.SourceSpan); + + // perf - take original semantic model whenever possible + // but keep in mind that we might traverse outside of the original one https://github.com/nsubstitute/NSubstitute.Analyzers/issues/56 + semanticModel = semanticModel == null || semanticModel.SyntaxTree != constructorLocation.SourceTree + ? compilation.TryGetSemanticModel(constructorLocation.SourceTree) + : semanticModel; - // TODO reuse semantic model - var semanticModel = compilation.GetSemanticModel(location.SourceTree); - var operation = semanticModel.GetOperation(relatedNode) ?? semanticModel.GetOperation(relatedNode.Parent); + var operation = semanticModel?.GetOperation(relatedNode) ?? semanticModel?.GetOperation(relatedNode.Parent); if (operation is not null) { @@ -121,9 +131,9 @@ private static IEnumerable GetConstructorOperations(Compilation comp } } - private IOperation GetLocalReferenceOperation(IOperation node) + private IOperation? GetLocalReferenceOperation(IOperation? node) { - var child = node.Children.FirstOrDefault(); + var child = node?.Children.FirstOrDefault(); return child is ILocalReferenceOperation or IFieldReferenceOperation ? child : null; } @@ -140,7 +150,7 @@ private class ReEntrantCallVisitor : OperationWalker private readonly Compilation _compilation; private readonly HashSet _visitedOperations = new(); private readonly List _invocationOperation = new(); - private readonly Dictionary _semanticModelCache = new(1); + private SemanticModel? _semanticModel; public ImmutableList InvocationOperations => _invocationOperation.ToImmutableList(); @@ -175,24 +185,30 @@ private void VisitRelatedSymbols(IInvocationOperation invocationOperation) var root = location.SourceTree.GetRoot(); var relatedNode = root.FindNode(location.SourceSpan); var semanticModel = GetSemanticModel(relatedNode); + + if (semanticModel == null) + { + continue; + } + var operation = semanticModel.GetOperation(relatedNode) ?? semanticModel.GetOperation(relatedNode.Parent); Visit(operation); } } - private SemanticModel GetSemanticModel(SyntaxNode syntaxNode) + private SemanticModel? GetSemanticModel(SyntaxNode syntaxNode) { var syntaxTree = syntaxNode.SyntaxTree; - if (_semanticModelCache.TryGetValue(syntaxTree, out var semanticModel)) + + // perf - take original semantic model whenever possible + // but keep in mind that we might traverse outside of the original one https://github.com/nsubstitute/NSubstitute.Analyzers/issues/56 + if (_semanticModel == null || _semanticModel.SyntaxTree != syntaxTree) { - return semanticModel; + _semanticModel = _compilation.TryGetSemanticModel(syntaxTree); } - semanticModel = _compilation.GetSemanticModel(syntaxTree); - _semanticModelCache[syntaxTree] = semanticModel; - - return semanticModel; + return _semanticModel; } } } \ No newline at end of file diff --git a/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/SubstitutionNodeFinder.cs b/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/SubstitutionNodeFinder.cs index 92e5c4bb..83b80a85 100644 --- a/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/SubstitutionNodeFinder.cs +++ b/src/NSubstitute.Analyzers.Shared/DiagnosticAnalyzers/SubstitutionNodeFinder.cs @@ -1,3 +1,4 @@ +#nullable enable using System; using System.Collections.Generic; using System.Linq; @@ -13,7 +14,7 @@ internal class SubstitutionNodeFinder : ISubstitutionNodeFinder public IEnumerable Find( Compilation compilation, - IInvocationOperation invocationOperation) + IInvocationOperation? invocationOperation) { if (invocationOperation == null) { @@ -51,7 +52,7 @@ public IEnumerable Find( return standardSubstitution != null ? new[] { standardSubstitution } : Enumerable.Empty(); } - public IEnumerable FindForWhenExpression(Compilation compilation, IInvocationOperation invocationOperation) + public IEnumerable FindForWhenExpression(Compilation compilation, IInvocationOperation? invocationOperation) { if (invocationOperation == null) { @@ -86,7 +87,7 @@ public IEnumerable FindForReceivedInOrderExpression( return visitor.Operations; } - public IOperation FindForStandardExpression(IInvocationOperation invocationOperation) + public IOperation? FindForStandardExpression(IInvocationOperation invocationOperation) { return invocationOperation.GetSubstituteOperation(); } @@ -106,7 +107,7 @@ private static IEnumerable GetBaseTypesAndThis(ITypeSymbol type) } } - private IOperation FindForAndDoesExpression(IInvocationOperation invocationOperation) + private IOperation? FindForAndDoesExpression(IInvocationOperation invocationOperation) { if (invocationOperation.GetSubstituteOperation() is not IInvocationOperation parentInvocationOperation) { @@ -123,7 +124,7 @@ private class WhenVisitor : OperationWalker private readonly bool _includeAll; private readonly HashSet _operations = new(); - private readonly Dictionary _semanticModelCache = new(1); + private SemanticModel? _semanticModel; public WhenVisitor( Compilation compilation, @@ -150,9 +151,13 @@ public override void VisitMethodReference(IMethodReferenceOperation operation) { foreach (var methodDeclaringSyntaxReference in operation.Method.DeclaringSyntaxReferences) { - // TODO async? var syntaxNode = methodDeclaringSyntaxReference.GetSyntax(); var semanticModel = GetSemanticModel(syntaxNode.Parent); + if (semanticModel is null) + { + continue; + } + var referencedOperation = semanticModel.GetOperation(syntaxNode) ?? semanticModel.GetOperation(syntaxNode.Parent); Visit(referencedOperation); @@ -193,18 +198,18 @@ private void TryAdd(IOperation operation) } } - private SemanticModel GetSemanticModel(SyntaxNode syntaxNode) + private SemanticModel? GetSemanticModel(SyntaxNode syntaxNode) { var syntaxTree = syntaxNode.SyntaxTree; - if (_semanticModelCache.TryGetValue(syntaxTree, out var semanticModel)) + + // perf - take original semantic model whenever possible + // but keep in mind that we might traverse outside of the original one https://github.com/nsubstitute/NSubstitute.Analyzers/issues/56 + if (_semanticModel == null || _semanticModel.SyntaxTree != syntaxTree) { - return semanticModel; + _semanticModel = _compilation.TryGetSemanticModel(syntaxTree); } - semanticModel = _compilation.GetSemanticModel(syntaxTree); - _semanticModelCache[syntaxTree] = semanticModel; - - return semanticModel; + return _semanticModel; } } } \ No newline at end of file diff --git a/src/NSubstitute.Analyzers.Shared/Extensions/CompilationExtensions.cs b/src/NSubstitute.Analyzers.Shared/Extensions/CompilationExtensions.cs new file mode 100644 index 00000000..884f7dc7 --- /dev/null +++ b/src/NSubstitute.Analyzers.Shared/Extensions/CompilationExtensions.cs @@ -0,0 +1,10 @@ +#nullable enable +using Microsoft.CodeAnalysis; + +namespace NSubstitute.Analyzers.Shared.Extensions; + +internal static class CompilationExtensions +{ + public static SemanticModel? TryGetSemanticModel(this Compilation compilation, SyntaxTree syntaxTree) => + compilation.ContainsSyntaxTree(syntaxTree) ? compilation.GetSemanticModel(syntaxTree) : null; +} \ No newline at end of file diff --git a/src/NSubstitute.Analyzers.Shared/Extensions/IOperationExtensions.cs b/src/NSubstitute.Analyzers.Shared/Extensions/IOperationExtensions.cs index 1c32ba01..1e30604e 100644 --- a/src/NSubstitute.Analyzers.Shared/Extensions/IOperationExtensions.cs +++ b/src/NSubstitute.Analyzers.Shared/Extensions/IOperationExtensions.cs @@ -1,3 +1,4 @@ +#nullable enable using System.Collections.Generic; using System.Linq; using Microsoft.CodeAnalysis; @@ -22,7 +23,7 @@ public static bool IsEventAssignmentOperation(this IOperation operation) public static IOperation GetSubstituteOperation(this IPropertyReferenceOperation propertyReferenceOperation) => propertyReferenceOperation.Instance; - public static IOperation GetSubstituteOperation(this IInvocationOperation invocationOperation) + public static IOperation? GetSubstituteOperation(this IInvocationOperation invocationOperation) { if (invocationOperation.Instance != null) { @@ -110,7 +111,7 @@ public static IEnumerable Ancestors(this IOperation operation) } } - public static ISymbol ExtractSymbol(this IOperation operation) + public static ISymbol? ExtractSymbol(this IOperation? operation) { var symbol = operation switch { @@ -122,10 +123,11 @@ public static ISymbol ExtractSymbol(this IOperation operation) IFieldReferenceOperation fieldReferenceOperation => fieldReferenceOperation.Field, _ => null }; + return symbol; } - public static IEnumerable GetArrayElementValues(this IOperation operation) + public static IEnumerable? GetArrayElementValues(this IOperation operation) { return operation switch {