Skip to content

Commit

Permalink
GH-153 - clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
tpodolak committed Oct 2, 2022
1 parent a89afb5 commit 9395f5d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#nullable enable
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
Expand Down Expand Up @@ -50,7 +51,7 @@ private IEnumerable<IInvocationOperation> GetPotentialOtherSubstituteInvocations
}
}

private IEnumerable<IOperation> GetOtherSubstitutionsForSymbol(Compilation compilation, IOperation rootOperation, ISymbol rootNodeSymbol)
private IEnumerable<IOperation> GetOtherSubstitutionsForSymbol(Compilation compilation, IOperation rootOperation, ISymbol? rootNodeSymbol)
{
if (rootNodeSymbol == null)
{
Expand All @@ -59,7 +60,7 @@ private IEnumerable<IOperation> GetOtherSubstitutionsForSymbol(Compilation compi

var rootIdentifierNode = GetLocalReferenceOperation(rootOperation);

var rootIdentifierSymbol = rootIdentifierNode?.ExtractSymbol();
var rootIdentifierSymbol = rootIdentifierNode.ExtractSymbol();

if (rootIdentifierSymbol == null)
{
Expand All @@ -78,7 +79,12 @@ private IEnumerable<IOperation> GetOtherSubstitutionsForSymbol(Compilation compi

var substitutedNode = _substitutionNodeFinder.FindForStandardExpression(operation);

var substituteNodeSymbol = substitutedNode?.ExtractSymbol();
if (substitutedNode == null)
{
yield break;
}

var substituteNodeSymbol = substitutedNode.ExtractSymbol();

if (substituteNodeSymbol == null)
{
Expand All @@ -101,18 +107,22 @@ private IEnumerable<IOperation> GetOtherSubstitutionsForSymbol(Compilation compi

private static IEnumerable<IOperation> GetConstructorOperations(Compilation compilation, ISymbol fieldReferenceOperation)
{
// TODO naming
foreach (var location in fieldReferenceOperation.ContainingType.GetMembers().OfType<IMethodSymbol>()
.Where(methodSymbol => methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor && methodSymbol.Locations.Length > 0)
SemanticModel? semanticModel = null;
foreach (var constructorLocation in fieldReferenceOperation.ContainingType.Constructors
.Where(methodSymbol => methodSymbol.Locations.Length > 0)
.SelectMany(x => x.Locations)
.Where(location => location.IsInSource))
{
var root = location.SourceTree.GetRoot();
var relatedNode = root.FindNode(location.SourceSpan);
var root = constructorLocation.SourceTree.GetRoot();
var relatedNode = root.FindNode(constructorLocation.SourceSpan);

// perf - take original semantic model whenever possible
// but keep in mind that we might traverse outside of the original one https://github.com/nsubstitute/NSubstitute.Analyzers/issues/56
semanticModel = semanticModel == null || semanticModel.SyntaxTree != constructorLocation.SourceTree
? compilation.TryGetSemanticModel(constructorLocation.SourceTree)
: semanticModel;

// TODO reuse semantic model
var semanticModel = compilation.GetSemanticModel(location.SourceTree);
var operation = semanticModel.GetOperation(relatedNode) ?? semanticModel.GetOperation(relatedNode.Parent);
var operation = semanticModel?.GetOperation(relatedNode) ?? semanticModel?.GetOperation(relatedNode.Parent);

if (operation is not null)
{
Expand All @@ -121,9 +131,9 @@ private static IEnumerable<IOperation> GetConstructorOperations(Compilation comp
}
}

private IOperation GetLocalReferenceOperation(IOperation node)
private IOperation? GetLocalReferenceOperation(IOperation? node)
{
var child = node.Children.FirstOrDefault();
var child = node?.Children.FirstOrDefault();
return child is ILocalReferenceOperation or IFieldReferenceOperation ? child : null;
}

Expand All @@ -140,7 +150,7 @@ private class ReEntrantCallVisitor : OperationWalker
private readonly Compilation _compilation;
private readonly HashSet<IOperation> _visitedOperations = new();
private readonly List<IInvocationOperation> _invocationOperation = new();
private readonly Dictionary<SyntaxTree, SemanticModel> _semanticModelCache = new(1);
private SemanticModel? _semanticModel;

public ImmutableList<IInvocationOperation> InvocationOperations => _invocationOperation.ToImmutableList();

Expand Down Expand Up @@ -175,24 +185,30 @@ private void VisitRelatedSymbols(IInvocationOperation invocationOperation)
var root = location.SourceTree.GetRoot();
var relatedNode = root.FindNode(location.SourceSpan);
var semanticModel = GetSemanticModel(relatedNode);

if (semanticModel == null)
{
continue;
}

var operation = semanticModel.GetOperation(relatedNode) ??
semanticModel.GetOperation(relatedNode.Parent);
Visit(operation);
}
}

private SemanticModel GetSemanticModel(SyntaxNode syntaxNode)
private SemanticModel? GetSemanticModel(SyntaxNode syntaxNode)
{
var syntaxTree = syntaxNode.SyntaxTree;
if (_semanticModelCache.TryGetValue(syntaxTree, out var semanticModel))

// perf - take original semantic model whenever possible
// but keep in mind that we might traverse outside of the original one https://github.com/nsubstitute/NSubstitute.Analyzers/issues/56
if (_semanticModel == null || _semanticModel.SyntaxTree != syntaxTree)
{
return semanticModel;
_semanticModel = _compilation.TryGetSemanticModel(syntaxTree);
}

semanticModel = _compilation.GetSemanticModel(syntaxTree);
_semanticModelCache[syntaxTree] = semanticModel;

return semanticModel;
return _semanticModel;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#nullable enable
using System;
using System.Collections.Generic;
using System.Linq;
Expand All @@ -13,7 +14,7 @@ internal class SubstitutionNodeFinder : ISubstitutionNodeFinder

public IEnumerable<IOperation> Find(
Compilation compilation,
IInvocationOperation invocationOperation)
IInvocationOperation? invocationOperation)
{
if (invocationOperation == null)
{
Expand Down Expand Up @@ -51,7 +52,7 @@ public IEnumerable<IOperation> Find(
return standardSubstitution != null ? new[] { standardSubstitution } : Enumerable.Empty<IOperation>();
}

public IEnumerable<IOperation> FindForWhenExpression(Compilation compilation, IInvocationOperation invocationOperation)
public IEnumerable<IOperation> FindForWhenExpression(Compilation compilation, IInvocationOperation? invocationOperation)
{
if (invocationOperation == null)
{
Expand Down Expand Up @@ -86,7 +87,7 @@ public IEnumerable<IOperation> FindForReceivedInOrderExpression(
return visitor.Operations;
}

public IOperation FindForStandardExpression(IInvocationOperation invocationOperation)
public IOperation? FindForStandardExpression(IInvocationOperation invocationOperation)
{
return invocationOperation.GetSubstituteOperation();
}
Expand All @@ -106,7 +107,7 @@ private static IEnumerable<ITypeSymbol> GetBaseTypesAndThis(ITypeSymbol type)
}
}

private IOperation FindForAndDoesExpression(IInvocationOperation invocationOperation)
private IOperation? FindForAndDoesExpression(IInvocationOperation invocationOperation)
{
if (invocationOperation.GetSubstituteOperation() is not IInvocationOperation parentInvocationOperation)
{
Expand All @@ -123,7 +124,7 @@ private class WhenVisitor : OperationWalker
private readonly bool _includeAll;
private readonly HashSet<IOperation> _operations = new();

private readonly Dictionary<SyntaxTree, SemanticModel> _semanticModelCache = new(1);
private SemanticModel? _semanticModel;

public WhenVisitor(
Compilation compilation,
Expand All @@ -150,9 +151,13 @@ public override void VisitMethodReference(IMethodReferenceOperation operation)
{
foreach (var methodDeclaringSyntaxReference in operation.Method.DeclaringSyntaxReferences)
{
// TODO async?
var syntaxNode = methodDeclaringSyntaxReference.GetSyntax();
var semanticModel = GetSemanticModel(syntaxNode.Parent);
if (semanticModel is null)
{
continue;
}

var referencedOperation = semanticModel.GetOperation(syntaxNode) ??
semanticModel.GetOperation(syntaxNode.Parent);
Visit(referencedOperation);
Expand Down Expand Up @@ -193,18 +198,18 @@ private void TryAdd(IOperation operation)
}
}

private SemanticModel GetSemanticModel(SyntaxNode syntaxNode)
private SemanticModel? GetSemanticModel(SyntaxNode syntaxNode)
{
var syntaxTree = syntaxNode.SyntaxTree;
if (_semanticModelCache.TryGetValue(syntaxTree, out var semanticModel))

// perf - take original semantic model whenever possible
// but keep in mind that we might traverse outside of the original one https://github.com/nsubstitute/NSubstitute.Analyzers/issues/56
if (_semanticModel == null || _semanticModel.SyntaxTree != syntaxTree)
{
return semanticModel;
_semanticModel = _compilation.TryGetSemanticModel(syntaxTree);
}

semanticModel = _compilation.GetSemanticModel(syntaxTree);
_semanticModelCache[syntaxTree] = semanticModel;

return semanticModel;
return _semanticModel;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#nullable enable
using Microsoft.CodeAnalysis;

namespace NSubstitute.Analyzers.Shared.Extensions;

internal static class CompilationExtensions
{
public static SemanticModel? TryGetSemanticModel(this Compilation compilation, SyntaxTree syntaxTree) =>
compilation.ContainsSyntaxTree(syntaxTree) ? compilation.GetSemanticModel(syntaxTree) : null;
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#nullable enable
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
Expand All @@ -22,7 +23,7 @@ public static bool IsEventAssignmentOperation(this IOperation operation)
public static IOperation GetSubstituteOperation(this IPropertyReferenceOperation propertyReferenceOperation) =>
propertyReferenceOperation.Instance;

public static IOperation GetSubstituteOperation(this IInvocationOperation invocationOperation)
public static IOperation? GetSubstituteOperation(this IInvocationOperation invocationOperation)
{
if (invocationOperation.Instance != null)
{
Expand Down Expand Up @@ -110,7 +111,7 @@ public static IEnumerable<IOperation> Ancestors(this IOperation operation)
}
}

public static ISymbol ExtractSymbol(this IOperation operation)
public static ISymbol? ExtractSymbol(this IOperation? operation)
{
var symbol = operation switch
{
Expand All @@ -122,10 +123,11 @@ public static ISymbol ExtractSymbol(this IOperation operation)
IFieldReferenceOperation fieldReferenceOperation => fieldReferenceOperation.Field,
_ => null
};

return symbol;
}

public static IEnumerable<IOperation> GetArrayElementValues(this IOperation operation)
public static IEnumerable<IOperation>? GetArrayElementValues(this IOperation operation)
{
return operation switch
{
Expand Down

0 comments on commit 9395f5d

Please sign in to comment.