Skip to content

Commit

Permalink
Ability to Prevent Premature COM Uninitialization
Browse files Browse the repository at this point in the history
This cl adds the capability to prevent the premature uninitialization
of the COM library in ScopedCOMInitializer. Premature uninitialization
usually occurs in the presence of unbalanced CoInitialize/CoUnitialize
pairs. While we can prevent this from ocurring in first party-code,
there is no mechanism that protects us when executing third-party code
in a COM enabled thread such as in the case of the Quarantine process.

Bug: 1075487
Change-Id: Ibb3cf304c6bbabc126867de47e963a52c9409248
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2378270
Reviewed-by: Bruce Dawson <brucedawson@chromium.org>
Reviewed-by: Asanka Herath <asanka@chromium.org>
Reviewed-by: Greg Thompson <grt@chromium.org>
Commit-Queue: Andres Pico <anpico@microsoft.com>
Cr-Commit-Position: refs/heads/master@{#804589}
  • Loading branch information
Andres Pico authored and Commit Bot committed Sep 5, 2020
1 parent 77109ca commit 1cceb25
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 16 deletions.
3 changes: 3 additions & 0 deletions base/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,8 @@ component("base") {
"win/atl.h",
"win/atl_throw.cc",
"win/atl_throw.h",
"win/com_init_balancer.cc",
"win/com_init_balancer.h",
"win/com_init_check_hook.cc",
"win/com_init_check_hook.h",
"win/com_init_util.cc",
Expand Down Expand Up @@ -2988,6 +2990,7 @@ test("base_unittests") {
"threading/platform_thread_win_unittest.cc",
"time/time_win_unittest.cc",
"win/async_operation_unittest.cc",
"win/com_init_balancer_unittest.cc",
"win/com_init_check_hook_unittest.cc",
"win/com_init_util_unittest.cc",
"win/core_winrt_util_unittest.cc",
Expand Down
68 changes: 68 additions & 0 deletions base/win/com_init_balancer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <objbase.h>

#include "base/check_op.h"
#include "base/win/com_init_balancer.h"

namespace base {
namespace win {
namespace internal {

ComInitBalancer::ComInitBalancer(DWORD co_init) : co_init_(co_init) {
ULARGE_INTEGER spy_cookie = {};
HRESULT hr = ::CoRegisterInitializeSpy(this, &spy_cookie);
if (SUCCEEDED(hr))
spy_cookie_ = spy_cookie;
}

ComInitBalancer::~ComInitBalancer() {
DCHECK(!spy_cookie_.has_value());
}

void ComInitBalancer::Disable() {
if (spy_cookie_.has_value()) {
::CoRevokeInitializeSpy(spy_cookie_.value());
reference_count_ = 0;
spy_cookie_.reset();
}
}

DWORD ComInitBalancer::GetReferenceCountForTesting() const {
return reference_count_;
}

IFACEMETHODIMP
ComInitBalancer::PreInitialize(DWORD apartment_type, DWORD reference_count) {
return S_OK;
}

IFACEMETHODIMP
ComInitBalancer::PostInitialize(HRESULT result,
DWORD apartment_type,
DWORD new_reference_count) {
reference_count_ = new_reference_count;
return result;
}

IFACEMETHODIMP
ComInitBalancer::PreUninitialize(DWORD reference_count) {
if (reference_count == 1 && spy_cookie_.has_value()) {
// Increase the reference count to prevent premature and unbalanced
// uninitalization of the COM library.
::CoInitializeEx(nullptr, co_init_);
}
return S_OK;
}

IFACEMETHODIMP
ComInitBalancer::PostUninitialize(DWORD new_reference_count) {
reference_count_ = new_reference_count;
return S_OK;
}

} // namespace internal
} // namespace win
} // namespace base
73 changes: 73 additions & 0 deletions base/win/com_init_balancer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef BASE_WIN_COM_INIT_BALANCER_H_
#define BASE_WIN_COM_INIT_BALANCER_H_

#include <objidl.h>
#include <winnt.h>
#include <wrl/implements.h>

#include "base/base_export.h"
#include "base/optional.h"
#include "base/threading/thread_checker.h"
#include "base/win/windows_types.h"

namespace base {
namespace win {
namespace internal {

// Implementation class of the IInitializeSpy Interface that prevents premature
// uninitialization of the COM library, often caused by unbalanced
// CoInitialize/CoUninitialize pairs. The use of this class is encouraged in
// COM-supporting threads that execute third-party code.
//
// Disable() must be called before uninitializing the COM library in order to
// revoke the registered spy and allow for the successful uninitialization of
// the COM library.
class BASE_EXPORT ComInitBalancer
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
IInitializeSpy> {
public:
// Constructs a COM initialize balancer. |co_init| defines the apartment's
// concurrency model used by the balancer.
explicit ComInitBalancer(DWORD co_init);

ComInitBalancer(const ComInitBalancer&) = delete;
ComInitBalancer& operator=(const ComInitBalancer&) = delete;

~ComInitBalancer() override;

// Disables balancer by revoking the registered spy and consequently
// unblocking attempts to uninitialize the COM library.
void Disable();

DWORD GetReferenceCountForTesting() const;

private:
// IInitializeSpy:
IFACEMETHODIMP PreInitialize(DWORD apartment_type,
DWORD reference_count) override;
IFACEMETHODIMP PostInitialize(HRESULT result,
DWORD apartment_type,
DWORD new_reference_count) override;
IFACEMETHODIMP PreUninitialize(DWORD reference_count) override;
IFACEMETHODIMP PostUninitialize(DWORD new_reference_count) override;

const DWORD co_init_;

// The current apartment reference count set after the completion of the last
// call made to CoInitialize or CoUninitialize.
DWORD reference_count_ = 0;

base::Optional<ULARGE_INTEGER> spy_cookie_;
THREAD_CHECKER(thread_checker_);
};

} // namespace internal
} // namespace win
} // namespace base

#endif // BASE_WIN_COM_INIT_BALANCER_H_
151 changes: 151 additions & 0 deletions base/win/com_init_balancer_unittest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/win/com_init_balancer.h"

#include <shlobj.h>
#include <wrl/client.h>

#include "base/test/gtest_util.h"
#include "base/win/com_init_util.h"
#include "base/win/scoped_com_initializer.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace base {
namespace win {

using Microsoft::WRL::ComPtr;

TEST(TestComInitBalancer, BalancedPairsWithComBalancerEnabled) {
{
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer.Succeeded());

// Create COM object successfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(SUCCEEDED(hr));
}

// ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());
}

TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerEnabled) {
{
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer.Succeeded());

// Attempt to prematurely uninitialize the COM library.
::CoUninitialize();
::CoUninitialize();

// Assert COM is still initialized.
AssertComInitialized();

// Create COM object successfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(SUCCEEDED(hr));
}

// ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());
}

TEST(TestComInitBalancer, BalancedPairsWithComBalancerDisabled) {
{
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kAllow);
ASSERT_TRUE(com_initializer.Succeeded());

// Create COM object successfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(SUCCEEDED(hr));
}

// ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());
}

TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerDisabled) {
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kAllow);
ASSERT_TRUE(com_initializer.Succeeded());

// Attempt to prematurely uninitialize the COM library.
::CoUninitialize();
::CoUninitialize();

// Assert COM is not initialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());

// Create COM object unsuccessfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(FAILED(hr));
EXPECT_EQ(CO_E_NOTINITIALIZED, hr);
}

TEST(TestComInitBalancer, OneRegisteredSpyRefCount) {
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer.Succeeded());

// Reference count should be 1 after initialization.
EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting());

// Attempt to prematurely uninitialize the COM library.
::CoUninitialize();

// Expect reference count to remain at 1.
EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting());
}

TEST(TestComInitBalancer, ThreeRegisteredSpiesRefCount) {
ScopedCOMInitializer com_initializer_1(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ScopedCOMInitializer com_initializer_2(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ScopedCOMInitializer com_initializer_3(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer_1.Succeeded());
ASSERT_TRUE(com_initializer_2.Succeeded());
ASSERT_TRUE(com_initializer_3.Succeeded());

// Reference count should be 3 after initialization.
EXPECT_EQ(DWORD(3),
com_initializer_1.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(3),
com_initializer_2.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(3),
com_initializer_3.GetCOMBalancerReferenceCountForTesting());

// Attempt to prematurely uninitialize the COM library.
::CoUninitialize(); // Reference count -> 2.
::CoUninitialize(); // Reference count -> 1.
::CoUninitialize();

// Expect reference count to remain at 1.
EXPECT_EQ(DWORD(1),
com_initializer_1.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(1),
com_initializer_2.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(1),
com_initializer_3.GetCOMBalancerReferenceCountForTesting());
}

} // namespace win
} // namespace base
33 changes: 26 additions & 7 deletions base/win/scoped_com_initializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,53 @@

#include "base/win/scoped_com_initializer.h"

#include <wrl/implements.h>

#include "base/check_op.h"

namespace base {
namespace win {

ScopedCOMInitializer::ScopedCOMInitializer() {
Initialize(COINIT_APARTMENTTHREADED);
ScopedCOMInitializer::ScopedCOMInitializer(Uninitialization uninitialization) {
Initialize(COINIT_APARTMENTTHREADED, uninitialization);
}

ScopedCOMInitializer::ScopedCOMInitializer(SelectMTA mta) {
Initialize(COINIT_MULTITHREADED);
ScopedCOMInitializer::ScopedCOMInitializer(SelectMTA mta,
Uninitialization uninitialization) {
Initialize(COINIT_MULTITHREADED, uninitialization);
}

ScopedCOMInitializer::~ScopedCOMInitializer() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
if (Succeeded())
if (Succeeded()) {
if (com_balancer_) {
com_balancer_->Disable();
com_balancer_.Reset();
}
CoUninitialize();
}
}

bool ScopedCOMInitializer::Succeeded() const {
return SUCCEEDED(hr_);
}

void ScopedCOMInitializer::Initialize(COINIT init) {
DWORD ScopedCOMInitializer::GetCOMBalancerReferenceCountForTesting() const {
if (com_balancer_)
return com_balancer_->GetReferenceCountForTesting();
return 0;
}

void ScopedCOMInitializer::Initialize(COINIT init,
Uninitialization uninitialization) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
// COINIT_DISABLE_OLE1DDE is always added based on:
// https://docs.microsoft.com/en-us/windows/desktop/learnwin32/initializing-the-com-library
hr_ = CoInitializeEx(nullptr, init | COINIT_DISABLE_OLE1DDE);
if (uninitialization == Uninitialization::kBlockPremature) {
com_balancer_ = Microsoft::WRL::Details::Make<internal::ComInitBalancer>(
init | COINIT_DISABLE_OLE1DDE);
}
hr_ = ::CoInitializeEx(nullptr, init | COINIT_DISABLE_OLE1DDE);
DCHECK_NE(RPC_E_CHANGED_MODE, hr_) << "Invalid COM thread model change";
}

Expand Down
Loading

0 comments on commit 1cceb25

Please sign in to comment.