Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FailFast instead of throw from non-HRESULT returning CCW methods #1021

Merged
merged 1 commit into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
FailFast instead of throw from non-HRESULT returning CCW methods
Closes #830
  • Loading branch information
AArnott committed Aug 11, 2023
commit 4119ae5604ca8303822d05e9ee6a49009063b899
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static AttributeTargetSpecifierSyntax AttributeTargetSpecifier(SyntaxToken identifier) => SyntaxFactory.AttributeTargetSpecifier(identifier, TokenWithSpace(SyntaxKind.ColonToken));

internal static ThrowStatementSyntax ThrowStatement() => SyntaxFactory.ThrowStatement(default, Token(SyntaxKind.ThrowKeyword), null, Semicolon);

internal static ThrowStatementSyntax ThrowStatement(ExpressionSyntax expression) => SyntaxFactory.ThrowStatement(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression, Semicolon);

internal static ThrowExpressionSyntax ThrowExpression(ExpressionSyntax expression) => SyntaxFactory.ThrowExpression(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression);
Expand Down
31 changes: 23 additions & 8 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -425,20 +425,35 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
//// hr.ThrowOnFailure();
: ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrLocal, HRThrowOnFailureMethodName)));

//// catch (Exception ex) { return (HRESULT)ex.HResult; }
IdentifierNameSyntax exLocal = IdentifierName("ex");
CatchClauseSyntax catchClause = CatchClause(CatchDeclaration(IdentifierName(nameof(Exception)).WithTrailingTrivia(Space), exLocal.Identifier), null, Block().AddStatements(
ReturnStatement(CastExpression(HresultTypeSyntax, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, exLocal, IdentifierName(nameof(Exception.HResult)))))));
BlockSyntax catchBlock = Block();
if (hrReturnType)
{
//// return (HRESULT)ex.HResult;
catchBlock = catchBlock.AddStatements(ReturnStatement(CastExpression(HresultTypeSyntax, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, exLocal, IdentifierName(nameof(Exception.HResult))))));
}
else
{
//// Environment.FailFast("COM object threw an exception from a non-HRESULT returning method.", ex);
//// throw;
catchBlock = catchBlock.AddStatements(
ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseName("global::System.Environment"), IdentifierName(nameof(Environment.FailFast))),
ArgumentList().AddArguments(
Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal("COM object threw an exception from a non-HRESULT returning method."))),
Argument(exLocal)))),
ThrowStatement());
}

//// catch (Exception ex) {
CatchClauseSyntax catchClause = CatchClause(CatchDeclaration(IdentifierName(nameof(Exception)).WithTrailingTrivia(Space), exLocal.Identifier), null, catchBlock);

BlockSyntax tryBlock = Block().AddStatements(
hrDecl,
ifNullReturnStatement).AddStatements(thunkInvokeAndReturn);

BlockSyntax ccwBody = hrReturnType
//// try { ... } catch { ... }
? Block().AddStatements(TryStatement(tryBlock, new SyntaxList<CatchClauseSyntax>(catchClause), null))
//// { .... } // any exception is thrown back to native code.
: tryBlock;
//// try { ... } catch { ... }
BlockSyntax ccwBody = Block().AddStatements(TryStatement(tryBlock, new SyntaxList<CatchClauseSyntax>(catchClause), null));

//// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
//// private static HRESULT Clone(IEnumEventObject* @this, IEnumEventObject** ppInterface)
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ internal WhitespaceRewriter()

public override SyntaxNode? VisitTryStatement(TryStatementSyntax node) => base.VisitTryStatement(this.WithIndentingTrivia(node));

public override SyntaxNode? VisitThrowStatement(ThrowStatementSyntax node) => base.VisitThrowStatement(this.WithIndentingTrivia(node));

public override SyntaxNode? VisitCatchClause(CatchClauseSyntax node) => base.VisitCatchClause(this.WithIndentingTrivia(node));

public override SyntaxNode? VisitFinallyClause(FinallyClauseSyntax node) => base.VisitFinallyClause(this.WithIndentingTrivia(node));
Expand Down
16 changes: 16 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,22 @@ public void InterestingComInterfaces(
this.AssertNoDiagnostics();
}

[Fact]
public void EnvironmentFailFast()
{
this.compilation = this.starterCompilations["net7.0"];
this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false });

// Emit something into the Environment namespace, to invite collisions.
Assert.True(this.generator.TryGenerate("ENCLAVE_IDENTITY", CancellationToken.None));

// Emit the interface that can require Environment.FailFast.
Assert.True(this.generator.TryGenerate("ITypeInfo", CancellationToken.None));

this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
}

[Fact]
public void ComOutPtrTypedAsOutObject()
{
Expand Down