From 29c11cd862eec73c86d845984bef6f0caadd30db Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 12:30:53 -0700 Subject: [PATCH 01/10] Properly implement type inference for C# member access expressions. --- ...CSharpTypeInferenceService.TypeInferrer.cs | 39 ++- .../Extensions/ObjectExtensions.TypeSwitch.cs | 232 ++++++++++++++++++ 2 files changed, 264 insertions(+), 7 deletions(-) diff --git a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs index 949266cbcd207..886141275e2cb 100644 --- a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs +++ b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics; using System.Linq; using System.Threading; using Microsoft.CodeAnalysis.CSharp.Extensions; @@ -113,7 +114,8 @@ private IEnumerable GetTypesSimple(ExpressionSyntax expression) return SpecializedCollections.EmptyEnumerable(); } - protected override IEnumerable InferTypesWorker_DoNotCallDirectly(ExpressionSyntax expression) + protected override IEnumerable InferTypesWorker_DoNotCallDirectly( + ExpressionSyntax expression) { expression = expression.WalkUpParentheses(); var parent = expression.Parent; @@ -145,7 +147,7 @@ protected override IEnumerable InferTypesWorker_DoNotCallDirectly(E (InitializerExpressionSyntax initializerExpression) => InferTypeInInitializerExpression(initializerExpression, expression), (IsPatternExpressionSyntax isPatternExpression) => InferTypeInIsPatternExpression(isPatternExpression, expression), (LockStatementSyntax lockStatement) => InferTypeInLockStatement(lockStatement), - (MemberAccessExpressionSyntax memberAccessExpression) => InferTypeInMemberAccessExpression(memberAccessExpression), + (MemberAccessExpressionSyntax memberAccessExpression) => InferTypeInMemberAccessExpression(memberAccessExpression, expression), (NameEqualsSyntax nameEquals) => InferTypeInNameEquals(nameEquals), (ParenthesizedLambdaExpressionSyntax parenthesizedLambdaExpression) => InferTypeInParenthesizedLambdaExpression(parenthesizedLambdaExpression), (PostfixUnaryExpressionSyntax postfixUnary) => InferTypeInPostfixUnaryExpression(postfixUnary), @@ -214,6 +216,7 @@ protected override IEnumerable InferTypesWorker_DoNotCallDirectly(i (IfStatementSyntax ifStatement) => InferTypeInIfStatement(ifStatement, token), (InitializerExpressionSyntax initializerExpression) => InferTypeInInitializerExpression(initializerExpression, previousToken: token), (LockStatementSyntax lockStatement) => InferTypeInLockStatement(lockStatement, token), + (MemberAccessExpressionSyntax memberAccessExpression) => InferTypeInMemberAccessExpression(memberAccessExpression, previousToken: token), (NameColonSyntax nameColon) => InferTypeInNameColon(nameColon, token), (NameEqualsSyntax nameEquals) => InferTypeInNameEquals(nameEquals, token), (ObjectCreationExpressionSyntax objectCreation) => InferTypeInObjectCreationExpression(objectCreation, token), @@ -1426,15 +1429,37 @@ private IEnumerable InferTypeInNameColon(NameColonSyntax nameColon, return SpecializedCollections.EmptyEnumerable(); } - private IEnumerable InferTypeInMemberAccessExpression(MemberAccessExpressionSyntax expression) + private IEnumerable InferTypeInMemberAccessExpression( + MemberAccessExpressionSyntax memberAccessExpression, + ExpressionSyntax expressionOpt = null, + SyntaxToken? previousToken = null) { - var awaitExpression = expression.GetAncestor(); - if (awaitExpression != null) + if (previousToken != null) + { + if (previousToken.Value != memberAccessExpression.OperatorToken) + { + return SpecializedCollections.EmptyEnumerable(); + } + + // fall through + } + else { - return InferTypes(awaitExpression.Expression); + Debug.Assert(expressionOpt != null); + if (expressionOpt != memberAccessExpression.Name) + { + // we're not on the name portion of hte member access expressoin. + // i.e. we're in "Foo" in "Foo.Bar". We can't figure a name for this + // at all. + return SpecializedCollections.EmptyEnumerable(); + } + + // fall through } - return SpecializedCollections.EmptyEnumerable(); + // We're right after the dot in "Foo.Bar". The type for "Bar" should be + // whatever type we'd infer for "Foo.Bar" itself. + return InferTypes(memberAccessExpression); } private IEnumerable InferTypeInNameEquals(NameEqualsSyntax nameEquals, SyntaxToken? previousToken = null) diff --git a/src/Workspaces/Core/Portable/Shared/Extensions/ObjectExtensions.TypeSwitch.cs b/src/Workspaces/Core/Portable/Shared/Extensions/ObjectExtensions.TypeSwitch.cs index dcfbf74a817bd..c5bf8a40002d2 100644 --- a/src/Workspaces/Core/Portable/Shared/Extensions/ObjectExtensions.TypeSwitch.cs +++ b/src/Workspaces/Core/Portable/Shared/Extensions/ObjectExtensions.TypeSwitch.cs @@ -8326,6 +8326,238 @@ public static TResult TypeSwitch(this TBaseType obj, Func matchFunc1, Func matchFunc2, Func matchFunc3, Func matchFunc4, Func matchFunc5, Func matchFunc6, Func matchFunc7, Func matchFunc8, Func matchFunc9, Func matchFunc10, Func matchFunc11, Func matchFunc12, Func matchFunc13, Func matchFunc14, Func matchFunc15, Func matchFunc16, Func matchFunc17, Func matchFunc18, Func matchFunc19, Func matchFunc20, Func matchFunc21, Func matchFunc22, Func matchFunc23, Func matchFunc24, Func matchFunc25, Func matchFunc26, Func matchFunc27, Func matchFunc28, Func matchFunc29, Func matchFunc30, Func matchFunc31, Func matchFunc32, Func matchFunc33, Func matchFunc34, Func matchFunc35, Func matchFunc36, Func matchFunc37, Func matchFunc38, Func matchFunc39, Func matchFunc40, Func matchFunc41, Func matchFunc42, Func matchFunc43, Func matchFunc44, Func defaultFunc = null) + where TDerivedType1 : TBaseType + where TDerivedType2 : TBaseType + where TDerivedType3 : TBaseType + where TDerivedType4 : TBaseType + where TDerivedType5 : TBaseType + where TDerivedType6 : TBaseType + where TDerivedType7 : TBaseType + where TDerivedType8 : TBaseType + where TDerivedType9 : TBaseType + where TDerivedType10 : TBaseType + where TDerivedType11 : TBaseType + where TDerivedType12 : TBaseType + where TDerivedType13 : TBaseType + where TDerivedType14 : TBaseType + where TDerivedType15 : TBaseType + where TDerivedType16 : TBaseType + where TDerivedType17 : TBaseType + where TDerivedType18 : TBaseType + where TDerivedType19 : TBaseType + where TDerivedType20 : TBaseType + where TDerivedType21 : TBaseType + where TDerivedType22 : TBaseType + where TDerivedType23 : TBaseType + where TDerivedType24 : TBaseType + where TDerivedType25 : TBaseType + where TDerivedType26 : TBaseType + where TDerivedType27 : TBaseType + where TDerivedType28 : TBaseType + where TDerivedType29 : TBaseType + where TDerivedType30 : TBaseType + where TDerivedType31 : TBaseType + where TDerivedType32 : TBaseType + where TDerivedType33 : TBaseType + where TDerivedType34 : TBaseType + where TDerivedType35 : TBaseType + where TDerivedType36 : TBaseType + where TDerivedType37 : TBaseType + where TDerivedType38 : TBaseType + where TDerivedType39 : TBaseType + where TDerivedType40 : TBaseType + where TDerivedType41 : TBaseType + where TDerivedType42 : TBaseType + where TDerivedType43 : TBaseType + where TDerivedType44 : TBaseType + { + if (obj is TDerivedType1) + { + return matchFunc1((TDerivedType1)obj); + } + else if (obj is TDerivedType2) + { + return matchFunc2((TDerivedType2)obj); + } + else if (obj is TDerivedType3) + { + return matchFunc3((TDerivedType3)obj); + } + else if (obj is TDerivedType4) + { + return matchFunc4((TDerivedType4)obj); + } + else if (obj is TDerivedType5) + { + return matchFunc5((TDerivedType5)obj); + } + else if (obj is TDerivedType6) + { + return matchFunc6((TDerivedType6)obj); + } + else if (obj is TDerivedType7) + { + return matchFunc7((TDerivedType7)obj); + } + else if (obj is TDerivedType8) + { + return matchFunc8((TDerivedType8)obj); + } + else if (obj is TDerivedType9) + { + return matchFunc9((TDerivedType9)obj); + } + else if (obj is TDerivedType10) + { + return matchFunc10((TDerivedType10)obj); + } + else if (obj is TDerivedType11) + { + return matchFunc11((TDerivedType11)obj); + } + else if (obj is TDerivedType12) + { + return matchFunc12((TDerivedType12)obj); + } + else if (obj is TDerivedType13) + { + return matchFunc13((TDerivedType13)obj); + } + else if (obj is TDerivedType14) + { + return matchFunc14((TDerivedType14)obj); + } + else if (obj is TDerivedType15) + { + return matchFunc15((TDerivedType15)obj); + } + else if (obj is TDerivedType16) + { + return matchFunc16((TDerivedType16)obj); + } + else if (obj is TDerivedType17) + { + return matchFunc17((TDerivedType17)obj); + } + else if (obj is TDerivedType18) + { + return matchFunc18((TDerivedType18)obj); + } + else if (obj is TDerivedType19) + { + return matchFunc19((TDerivedType19)obj); + } + else if (obj is TDerivedType20) + { + return matchFunc20((TDerivedType20)obj); + } + else if (obj is TDerivedType21) + { + return matchFunc21((TDerivedType21)obj); + } + else if (obj is TDerivedType22) + { + return matchFunc22((TDerivedType22)obj); + } + else if (obj is TDerivedType23) + { + return matchFunc23((TDerivedType23)obj); + } + else if (obj is TDerivedType24) + { + return matchFunc24((TDerivedType24)obj); + } + else if (obj is TDerivedType25) + { + return matchFunc25((TDerivedType25)obj); + } + else if (obj is TDerivedType26) + { + return matchFunc26((TDerivedType26)obj); + } + else if (obj is TDerivedType27) + { + return matchFunc27((TDerivedType27)obj); + } + else if (obj is TDerivedType28) + { + return matchFunc28((TDerivedType28)obj); + } + else if (obj is TDerivedType29) + { + return matchFunc29((TDerivedType29)obj); + } + else if (obj is TDerivedType30) + { + return matchFunc30((TDerivedType30)obj); + } + else if (obj is TDerivedType31) + { + return matchFunc31((TDerivedType31)obj); + } + else if (obj is TDerivedType32) + { + return matchFunc32((TDerivedType32)obj); + } + else if (obj is TDerivedType33) + { + return matchFunc33((TDerivedType33)obj); + } + else if (obj is TDerivedType34) + { + return matchFunc34((TDerivedType34)obj); + } + else if (obj is TDerivedType35) + { + return matchFunc35((TDerivedType35)obj); + } + else if (obj is TDerivedType36) + { + return matchFunc36((TDerivedType36)obj); + } + else if (obj is TDerivedType37) + { + return matchFunc37((TDerivedType37)obj); + } + else if (obj is TDerivedType38) + { + return matchFunc38((TDerivedType38)obj); + } + else if (obj is TDerivedType39) + { + return matchFunc39((TDerivedType39)obj); + } + else if (obj is TDerivedType40) + { + return matchFunc40((TDerivedType40)obj); + } + else if (obj is TDerivedType41) + { + return matchFunc41((TDerivedType41)obj); + } + else if (obj is TDerivedType42) + { + return matchFunc42((TDerivedType42)obj); + } + else if (obj is TDerivedType43) + { + return matchFunc43((TDerivedType43)obj); + } + else if (obj is TDerivedType44) + { + return matchFunc44((TDerivedType44)obj); + } + else if (defaultFunc != null) + { + return defaultFunc(obj); + } + else + { + return default(TResult); + } + } #endregion } } \ No newline at end of file From 045003e636ea2ce96222c655fc1ce0efbd963a0f Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 12:35:39 -0700 Subject: [PATCH 02/10] Add VB side. --- ...lBasicTypeInferenceService.TypeInferrer.vb | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb index a7cdbcc345d12..8ec29590bdc52 100644 --- a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb +++ b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb @@ -801,14 +801,29 @@ Namespace Microsoft.CodeAnalysis.VisualBasic Return SpecializedCollections.SingletonEnumerable(Me.Compilation.GetSpecialType(SpecialType.System_Boolean)) End Function - Private Function InferTypeInMemberAccessExpression(expression As MemberAccessExpressionSyntax) As IEnumerable(Of ITypeSymbol) - Dim awaitExpression = expression.GetAncestor(Of AwaitExpressionSyntax) - Dim lambdaExpression = expression.GetAncestor(Of LambdaExpressionSyntax) - If Not awaitExpression?.Contains(lambdaExpression) AndAlso awaitExpression IsNot Nothing Then - Return InferTypes(awaitExpression.Expression) + Private Function InferTypeInMemberAccessExpression( + memberAccessExpression As MemberAccessExpressionSyntax, + Optional expressionOpt As ExpressionSyntax = Nothing, + Optional previousTokenOpt As SyntaxToken? = Nothing) As IEnumerable(Of ITypeSymbol) + + ' We need to be on the right of the dot to infer an appropriate type for + ' the member access expression. i.e. if we have "Foo.Bar" then we can + ' def infer what the type of 'Bar' should be (it's whatever type we infer + ' for 'Foo.Bar' itself. However, if we're on 'Foo' then we can't figure + ' out anything about its type. + If previousTokenOpt <> Nothing Then + If previousTokenOpt.Value <> memberAccessExpression.OperatorToken Then + Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol) + End If + ' fall through + Else + If expressionOpt IsNot memberAccessExpression.Name Then + Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol) + End If + ' fall through End If - Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol)() + Return InferTypes(memberAccessExpression) End Function Private Function InferTypeInNamedFieldInitializer(initializer As NamedFieldInitializerSyntax, Optional previousToken As SyntaxToken = Nothing) As IEnumerable(Of ITypeSymbol) From 1e4027f13e306f9e184a90d4700b8aea880ed7dd Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 12:36:10 -0700 Subject: [PATCH 03/10] Share comment. --- .../CSharpTypeInferenceService.TypeInferrer.cs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs index 886141275e2cb..64ca77bfbdeec 100644 --- a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs +++ b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs @@ -1434,6 +1434,11 @@ private IEnumerable InferTypeInMemberAccessExpression( ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { + // We need to be on the right of the dot to infer an appropriate type for + // the member access expression. i.e. if we have "Foo.Bar" then we can + // def infer what the type of 'Bar' should be (it's whatever type we infer + // for 'Foo.Bar' itself. However, if we're on 'Foo' then we can't figure + // out anything about its type. if (previousToken != null) { if (previousToken.Value != memberAccessExpression.OperatorToken) @@ -1448,9 +1453,6 @@ private IEnumerable InferTypeInMemberAccessExpression( Debug.Assert(expressionOpt != null); if (expressionOpt != memberAccessExpression.Name) { - // we're not on the name portion of hte member access expressoin. - // i.e. we're in "Foo" in "Foo.Bar". We can't figure a name for this - // at all. return SpecializedCollections.EmptyEnumerable(); } From 91cad67623322bb16859cdc6e3f6b646402a7d6b Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 12:40:35 -0700 Subject: [PATCH 04/10] More VB side. --- .../VisualBasicTypeInferenceService.TypeInferrer.vb | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb index 8ec29590bdc52..a36a76d3eab2d 100644 --- a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb +++ b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb @@ -75,7 +75,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic Function(forStepClause As ForStepClauseSyntax) InferTypeInForStepClause(forStepClause), Function(ifStatement As ElseIfStatementSyntax) InferTypeInIfOrElseIfStatement(), Function(ifStatement As IfStatementSyntax) InferTypeInIfOrElseIfStatement(), - Function(memberAccessExpression As MemberAccessExpressionSyntax) InferTypeInMemberAccessExpression(memberAccessExpression), + Function(memberAccessExpression As MemberAccessExpressionSyntax) InferTypeInMemberAccessExpression(memberAccessExpression, expression), Function(namedFieldInitializer As NamedFieldInitializerSyntax) InferTypeInNamedFieldInitializer(namedFieldInitializer), Function(parenthesizedLambda As MultiLineLambdaExpressionSyntax) InferTypeInLambda(parenthesizedLambda), Function(prefixUnary As UnaryExpressionSyntax) InferTypeInUnaryExpression(prefixUnary), @@ -145,6 +145,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic Function(forStatement As ForStatementSyntax) InferTypeInForStatement(forStatement, previousToken:=token), Function(forStepClause As ForStepClauseSyntax) InferTypeInForStepClause(forStepClause, token), Function(ifStatement As IfStatementSyntax) InferTypeInIfOrElseIfStatement(token), + Function(memberAccessExpression As MemberAccessExpressionSyntax) InferTypeInMemberAccessExpression(memberAccessExpression, previousToken:=token), Function(nameColonEquals As NameColonEqualsSyntax) InferTypeInArgumentList(TryCast(nameColonEquals.Parent.Parent, ArgumentListSyntax), DirectCast(nameColonEquals.Parent, ArgumentSyntax)), Function(namedFieldInitializer As NamedFieldInitializerSyntax) InferTypeInNamedFieldInitializer(namedFieldInitializer, token), Function(objectCreation As ObjectCreationExpressionSyntax) InferTypes(objectCreation), @@ -203,7 +204,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic Dim targetExpression As ExpressionSyntax = Nothing If invocation.Expression IsNot Nothing Then targetExpression = invocation.Expression - ElseIf invocation.Parent.IsKind(SyntaxKind.ConditionalAccessExpression) + ElseIf invocation.Parent.IsKind(SyntaxKind.ConditionalAccessExpression) Then targetExpression = DirectCast(invocation.Parent, ConditionalAccessExpressionSyntax).Expression End If @@ -804,15 +805,15 @@ Namespace Microsoft.CodeAnalysis.VisualBasic Private Function InferTypeInMemberAccessExpression( memberAccessExpression As MemberAccessExpressionSyntax, Optional expressionOpt As ExpressionSyntax = Nothing, - Optional previousTokenOpt As SyntaxToken? = Nothing) As IEnumerable(Of ITypeSymbol) + Optional previousToken As SyntaxToken? = Nothing) As IEnumerable(Of ITypeSymbol) ' We need to be on the right of the dot to infer an appropriate type for ' the member access expression. i.e. if we have "Foo.Bar" then we can ' def infer what the type of 'Bar' should be (it's whatever type we infer ' for 'Foo.Bar' itself. However, if we're on 'Foo' then we can't figure ' out anything about its type. - If previousTokenOpt <> Nothing Then - If previousTokenOpt.Value <> memberAccessExpression.OperatorToken Then + If previousToken <> Nothing Then + If previousToken.Value <> memberAccessExpression.OperatorToken Then Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol) End If ' fall through From 41a1d64374e5241bc9cc80a290aa2a802fca2c95 Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 13:43:03 -0700 Subject: [PATCH 05/10] Fix inference in awaits. --- .../GenerateMethod/GenerateMethodTests.cs | 20 ++++++++++-- .../TypeInferrer/TypeInferrerTests.cs | 4 +-- .../GenerateMethod/GenerateMethodTests.vb | 31 +++++++++++-------- .../TypeInferrer/TypeInferrerTests.vb | 4 +-- ...CSharpTypeInferenceService.TypeInferrer.cs | 30 +++++++++++++----- ...lBasicTypeInferenceService.TypeInferrer.vb | 26 ++++++++++++---- 6 files changed, 82 insertions(+), 33 deletions(-) diff --git a/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs b/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs index 27fec685810cc..0ecdb5246d1c8 100644 --- a/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs +++ b/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs @@ -2757,7 +2757,15 @@ await TestAsync( public async Task TestGenerateMethodWithMethodChaining() { await TestAsync( -@"using System ; using System . Collections . Generic ; using System . Linq ; using System . Threading . Tasks ; class Program { static void Main ( string [ ] args ) { bool x = await [|Foo|] ( ) . ConfigureAwait ( false ) ; } } ", +@"using System ; +using System . Collections . Generic ; +using System . Linq ; +using System . Threading . Tasks ; +class Program { + static void Main ( string [ ] args ) { + bool x = await [|Foo|] ( ) . ConfigureAwait ( false ) ; + } +}", @"using System ; using System . Collections . Generic ; using System . Linq ; using System . Threading . Tasks ; class Program { static void Main ( string [ ] args ) { bool x = await Foo ( ) . ConfigureAwait ( false ) ; } private static Task < bool > Foo ( ) { throw new NotImplementedException ( ) ; } } "); } @@ -2766,8 +2774,14 @@ await TestAsync( public async Task TestGenerateMethodWithMethodChaining2() { await TestAsync( -@"using System ; using System . Threading . Tasks ; class C { static async void T ( ) { bool x = await [|M|] ( ) . ContinueWith ( a => { return true ; } ) . ContinueWith ( a => { return false ; } ) ; } } ", -@"using System ; using System . Threading . Tasks ; class C { static async void T ( ) { bool x = await M ( ) . ContinueWith ( a => { return true ; } ) . ContinueWith ( a => { return false ; } ) ; } private static Task < bool > M ( ) { throw new NotImplementedException ( ) ; } } "); +@"using System ; +using System . Threading . Tasks ; +class C { + static async void T ( ) { + bool x = await [|M|] ( ) . ContinueWith ( a => { return true ; } ) . ContinueWith ( a => { return false ; } ) ; + } +} ", +@"using System ; using System . Threading . Tasks ; class C { static async void T ( ) { bool x = await M ( ) . ContinueWith ( a => { return true ; } ) . ContinueWith ( a => { return false ; } ) ; } private static object M ( ) { throw new NotImplementedException ( ) ; } } "); } [WorkItem(529480, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/529480")] diff --git a/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs b/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs index a311e372c2d78..1fe691364bb3e 100644 --- a/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs +++ b/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs @@ -1591,7 +1591,7 @@ static async void T() bool x = await [|M()|].ConfigureAwait(false); } }"; - await TestAsync(text, "global::System.Threading.Tasks.Task"); + await TestAsync(text, "global::System.Threading.Tasks.Task", testPosition: false); } [Fact, Trait(Traits.Feature, Traits.Features.TypeInferenceService)] @@ -1609,7 +1609,7 @@ static async void T() bool x = await [|M|].ContinueWith(a => { return true; }).ContinueWith(a => { return false; }); } }"; - await TestAsync(text, "global::System.Threading.Tasks.Task"); + await TestAsync(text, "System.Object", testPosition: false); } [Fact, Trait(Traits.Feature, Traits.Features.TypeInferenceService)] diff --git a/src/EditorFeatures/VisualBasicTest/Diagnostics/GenerateMethod/GenerateMethodTests.vb b/src/EditorFeatures/VisualBasicTest/Diagnostics/GenerateMethod/GenerateMethodTests.vb index 559334c8642c4..8f8937b82937e 100644 --- a/src/EditorFeatures/VisualBasicTest/Diagnostics/GenerateMethod/GenerateMethodTests.vb +++ b/src/EditorFeatures/VisualBasicTest/Diagnostics/GenerateMethod/GenerateMethodTests.vb @@ -1,10 +1,8 @@ ' Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -Option Strict Off Imports Microsoft.CodeAnalysis.CodeFixes Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.GenerateMethod Imports Microsoft.CodeAnalysis.Diagnostics -Imports System.Threading.Tasks Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.GenerateMethod Public Class GenerateMethodTests @@ -2190,17 +2188,24 @@ index:=1) Public Async Function TestGenerateMethodWithMethodChaining() As Task Await TestAsync( -NewLines("Imports System \n Imports System.Linq \n Module M \n Async Sub T() \n Dim x As Boolean = Await [|F|]().ContinueWith(Function(a) True).ContinueWith(Function(a) False) \n End Sub \n End Module"), -NewLines("Imports System\nImports System.Linq\nImports System.Threading.Tasks\n\nModule M\n Async Sub T()\n Dim x As Boolean = Await F().ContinueWith(Function(a) True).ContinueWith(Function(a) False)\n End Sub\n\n Private Function F() As Task(Of Boolean)\n Throw New NotImplementedException()\n End Function\nEnd Module")) - End Function - - - - Public Async Function TestGenerateMethodWithMethodChaining2() As Task - Await TestAsync( -NewLines("Imports System \n Imports System.Linq \n Module M \n Async Sub T() \n Dim x As Boolean = Await [|F|]().ContinueWith(Function(a) True).ContinueWith(Function(a) False) \n End Sub \n End Module"), -NewLines("Imports System\nImports System.Linq\nImports System.Threading.Tasks\n\nModule M\n Async Sub T()\n Dim x As Boolean = Await F().ContinueWith(Function(a) True).ContinueWith(Function(a) False)\n End Sub\n\n Private ReadOnly Property F As Task(Of Boolean)\n Get\n Throw New NotImplementedException()\n End Get\n End Property\nEnd Module"), -index:=1) +"Imports System +Imports System.Linq +Module M + Async Sub T() + Dim x As Boolean = Await [|F|]().ConfigureAwait(False) + End Sub +End Module", +"Imports System +Imports System.Linq +Imports System.Threading.Tasks +Module M + Async Sub T() + Dim x As Boolean = Await F().ConfigureAwait(False) + End Sub + Private Function F() As Task(Of Boolean) + Throw New NotImplementedException() + End Function +End Module") End Function diff --git a/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb b/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb index 3d39e131eb99c..45246dfa3a5c9 100644 --- a/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb +++ b/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb @@ -705,7 +705,7 @@ Module M Dim x As Boolean = Await [|F|].ContinueWith(Function(a) True).ContinueWith(Function(a) False) End Sub End Module" - Await TestAsync(text, "Global.System.Threading.Tasks.Task(Of System.Boolean)", testPosition:=True) + Await TestAsync(text, "System.Object", testPosition:=False) End Function @@ -719,7 +719,7 @@ Module M Dim x As Boolean = Await [|F|].ConfigureAwait(False) End Sub End Module" - Await TestAsync(text, "Global.System.Threading.Tasks.Task(Of System.Boolean)", testPosition:=True) + Await TestAsync(text, "Global.System.Threading.Tasks.Task(Of System.Boolean)", testPosition:=False) End Function diff --git a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs index 64ca77bfbdeec..f9004cd7cdf5b 100644 --- a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs +++ b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Linq; using System.Threading; +using System.Threading.Tasks; using Microsoft.CodeAnalysis.CSharp.Extensions; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Shared.Extensions; @@ -1446,22 +1447,37 @@ private IEnumerable InferTypeInMemberAccessExpression( return SpecializedCollections.EmptyEnumerable(); } - // fall through + // We're right after the dot in "Foo.Bar". The type for "Bar" should be + // whatever type we'd infer for "Foo.Bar" itself. + return InferTypes(memberAccessExpression); } else { Debug.Assert(expressionOpt != null); - if (expressionOpt != memberAccessExpression.Name) + if (expressionOpt == memberAccessExpression.Expression) { + // If we're on the left side of a dot, it's possible in a few cases + // to figure out what type we should be. Specifically, if we have + // + // await foo.ConfigureAwait() + // + // then we can figure out what 'foo' should be based on teh await + // context. + + if (memberAccessExpression.Name.Identifier.Value.Equals(nameof(Task.ConfigureAwait)) && + memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) && + memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression)) + { + return InferTypes((ExpressionSyntax)memberAccessExpression.Parent); + } + return SpecializedCollections.EmptyEnumerable(); } - // fall through + // We're right after the dot in "Foo.Bar". The type for "Bar" should be + // whatever type we'd infer for "Foo.Bar" itself. + return InferTypes(memberAccessExpression); } - - // We're right after the dot in "Foo.Bar". The type for "Bar" should be - // whatever type we'd infer for "Foo.Bar" itself. - return InferTypes(memberAccessExpression); } private IEnumerable InferTypeInNameEquals(NameEqualsSyntax nameEquals, SyntaxToken? previousToken = null) diff --git a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb index a36a76d3eab2d..6c6f9e1cce1d1 100644 --- a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb +++ b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb @@ -2,6 +2,7 @@ Imports System.Collections.Immutable Imports System.Threading +Imports System.Threading.Tasks Imports Microsoft.CodeAnalysis Imports Microsoft.CodeAnalysis.VisualBasic.Symbols Imports Microsoft.CodeAnalysis.VisualBasic.Syntax @@ -816,15 +817,28 @@ Namespace Microsoft.CodeAnalysis.VisualBasic If previousToken.Value <> memberAccessExpression.OperatorToken Then Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol) End If - ' fall through + + Return InferTypes(memberAccessExpression) Else - If expressionOpt IsNot memberAccessExpression.Name Then - Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol) + ' If we're on the left side of a dot, it's possible in a few cases + ' to figure out what type we should be. Specifically, if we have + ' + ' await foo.ConfigureAwait() + ' + ' then we can figure out what 'foo' should be based on teh await + ' context. + If expressionOpt Is memberAccessExpression.Expression Then + If memberAccessExpression.Name.Identifier.Value.Equals(NameOf(Task(Of Integer).ConfigureAwait)) AndAlso + memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) AndAlso + memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression) Then + Return InferTypes(DirectCast(memberAccessExpression.Parent, ExpressionSyntax)) + End If + + Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol)() End If - ' fall through - End If - Return InferTypes(memberAccessExpression) + Return InferTypes(memberAccessExpression) + End If End Function Private Function InferTypeInNamedFieldInitializer(initializer As NamedFieldInitializerSyntax, Optional previousToken As SyntaxToken = Nothing) As IEnumerable(Of ITypeSymbol) From f29a6cb07a768619bb91c4456d9bff6bc4fc1a4c Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 20:56:35 -0700 Subject: [PATCH 06/10] Add better inference for task and enumerable scenarios. --- .../TypeInferrer/TypeInferrerTests.cs | 36 +++++- ...CSharpTypeInferenceService.TypeInferrer.cs | 121 +++++++++++++++++- .../Extensions/ICompilationExtensions.cs | 6 + .../Shared/Extensions/ISymbolExtensions.cs | 5 +- 4 files changed, 160 insertions(+), 8 deletions(-) diff --git a/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs b/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs index 1fe691364bb3e..dbb05b8ac3738 100644 --- a/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs +++ b/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs @@ -1609,7 +1609,7 @@ static async void T() bool x = await [|M|].ContinueWith(a => { return true; }).ContinueWith(a => { return false; }); } }"; - await TestAsync(text, "System.Object", testPosition: false); + await TestAsync(text, "global::System.Threading.Tasks.Task", testPosition: false); } [Fact, Trait(Traits.Feature, Traits.Features.TypeInferenceService)] @@ -1818,5 +1818,39 @@ static void Foo(System.ConsoleModifiers arg) }"; await TestAsync(text, "global::System.ConsoleModifiers", testNode: false); } + + [Fact, Trait(Traits.Feature, Traits.Features.TypeInferenceService)] + [WorkItem(6765, "https://github.com/dotnet/roslyn/issues/6765")] + public async Task TestWhereCall() + { + var text = + @" +using System.Collections.Generic; +class C +{ + void Foo() + { + [|ints|].Where(i => i > 10); + } +}"; + await TestAsync(text, "global::System.Collections.Generic.IEnumerable", testPosition: false); + } + + [Fact, Trait(Traits.Feature, Traits.Features.TypeInferenceService)] + [WorkItem(6765, "https://github.com/dotnet/roslyn/issues/6765")] + public async Task TestWhereCall2() + { + var text = + @" +using System.Collections.Generic; +class C +{ + void Foo() + { + [|ints|].Where(i => null); + } +}"; + await TestAsync(text, "global::System.Collections.Generic.IEnumerable", testPosition: false); + } } } diff --git a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs index f9004cd7cdf5b..83233595276c8 100644 --- a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs +++ b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs @@ -235,7 +235,8 @@ protected override IEnumerable InferTypesWorker_DoNotCallDirectly(i _ => SpecializedCollections.EmptyEnumerable()); } - private IEnumerable InferTypeInArgument(ArgumentSyntax argument, SyntaxToken? previousToken = null) + private IEnumerable InferTypeInArgument( + ArgumentSyntax argument, SyntaxToken? previousToken = null) { if (previousToken.HasValue) { @@ -1463,13 +1464,57 @@ private IEnumerable InferTypeInMemberAccessExpression( // // then we can figure out what 'foo' should be based on teh await // context. - - if (memberAccessExpression.Name.Identifier.Value.Equals(nameof(Task.ConfigureAwait)) && + var name = memberAccessExpression.Name.Identifier.Value; + if (name.Equals(nameof(Task.ConfigureAwait)) && memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) && memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression)) { return InferTypes((ExpressionSyntax)memberAccessExpression.Parent); } + else if (name.Equals(nameof(Task.ContinueWith))) + { + // foo.ContinueWith(...) + // We want to infer Task. For now, we'll just do Task, + // in the future it would be nice to figure out the actual result + // type based on the argument to ContinueWith. + var taskOfT = this.Compilation.TaskOfTType(); + if (taskOfT != null) + { + return SpecializedCollections.SingletonEnumerable( + taskOfT.Construct(this.Compilation.ObjectType)); + } + } + else if (name.Equals(nameof(Enumerable.Select)) || + name.Equals(nameof(Enumerable.Where))) + { + var ienumerableType = this.Compilation.IEnumerableOfTType(); + + // foo.Select + // We want to infer IEnumerable. We can try to figure out what + // T if we get a delegate as the first argument to Select/Where. + if (ienumerableType != null && memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression)) + { + var invocation = (InvocationExpressionSyntax)memberAccessExpression.Parent; + if (invocation.ArgumentList.Arguments.Count > 0) + { + var argumentExpression = invocation.ArgumentList.Arguments[0].Expression; + var argumentTypes = GetTypes(argumentExpression); + var delegateType = argumentTypes.FirstOrDefault().GetDelegateType(this.Compilation); + var typeArg = delegateType?.TypeArguments.Length > 0 + ? delegateType.TypeArguments[0] + : this.Compilation.ObjectType; + + if (IsUnusableType(typeArg) && argumentExpression is LambdaExpressionSyntax) + { + typeArg = InferTypeForFirstParameterOfLambda((LambdaExpressionSyntax)argumentExpression) ?? + this.Compilation.ObjectType; + } + + return SpecializedCollections.SingletonEnumerable( + ienumerableType.Construct(typeArg)); + } + } + } return SpecializedCollections.EmptyEnumerable(); } @@ -1480,6 +1525,76 @@ private IEnumerable InferTypeInMemberAccessExpression( } } + private ITypeSymbol InferTypeForFirstParameterOfLambda( + LambdaExpressionSyntax lambdaExpression) + { + if (lambdaExpression is ParenthesizedLambdaExpressionSyntax) + { + return InferTypeForFirstParameterOfParenthesizedLambda( + (ParenthesizedLambdaExpressionSyntax)lambdaExpression); + } + else if (lambdaExpression is SimpleLambdaExpressionSyntax) + { + return InferTypeForFirstParameterOfSimpleLambda( + (SimpleLambdaExpressionSyntax)lambdaExpression); + } + + return null; + } + + private ITypeSymbol InferTypeForFirstParameterOfParenthesizedLambda( + ParenthesizedLambdaExpressionSyntax lambdaExpression) + { + return lambdaExpression.ParameterList.Parameters.Count == 0 + ? null + : InferTypeForFirstParameterOfLambda( + lambdaExpression, lambdaExpression.ParameterList.Parameters[0]); + } + + private ITypeSymbol InferTypeForFirstParameterOfSimpleLambda( + SimpleLambdaExpressionSyntax lambdaExpression) + { + return InferTypeForFirstParameterOfLambda( + lambdaExpression, lambdaExpression.Parameter); + } + + private ITypeSymbol InferTypeForFirstParameterOfLambda( + LambdaExpressionSyntax lambdaExpression, ParameterSyntax parameter) + { + return InferTypeForFirstParameterOfLambda( + parameter.Identifier.ValueText, lambdaExpression.Body); + } + + private ITypeSymbol InferTypeForFirstParameterOfLambda( + string parameterName, + SyntaxNode node) + { + if (node.IsKind(SyntaxKind.IdentifierName)) + { + var identifierName = (IdentifierNameSyntax)node; + if (identifierName.Identifier.ValueText.Equals(parameterName)) + { + return InferTypes(identifierName).FirstOrDefault(); + } + } + else + { + foreach (var child in node.ChildNodesAndTokens()) + { + if (child.IsNode) + { + var type = InferTypeForFirstParameterOfLambda(parameterName, child.AsNode()); + if (type != null) + { + return type; + } + } + } + } + + return null; + } + private IEnumerable InferTypeInNameEquals(NameEqualsSyntax nameEquals, SyntaxToken? previousToken = null) { if (previousToken == nameEquals.EqualsToken) diff --git a/src/Workspaces/Core/Portable/Shared/Extensions/ICompilationExtensions.cs b/src/Workspaces/Core/Portable/Shared/Extensions/ICompilationExtensions.cs index 9bb39803fcff0..31e118a15a828 100644 --- a/src/Workspaces/Core/Portable/Shared/Extensions/ICompilationExtensions.cs +++ b/src/Workspaces/Core/Portable/Shared/Extensions/ICompilationExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; @@ -114,6 +115,11 @@ public static INamedTypeSymbol TaskOfTType(this Compilation compilation) return compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1"); } + public static INamedTypeSymbol IEnumerableOfTType(this Compilation compilation) + { + return compilation.GetTypeByMetadataName("System.Collections.Generic.IEnumerable`1"); + } + public static INamedTypeSymbol SerializableAttributeType(this Compilation compilation) { return compilation.GetTypeByMetadataName("System.SerializableAttribute"); diff --git a/src/Workspaces/Core/Portable/Shared/Extensions/ISymbolExtensions.cs b/src/Workspaces/Core/Portable/Shared/Extensions/ISymbolExtensions.cs index 94e535044ada7..bf990ff269338 100644 --- a/src/Workspaces/Core/Portable/Shared/Extensions/ISymbolExtensions.cs +++ b/src/Workspaces/Core/Portable/Shared/Extensions/ISymbolExtensions.cs @@ -504,10 +504,7 @@ public static ITypeSymbol ConvertToType( .Skip(skip) .Select(p => p.Type) .Concat(method.ReturnType) - .Select(t => - t == null ? - compilation.GetSpecialType(SpecialType.System_Object) : - t) + .Select(t => t ?? compilation.GetSpecialType(SpecialType.System_Object)) .ToArray(); return functionType.Construct(types); } From 6bdb5a53a5de4eac37ef4f68e07af36d4246df65 Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 21:26:42 -0700 Subject: [PATCH 07/10] VB side of type inference. --- .../Portable/Syntax/SyntaxNodePartials.vb | 2 +- .../TypeInferrer/TypeInferrerTests.cs | 2 - .../TypeInferrer/TypeInferrerTests.vb | 28 +++- ...CSharpTypeInferenceService.TypeInferrer.cs | 124 +++++++++--------- ...lBasicTypeInferenceService.TypeInferrer.vb | 102 +++++++++++++- 5 files changed, 189 insertions(+), 69 deletions(-) diff --git a/src/Compilers/VisualBasic/Portable/Syntax/SyntaxNodePartials.vb b/src/Compilers/VisualBasic/Portable/Syntax/SyntaxNodePartials.vb index 223e3701685fd..9a784295b6a09 100644 --- a/src/Compilers/VisualBasic/Portable/Syntax/SyntaxNodePartials.vb +++ b/src/Compilers/VisualBasic/Portable/Syntax/SyntaxNodePartials.vb @@ -70,7 +70,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.Syntax End Function End Class - Public Partial Class SingleLineLambdaExpressionSyntax + Partial Public Class SingleLineLambdaExpressionSyntax ''' ''' Single line subs only have a single statement. However, when binding it is convenient to have a statement list. For example, ''' dim statements are not valid in a single line lambda. However, it is nice to be able to provide semantic info about the local. diff --git a/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs b/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs index dbb05b8ac3738..0fea5d4985f44 100644 --- a/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs +++ b/src/EditorFeatures/CSharpTest/TypeInferrer/TypeInferrerTests.cs @@ -1820,7 +1820,6 @@ static void Foo(System.ConsoleModifiers arg) } [Fact, Trait(Traits.Feature, Traits.Features.TypeInferenceService)] - [WorkItem(6765, "https://github.com/dotnet/roslyn/issues/6765")] public async Task TestWhereCall() { var text = @@ -1837,7 +1836,6 @@ void Foo() } [Fact, Trait(Traits.Feature, Traits.Features.TypeInferenceService)] - [WorkItem(6765, "https://github.com/dotnet/roslyn/issues/6765")] public async Task TestWhereCall2() { var text = diff --git a/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb b/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb index 45246dfa3a5c9..e8e679014d170 100644 --- a/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb +++ b/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb @@ -705,7 +705,7 @@ Module M Dim x As Boolean = Await [|F|].ContinueWith(Function(a) True).ContinueWith(Function(a) False) End Sub End Module" - Await TestAsync(text, "System.Object", testPosition:=False) + Await TestAsync(text, "Global.System.Threading.Tasks.Task(Of System.Object)", testPosition:=False) End Function @@ -736,5 +736,31 @@ End Module" End Class" Await TestAsync(text, "System.Object", testNode:=False, testPosition:=True) End Function + + + Public Async Function TestWhereCall() As Task + Dim text = +"imports System.Collections.Generic +class C + sub Foo() + [|ints|].Where(function(i) i > 10) + end sub +end class" + Await TestAsync(text, "Global.System.Collections.Generic.IEnumerable(Of System.Int32)", testPosition:=False) + End Function + + + Public Async Function TestWhereCall2() As Task + Dim text = +"imports System.Collections.Generic +class C + sub Foo() + [|ints|].Where(function(i) + return i > 10 + end function) + end sub +end class" + Await TestAsync(text, "Global.System.Collections.Generic.IEnumerable(Of System.Int32)", testPosition:=False) + End Function End Class End Namespace diff --git a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs index 83233595276c8..b737b2b917308 100644 --- a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs +++ b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs @@ -1457,72 +1457,78 @@ private IEnumerable InferTypeInMemberAccessExpression( Debug.Assert(expressionOpt != null); if (expressionOpt == memberAccessExpression.Expression) { - // If we're on the left side of a dot, it's possible in a few cases - // to figure out what type we should be. Specifically, if we have - // - // await foo.ConfigureAwait() - // - // then we can figure out what 'foo' should be based on teh await - // context. - var name = memberAccessExpression.Name.Identifier.Value; - if (name.Equals(nameof(Task.ConfigureAwait)) && - memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) && - memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression)) - { - return InferTypes((ExpressionSyntax)memberAccessExpression.Parent); - } - else if (name.Equals(nameof(Task.ContinueWith))) + return InferTypeForExpressionOfMemberAccessExpression(memberAccessExpression); + } + + // We're right after the dot in "Foo.Bar". The type for "Bar" should be + // whatever type we'd infer for "Foo.Bar" itself. + return InferTypes(memberAccessExpression); + } + } + + private IEnumerable InferTypeForExpressionOfMemberAccessExpression( + MemberAccessExpressionSyntax memberAccessExpression) + { + // If we're on the left side of a dot, it's possible in a few cases + // to figure out what type we should be. Specifically, if we have + // + // await foo.ConfigureAwait() + // + // then we can figure out what 'foo' should be based on teh await + // context. + var name = memberAccessExpression.Name.Identifier.Value; + if (name.Equals(nameof(Task.ConfigureAwait)) && + memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) && + memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression)) + { + return InferTypes((ExpressionSyntax)memberAccessExpression.Parent); + } + else if (name.Equals(nameof(Task.ContinueWith))) + { + // foo.ContinueWith(...) + // We want to infer Task. For now, we'll just do Task, + // in the future it would be nice to figure out the actual result + // type based on the argument to ContinueWith. + var taskOfT = this.Compilation.TaskOfTType(); + if (taskOfT != null) + { + return SpecializedCollections.SingletonEnumerable( + taskOfT.Construct(this.Compilation.ObjectType)); + } + } + else if (name.Equals(nameof(Enumerable.Select)) || + name.Equals(nameof(Enumerable.Where))) + { + var ienumerableType = this.Compilation.IEnumerableOfTType(); + + // foo.Select + // We want to infer IEnumerable. We can try to figure out what + // T if we get a delegate as the first argument to Select/Where. + if (ienumerableType != null && memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression)) + { + var invocation = (InvocationExpressionSyntax)memberAccessExpression.Parent; + if (invocation.ArgumentList.Arguments.Count > 0) { - // foo.ContinueWith(...) - // We want to infer Task. For now, we'll just do Task, - // in the future it would be nice to figure out the actual result - // type based on the argument to ContinueWith. - var taskOfT = this.Compilation.TaskOfTType(); - if (taskOfT != null) + var argumentExpression = invocation.ArgumentList.Arguments[0].Expression; + var argumentTypes = GetTypes(argumentExpression); + var delegateType = argumentTypes.FirstOrDefault().GetDelegateType(this.Compilation); + var typeArg = delegateType?.TypeArguments.Length > 0 + ? delegateType.TypeArguments[0] + : this.Compilation.ObjectType; + + if (IsUnusableType(typeArg) && argumentExpression is LambdaExpressionSyntax) { - return SpecializedCollections.SingletonEnumerable( - taskOfT.Construct(this.Compilation.ObjectType)); + typeArg = InferTypeForFirstParameterOfLambda((LambdaExpressionSyntax)argumentExpression) ?? + this.Compilation.ObjectType; } - } - else if (name.Equals(nameof(Enumerable.Select)) || - name.Equals(nameof(Enumerable.Where))) - { - var ienumerableType = this.Compilation.IEnumerableOfTType(); - // foo.Select - // We want to infer IEnumerable. We can try to figure out what - // T if we get a delegate as the first argument to Select/Where. - if (ienumerableType != null && memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression)) - { - var invocation = (InvocationExpressionSyntax)memberAccessExpression.Parent; - if (invocation.ArgumentList.Arguments.Count > 0) - { - var argumentExpression = invocation.ArgumentList.Arguments[0].Expression; - var argumentTypes = GetTypes(argumentExpression); - var delegateType = argumentTypes.FirstOrDefault().GetDelegateType(this.Compilation); - var typeArg = delegateType?.TypeArguments.Length > 0 - ? delegateType.TypeArguments[0] - : this.Compilation.ObjectType; - - if (IsUnusableType(typeArg) && argumentExpression is LambdaExpressionSyntax) - { - typeArg = InferTypeForFirstParameterOfLambda((LambdaExpressionSyntax)argumentExpression) ?? - this.Compilation.ObjectType; - } - - return SpecializedCollections.SingletonEnumerable( - ienumerableType.Construct(typeArg)); - } - } + return SpecializedCollections.SingletonEnumerable( + ienumerableType.Construct(typeArg)); } - - return SpecializedCollections.EmptyEnumerable(); } - - // We're right after the dot in "Foo.Bar". The type for "Bar" should be - // whatever type we'd infer for "Foo.Bar" itself. - return InferTypes(memberAccessExpression); } + + return SpecializedCollections.EmptyEnumerable(); } private ITypeSymbol InferTypeForFirstParameterOfLambda( diff --git a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb index 6c6f9e1cce1d1..68e722068199e 100644 --- a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb +++ b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb @@ -828,17 +828,107 @@ Namespace Microsoft.CodeAnalysis.VisualBasic ' then we can figure out what 'foo' should be based on teh await ' context. If expressionOpt Is memberAccessExpression.Expression Then - If memberAccessExpression.Name.Identifier.Value.Equals(NameOf(Task(Of Integer).ConfigureAwait)) AndAlso - memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) AndAlso - memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression) Then - Return InferTypes(DirectCast(memberAccessExpression.Parent, ExpressionSyntax)) + Return InferTypeForExpressionOfMemberAccessExpression(memberAccessExpression) + End If + + Return InferTypes(memberAccessExpression) + End If + End Function + + Private Function InferTypeForExpressionOfMemberAccessExpression(memberAccessExpression As MemberAccessExpressionSyntax) As IEnumerable(Of ITypeSymbol) + Dim name = memberAccessExpression.Name.Identifier.Value + + If name.Equals(NameOf(Task(Of Integer).ConfigureAwait)) AndAlso + memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) AndAlso + memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression) Then + Return InferTypes(DirectCast(memberAccessExpression.Parent, ExpressionSyntax)) + ElseIf name.Equals(NameOf(Task(Of Integer).ContinueWith)) Then + ' foo.ContinueWith(...) + ' We want to infer Task. For now, we'll just do Task, + ' in the future it would be nice to figure out the actual result + ' type based on the argument to ContinueWith. + Dim taskOfT = Me.Compilation.TaskOfTType() + If taskOfT IsNot Nothing Then + Return SpecializedCollections.SingletonEnumerable( + taskOfT.Construct(Me.Compilation.ObjectType)) + End If + ElseIf name.Equals(NameOf(Enumerable.Select)) OrElse + name.Equals(NameOf(Enumerable.Where)) Then + + Dim ienumerableType = Me.Compilation.IEnumerableOfTType() + + ' foo.Select + ' We want to infer IEnumerable. We can try to figure out what + ' T if we get a delegate as the first argument to Select/Where. + If ienumerableType IsNot Nothing AndAlso memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression) Then + Dim invocation = DirectCast(memberAccessExpression.Parent, InvocationExpressionSyntax) + If invocation.ArgumentList.Arguments.Count > 0 AndAlso + TypeOf invocation.ArgumentList.Arguments(0) Is SimpleArgumentSyntax Then + Dim argumentExpression = DirectCast(invocation.ArgumentList.Arguments(0), SimpleArgumentSyntax).Expression + Dim argumentTypes = GetTypes(argumentExpression) + Dim delegateType = argumentTypes.FirstOrDefault().GetDelegateType(Me.Compilation) + Dim typeArg = If(delegateType?.TypeArguments.Length > 0, + delegateType.TypeArguments(0), + Me.Compilation.ObjectType) + + If delegateType Is Nothing OrElse IsUnusableType(typeArg) Then + If TypeOf argumentExpression Is LambdaExpressionSyntax Then + typeArg = If(InferTypeForFirstParameterOfLambda(DirectCast(argumentExpression, LambdaExpressionSyntax)), + Me.Compilation.ObjectType) + End If + End If + + Return SpecializedCollections.SingletonEnumerable( + ienumerableType.Construct(typeArg)) End If + End If + End If - Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol)() + Return SpecializedCollections.EmptyEnumerable(Of ITypeSymbol)() + End Function + + Private Function InferTypeForFirstParameterOfLambda( + lambda As LambdaExpressionSyntax) As ITypeSymbol + If lambda.SubOrFunctionHeader.ParameterList.Parameters.Count > 0 Then + Dim parameter = lambda.SubOrFunctionHeader.ParameterList.Parameters(0) + Dim parameterName = parameter.Identifier.Identifier.ValueText + + If TypeOf lambda Is SingleLineLambdaExpressionSyntax Then + Dim singleLine = DirectCast(lambda, SingleLineLambdaExpressionSyntax) + Return InferTypeForFirstParameterOfLambda(parameterName, singleLine.Body) + ElseIf TypeOf lambda Is MultiLineLambdaExpressionSyntax Then + Dim multiLine = DirectCast(lambda, MultiLineLambdaExpressionSyntax) + For Each statement In multiLine.Statements + Dim type = InferTypeForFirstParameterOfLambda(parameterName, statement) + If type IsNot Nothing Then + Return type + End If + Next End If + End If - Return InferTypes(memberAccessExpression) + Return Nothing + End Function + + Private Function InferTypeForFirstParameterOfLambda( + parameterName As String, node As SyntaxNode) As ITypeSymbol + If node.IsKind(SyntaxKind.IdentifierName) Then + Dim identifier = DirectCast(node, IdentifierNameSyntax) + If CaseInsensitiveComparison.Equals(parameterName, identifier.Identifier.ValueText) Then + Return InferTypes(identifier).FirstOrDefault() + End If + Else + For Each child In node.ChildNodesAndTokens() + If child.IsNode Then + Dim type = InferTypeForFirstParameterOfLambda(parameterName, child.AsNode) + If type IsNot Nothing Then + Return type + End If + End If + Next End If + + Return Nothing End Function Private Function InferTypeInNamedFieldInitializer(initializer As NamedFieldInitializerSyntax, Optional previousToken As SyntaxToken = Nothing) As IEnumerable(Of ITypeSymbol) From e11beed3eff074a49a14b607833f6561dd3a1577 Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 21:29:37 -0700 Subject: [PATCH 08/10] Fix test. --- .../GenerateMethod/GenerateMethodTests.cs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs b/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs index 0ecdb5246d1c8..81b1aba9e1bd2 100644 --- a/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs +++ b/src/EditorFeatures/CSharpTest/Diagnostics/GenerateMethod/GenerateMethodTests.cs @@ -2766,7 +2766,15 @@ static void Main ( string [ ] args ) { bool x = await [|Foo|] ( ) . ConfigureAwait ( false ) ; } }", -@"using System ; using System . Collections . Generic ; using System . Linq ; using System . Threading . Tasks ; class Program { static void Main ( string [ ] args ) { bool x = await Foo ( ) . ConfigureAwait ( false ) ; } private static Task < bool > Foo ( ) { throw new NotImplementedException ( ) ; } } "); +@"using System ; +using System . Collections . Generic ; +using System . Linq ; +using System . Threading . Tasks ; +class Program { + static void Main ( string [ ] args ) { + bool x = await Foo ( ) . ConfigureAwait ( false ) ; + } + private static Task < bool > Foo ( ) { throw new NotImplementedException ( ) ; } } "); } [WorkItem(643, "https://github.com/dotnet/roslyn/issues/643")] @@ -2781,7 +2789,12 @@ static async void T ( ) { bool x = await [|M|] ( ) . ContinueWith ( a => { return true ; } ) . ContinueWith ( a => { return false ; } ) ; } } ", -@"using System ; using System . Threading . Tasks ; class C { static async void T ( ) { bool x = await M ( ) . ContinueWith ( a => { return true ; } ) . ContinueWith ( a => { return false ; } ) ; } private static object M ( ) { throw new NotImplementedException ( ) ; } } "); +@"using System ; +using System . Threading . Tasks ; +class C { + static async void T ( ) { + bool x = await M ( ) . ContinueWith ( a => { return true ; } ) . ContinueWith ( a => { return false ; } ) ; } + private static Task M ( ) { throw new NotImplementedException ( ) ; } } "); } [WorkItem(529480, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/529480")] From 3491dd2bcf487cecbca4a8861039a1f8c70af49b Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 20 Jun 2016 22:19:26 -0700 Subject: [PATCH 09/10] Only consider identifiers that bind to lambda parameters. --- .../CSharpTypeInferenceService.TypeInferrer.cs | 3 ++- .../VisualBasicTypeInferenceService.TypeInferrer.vb | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs index b737b2b917308..4ea5624427a06 100644 --- a/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs +++ b/src/Workspaces/CSharp/Portable/LanguageServices/CSharpTypeInferenceService.TypeInferrer.cs @@ -1578,7 +1578,8 @@ private ITypeSymbol InferTypeForFirstParameterOfLambda( if (node.IsKind(SyntaxKind.IdentifierName)) { var identifierName = (IdentifierNameSyntax)node; - if (identifierName.Identifier.ValueText.Equals(parameterName)) + if (identifierName.Identifier.ValueText.Equals(parameterName) && + SemanticModel.GetSymbolInfo(identifierName.Identifier).Symbol?.Kind == SymbolKind.Parameter) { return InferTypes(identifierName).FirstOrDefault(); } diff --git a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb index 68e722068199e..afd74097e2166 100644 --- a/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb +++ b/src/Workspaces/VisualBasic/Portable/LanguageServices/VisualBasicTypeInferenceService.TypeInferrer.vb @@ -914,7 +914,8 @@ Namespace Microsoft.CodeAnalysis.VisualBasic parameterName As String, node As SyntaxNode) As ITypeSymbol If node.IsKind(SyntaxKind.IdentifierName) Then Dim identifier = DirectCast(node, IdentifierNameSyntax) - If CaseInsensitiveComparison.Equals(parameterName, identifier.Identifier.ValueText) Then + If CaseInsensitiveComparison.Equals(parameterName, identifier.Identifier.ValueText) AndAlso + SemanticModel.GetSymbolInfo(identifier.Identifier).Symbol?.Kind = SymbolKind.Parameter Then Return InferTypes(identifier).FirstOrDefault() End If Else From a5b807121c6e917d3e8b2606170e2670ac04cca3 Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Tue, 21 Jun 2016 12:29:17 -0700 Subject: [PATCH 10/10] Add test. --- .../TypeInferrer/TypeInferrerTests.vb | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb b/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb index e8e679014d170..a8587cb76825c 100644 --- a/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb +++ b/src/EditorFeatures/VisualBasicTest/TypeInferrer/TypeInferrerTests.vb @@ -762,5 +762,17 @@ class C end class" Await TestAsync(text, "Global.System.Collections.Generic.IEnumerable(Of System.Int32)", testPosition:=False) End Function + + + Public Async Function TestMemberAccess1() As Task + Dim text = +"imports System.Collections.Generic +class C + sub Foo() + dim b as boolean = x.[||] + end sub +end class" + Await TestAsync(text, "System.Boolean", testNode:=False) + End Function End Class End Namespace