diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index 1a73267f..88eaedd0 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -134,6 +134,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static AttributeTargetSpecifierSyntax AttributeTargetSpecifier(SyntaxToken identifier) => SyntaxFactory.AttributeTargetSpecifier(identifier, TokenWithSpace(SyntaxKind.ColonToken)); + internal static ThrowStatementSyntax ThrowStatement() => SyntaxFactory.ThrowStatement(default, Token(SyntaxKind.ThrowKeyword), null, Semicolon); + internal static ThrowStatementSyntax ThrowStatement(ExpressionSyntax expression) => SyntaxFactory.ThrowStatement(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression, Semicolon); internal static ThrowExpressionSyntax ThrowExpression(ExpressionSyntax expression) => SyntaxFactory.ThrowExpression(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Com.cs b/src/Microsoft.Windows.CsWin32/Generator.Com.cs index e6593e01..fa0eaf58 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Com.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Com.cs @@ -425,20 +425,35 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn) //// hr.ThrowOnFailure(); : ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrLocal, HRThrowOnFailureMethodName))); - //// catch (Exception ex) { return (HRESULT)ex.HResult; } IdentifierNameSyntax exLocal = IdentifierName("ex"); - CatchClauseSyntax catchClause = CatchClause(CatchDeclaration(IdentifierName(nameof(Exception)).WithTrailingTrivia(Space), exLocal.Identifier), null, Block().AddStatements( - ReturnStatement(CastExpression(HresultTypeSyntax, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, exLocal, IdentifierName(nameof(Exception.HResult))))))); + BlockSyntax catchBlock = Block(); + if (hrReturnType) + { + //// return (HRESULT)ex.HResult; + catchBlock = catchBlock.AddStatements(ReturnStatement(CastExpression(HresultTypeSyntax, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, exLocal, IdentifierName(nameof(Exception.HResult)))))); + } + else + { + //// Environment.FailFast("COM object threw an exception from a non-HRESULT returning method.", ex); + //// throw; + catchBlock = catchBlock.AddStatements( + ExpressionStatement(InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseName("global::System.Environment"), IdentifierName(nameof(Environment.FailFast))), + ArgumentList().AddArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal("COM object threw an exception from a non-HRESULT returning method."))), + Argument(exLocal)))), + ThrowStatement()); + } + + //// catch (Exception ex) { + CatchClauseSyntax catchClause = CatchClause(CatchDeclaration(IdentifierName(nameof(Exception)).WithTrailingTrivia(Space), exLocal.Identifier), null, catchBlock); BlockSyntax tryBlock = Block().AddStatements( hrDecl, ifNullReturnStatement).AddStatements(thunkInvokeAndReturn); - BlockSyntax ccwBody = hrReturnType - //// try { ... } catch { ... } - ? Block().AddStatements(TryStatement(tryBlock, new SyntaxList(catchClause), null)) - //// { .... } // any exception is thrown back to native code. - : tryBlock; + //// try { ... } catch { ... } + BlockSyntax ccwBody = Block().AddStatements(TryStatement(tryBlock, new SyntaxList(catchClause), null)); //// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] //// private static HRESULT Clone(IEnumEventObject* @this, IEnumEventObject** ppInterface) diff --git a/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs b/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs index 7362d27d..2d0a8be7 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs @@ -206,6 +206,8 @@ internal WhitespaceRewriter() public override SyntaxNode? VisitTryStatement(TryStatementSyntax node) => base.VisitTryStatement(this.WithIndentingTrivia(node)); + public override SyntaxNode? VisitThrowStatement(ThrowStatementSyntax node) => base.VisitThrowStatement(this.WithIndentingTrivia(node)); + public override SyntaxNode? VisitCatchClause(CatchClauseSyntax node) => base.VisitCatchClause(this.WithIndentingTrivia(node)); public override SyntaxNode? VisitFinallyClause(FinallyClauseSyntax node) => base.VisitFinallyClause(this.WithIndentingTrivia(node)); diff --git a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs index 9c727d13..e32d9939 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs @@ -294,6 +294,22 @@ public void InterestingComInterfaces( this.AssertNoDiagnostics(); } + [Fact] + public void EnvironmentFailFast() + { + this.compilation = this.starterCompilations["net7.0"]; + this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false }); + + // Emit something into the Environment namespace, to invite collisions. + Assert.True(this.generator.TryGenerate("ENCLAVE_IDENTITY", CancellationToken.None)); + + // Emit the interface that can require Environment.FailFast. + Assert.True(this.generator.TryGenerate("ITypeInfo", CancellationToken.None)); + + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + } + [Fact] public void ComOutPtrTypedAsOutObject() {