diff --git a/base/BUILD.gn b/base/BUILD.gn index c277f828f011f8..d287afb434df68 100644 --- a/base/BUILD.gn +++ b/base/BUILD.gn @@ -955,6 +955,8 @@ component("base") { "win/atl.h", "win/atl_throw.cc", "win/atl_throw.h", + "win/com_init_balancer.cc", + "win/com_init_balancer.h", "win/com_init_check_hook.cc", "win/com_init_check_hook.h", "win/com_init_util.cc", @@ -2988,6 +2990,7 @@ test("base_unittests") { "threading/platform_thread_win_unittest.cc", "time/time_win_unittest.cc", "win/async_operation_unittest.cc", + "win/com_init_balancer_unittest.cc", "win/com_init_check_hook_unittest.cc", "win/com_init_util_unittest.cc", "win/core_winrt_util_unittest.cc", diff --git a/base/win/com_init_balancer.cc b/base/win/com_init_balancer.cc new file mode 100644 index 00000000000000..cd6d6b66b2de3f --- /dev/null +++ b/base/win/com_init_balancer.cc @@ -0,0 +1,68 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include "base/check_op.h" +#include "base/win/com_init_balancer.h" + +namespace base { +namespace win { +namespace internal { + +ComInitBalancer::ComInitBalancer(DWORD co_init) : co_init_(co_init) { + ULARGE_INTEGER spy_cookie = {}; + HRESULT hr = ::CoRegisterInitializeSpy(this, &spy_cookie); + if (SUCCEEDED(hr)) + spy_cookie_ = spy_cookie; +} + +ComInitBalancer::~ComInitBalancer() { + DCHECK(!spy_cookie_.has_value()); +} + +void ComInitBalancer::Disable() { + if (spy_cookie_.has_value()) { + ::CoRevokeInitializeSpy(spy_cookie_.value()); + reference_count_ = 0; + spy_cookie_.reset(); + } +} + +DWORD ComInitBalancer::GetReferenceCountForTesting() const { + return reference_count_; +} + +IFACEMETHODIMP +ComInitBalancer::PreInitialize(DWORD apartment_type, DWORD reference_count) { + return S_OK; +} + +IFACEMETHODIMP +ComInitBalancer::PostInitialize(HRESULT result, + DWORD apartment_type, + DWORD new_reference_count) { + reference_count_ = new_reference_count; + return result; +} + +IFACEMETHODIMP +ComInitBalancer::PreUninitialize(DWORD reference_count) { + if (reference_count == 1 && spy_cookie_.has_value()) { + // Increase the reference count to prevent premature and unbalanced + // uninitalization of the COM library. + ::CoInitializeEx(nullptr, co_init_); + } + return S_OK; +} + +IFACEMETHODIMP +ComInitBalancer::PostUninitialize(DWORD new_reference_count) { + reference_count_ = new_reference_count; + return S_OK; +} + +} // namespace internal +} // namespace win +} // namespace base diff --git a/base/win/com_init_balancer.h b/base/win/com_init_balancer.h new file mode 100644 index 00000000000000..cf7860cbdad6fe --- /dev/null +++ b/base/win/com_init_balancer.h @@ -0,0 +1,73 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef BASE_WIN_COM_INIT_BALANCER_H_ +#define BASE_WIN_COM_INIT_BALANCER_H_ + +#include +#include +#include + +#include "base/base_export.h" +#include "base/optional.h" +#include "base/threading/thread_checker.h" +#include "base/win/windows_types.h" + +namespace base { +namespace win { +namespace internal { + +// Implementation class of the IInitializeSpy Interface that prevents premature +// uninitialization of the COM library, often caused by unbalanced +// CoInitialize/CoUninitialize pairs. The use of this class is encouraged in +// COM-supporting threads that execute third-party code. +// +// Disable() must be called before uninitializing the COM library in order to +// revoke the registered spy and allow for the successful uninitialization of +// the COM library. +class BASE_EXPORT ComInitBalancer + : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IInitializeSpy> { + public: + // Constructs a COM initialize balancer. |co_init| defines the apartment's + // concurrency model used by the balancer. + explicit ComInitBalancer(DWORD co_init); + + ComInitBalancer(const ComInitBalancer&) = delete; + ComInitBalancer& operator=(const ComInitBalancer&) = delete; + + ~ComInitBalancer() override; + + // Disables balancer by revoking the registered spy and consequently + // unblocking attempts to uninitialize the COM library. + void Disable(); + + DWORD GetReferenceCountForTesting() const; + + private: + // IInitializeSpy: + IFACEMETHODIMP PreInitialize(DWORD apartment_type, + DWORD reference_count) override; + IFACEMETHODIMP PostInitialize(HRESULT result, + DWORD apartment_type, + DWORD new_reference_count) override; + IFACEMETHODIMP PreUninitialize(DWORD reference_count) override; + IFACEMETHODIMP PostUninitialize(DWORD new_reference_count) override; + + const DWORD co_init_; + + // The current apartment reference count set after the completion of the last + // call made to CoInitialize or CoUninitialize. + DWORD reference_count_ = 0; + + base::Optional spy_cookie_; + THREAD_CHECKER(thread_checker_); +}; + +} // namespace internal +} // namespace win +} // namespace base + +#endif // BASE_WIN_COM_INIT_BALANCER_H_ diff --git a/base/win/com_init_balancer_unittest.cc b/base/win/com_init_balancer_unittest.cc new file mode 100644 index 00000000000000..30ae5ffb4e6cf0 --- /dev/null +++ b/base/win/com_init_balancer_unittest.cc @@ -0,0 +1,151 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/win/com_init_balancer.h" + +#include +#include + +#include "base/test/gtest_util.h" +#include "base/win/com_init_util.h" +#include "base/win/scoped_com_initializer.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace base { +namespace win { + +using Microsoft::WRL::ComPtr; + +TEST(TestComInitBalancer, BalancedPairsWithComBalancerEnabled) { + { + // Assert COM has initialized correctly. + ScopedCOMInitializer com_initializer( + ScopedCOMInitializer::Uninitialization::kBlockPremature); + ASSERT_TRUE(com_initializer.Succeeded()); + + // Create COM object successfully. + ComPtr shell_link; + HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, + IID_PPV_ARGS(&shell_link)); + EXPECT_TRUE(SUCCEEDED(hr)); + } + + // ScopedCOMInitializer has gone out of scope and COM has been uninitialized. + EXPECT_DCHECK_DEATH(AssertComInitialized()); +} + +TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerEnabled) { + { + // Assert COM has initialized correctly. + ScopedCOMInitializer com_initializer( + ScopedCOMInitializer::Uninitialization::kBlockPremature); + ASSERT_TRUE(com_initializer.Succeeded()); + + // Attempt to prematurely uninitialize the COM library. + ::CoUninitialize(); + ::CoUninitialize(); + + // Assert COM is still initialized. + AssertComInitialized(); + + // Create COM object successfully. + ComPtr shell_link; + HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, + IID_PPV_ARGS(&shell_link)); + EXPECT_TRUE(SUCCEEDED(hr)); + } + + // ScopedCOMInitializer has gone out of scope and COM has been uninitialized. + EXPECT_DCHECK_DEATH(AssertComInitialized()); +} + +TEST(TestComInitBalancer, BalancedPairsWithComBalancerDisabled) { + { + // Assert COM has initialized correctly. + ScopedCOMInitializer com_initializer( + ScopedCOMInitializer::Uninitialization::kAllow); + ASSERT_TRUE(com_initializer.Succeeded()); + + // Create COM object successfully. + ComPtr shell_link; + HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, + IID_PPV_ARGS(&shell_link)); + EXPECT_TRUE(SUCCEEDED(hr)); + } + + // ScopedCOMInitializer has gone out of scope and COM has been uninitialized. + EXPECT_DCHECK_DEATH(AssertComInitialized()); +} + +TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerDisabled) { + // Assert COM has initialized correctly. + ScopedCOMInitializer com_initializer( + ScopedCOMInitializer::Uninitialization::kAllow); + ASSERT_TRUE(com_initializer.Succeeded()); + + // Attempt to prematurely uninitialize the COM library. + ::CoUninitialize(); + ::CoUninitialize(); + + // Assert COM is not initialized. + EXPECT_DCHECK_DEATH(AssertComInitialized()); + + // Create COM object unsuccessfully. + ComPtr shell_link; + HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, + IID_PPV_ARGS(&shell_link)); + EXPECT_TRUE(FAILED(hr)); + EXPECT_EQ(CO_E_NOTINITIALIZED, hr); +} + +TEST(TestComInitBalancer, OneRegisteredSpyRefCount) { + ScopedCOMInitializer com_initializer( + ScopedCOMInitializer::Uninitialization::kBlockPremature); + ASSERT_TRUE(com_initializer.Succeeded()); + + // Reference count should be 1 after initialization. + EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting()); + + // Attempt to prematurely uninitialize the COM library. + ::CoUninitialize(); + + // Expect reference count to remain at 1. + EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting()); +} + +TEST(TestComInitBalancer, ThreeRegisteredSpiesRefCount) { + ScopedCOMInitializer com_initializer_1( + ScopedCOMInitializer::Uninitialization::kBlockPremature); + ScopedCOMInitializer com_initializer_2( + ScopedCOMInitializer::Uninitialization::kBlockPremature); + ScopedCOMInitializer com_initializer_3( + ScopedCOMInitializer::Uninitialization::kBlockPremature); + ASSERT_TRUE(com_initializer_1.Succeeded()); + ASSERT_TRUE(com_initializer_2.Succeeded()); + ASSERT_TRUE(com_initializer_3.Succeeded()); + + // Reference count should be 3 after initialization. + EXPECT_EQ(DWORD(3), + com_initializer_1.GetCOMBalancerReferenceCountForTesting()); + EXPECT_EQ(DWORD(3), + com_initializer_2.GetCOMBalancerReferenceCountForTesting()); + EXPECT_EQ(DWORD(3), + com_initializer_3.GetCOMBalancerReferenceCountForTesting()); + + // Attempt to prematurely uninitialize the COM library. + ::CoUninitialize(); // Reference count -> 2. + ::CoUninitialize(); // Reference count -> 1. + ::CoUninitialize(); + + // Expect reference count to remain at 1. + EXPECT_EQ(DWORD(1), + com_initializer_1.GetCOMBalancerReferenceCountForTesting()); + EXPECT_EQ(DWORD(1), + com_initializer_2.GetCOMBalancerReferenceCountForTesting()); + EXPECT_EQ(DWORD(1), + com_initializer_3.GetCOMBalancerReferenceCountForTesting()); +} + +} // namespace win +} // namespace base diff --git a/base/win/scoped_com_initializer.cc b/base/win/scoped_com_initializer.cc index 80c97495a500a2..6d80ff41899e04 100644 --- a/base/win/scoped_com_initializer.cc +++ b/base/win/scoped_com_initializer.cc @@ -4,34 +4,53 @@ #include "base/win/scoped_com_initializer.h" +#include + #include "base/check_op.h" namespace base { namespace win { -ScopedCOMInitializer::ScopedCOMInitializer() { - Initialize(COINIT_APARTMENTTHREADED); +ScopedCOMInitializer::ScopedCOMInitializer(Uninitialization uninitialization) { + Initialize(COINIT_APARTMENTTHREADED, uninitialization); } -ScopedCOMInitializer::ScopedCOMInitializer(SelectMTA mta) { - Initialize(COINIT_MULTITHREADED); +ScopedCOMInitializer::ScopedCOMInitializer(SelectMTA mta, + Uninitialization uninitialization) { + Initialize(COINIT_MULTITHREADED, uninitialization); } ScopedCOMInitializer::~ScopedCOMInitializer() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); - if (Succeeded()) + if (Succeeded()) { + if (com_balancer_) { + com_balancer_->Disable(); + com_balancer_.Reset(); + } CoUninitialize(); + } } bool ScopedCOMInitializer::Succeeded() const { return SUCCEEDED(hr_); } -void ScopedCOMInitializer::Initialize(COINIT init) { +DWORD ScopedCOMInitializer::GetCOMBalancerReferenceCountForTesting() const { + if (com_balancer_) + return com_balancer_->GetReferenceCountForTesting(); + return 0; +} + +void ScopedCOMInitializer::Initialize(COINIT init, + Uninitialization uninitialization) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); // COINIT_DISABLE_OLE1DDE is always added based on: // https://docs.microsoft.com/en-us/windows/desktop/learnwin32/initializing-the-com-library - hr_ = CoInitializeEx(nullptr, init | COINIT_DISABLE_OLE1DDE); + if (uninitialization == Uninitialization::kBlockPremature) { + com_balancer_ = Microsoft::WRL::Details::Make( + init | COINIT_DISABLE_OLE1DDE); + } + hr_ = ::CoInitializeEx(nullptr, init | COINIT_DISABLE_OLE1DDE); DCHECK_NE(RPC_E_CHANGED_MODE, hr_) << "Invalid COM thread model change"; } diff --git a/base/win/scoped_com_initializer.h b/base/win/scoped_com_initializer.h index 3bb57954939ce2..d1b0694bc17806 100644 --- a/base/win/scoped_com_initializer.h +++ b/base/win/scoped_com_initializer.h @@ -6,10 +6,12 @@ #define BASE_WIN_SCOPED_COM_INITIALIZER_H_ #include +#include #include "base/base_export.h" #include "base/macros.h" #include "base/threading/thread_checker.h" +#include "base/win/com_init_balancer.h" #include "base/win/scoped_windows_thread_environment.h" namespace base { @@ -18,6 +20,10 @@ namespace win { // Initializes COM in the constructor (STA or MTA), and uninitializes COM in the // destructor. // +// It is strongly encouraged to block premature uninitialization of the COM +// libraries in threads that execute third-party code, as a way to protect +// against unbalanced CoInitialize/CoUninitialize pairs. +// // WARNING: This should only be used once per thread, ideally scoped to a // similar lifetime as the thread itself. You should not be using this in // random utility functions that make COM calls -- instead ensure these @@ -27,21 +33,39 @@ class BASE_EXPORT ScopedCOMInitializer : public ScopedWindowsThreadEnvironment { // Enum value provided to initialize the thread as an MTA instead of STA. enum SelectMTA { kMTA }; - // Constructor for STA initialization. - ScopedCOMInitializer(); + // Enum values which enumerates uninitialization modes for the COM library. + enum class Uninitialization { + + // Default value. Used in threads where no third-party code is executed. + kAllow, + + // Blocks premature uninitialization of the COM libraries before going out + // of scope. Used in threads where third-party code is executed. + kBlockPremature, + }; - // Constructor for MTA initialization. - explicit ScopedCOMInitializer(SelectMTA mta); + // Constructors for STA initialization. + explicit ScopedCOMInitializer( + Uninitialization uninitialization = Uninitialization::kAllow); + + // Constructors for MTA initialization. + explicit ScopedCOMInitializer( + SelectMTA mta, + Uninitialization uninitialization = Uninitialization::kAllow); ~ScopedCOMInitializer() override; // ScopedWindowsThreadEnvironment: bool Succeeded() const override; + // Used for testing. Returns the COM balancer's apartment thread ref count. + DWORD GetCOMBalancerReferenceCountForTesting() const; + private: - void Initialize(COINIT init); + void Initialize(COINIT init, Uninitialization uninitialization); - HRESULT hr_; + HRESULT hr_ = S_OK; + Microsoft::WRL::ComPtr com_balancer_; THREAD_CHECKER(thread_checker_); DISALLOW_COPY_AND_ASSIGN(ScopedCOMInitializer); diff --git a/components/services/quarantine/quarantine_impl.cc b/components/services/quarantine/quarantine_impl.cc index e75cd5c1bbe388..8324caabc6c032 100644 --- a/components/services/quarantine/quarantine_impl.cc +++ b/components/services/quarantine/quarantine_impl.cc @@ -13,7 +13,6 @@ #include "components/services/quarantine/quarantine.h" #if defined(OS_WIN) -#include "base/win/scoped_com_initializer.h" #include "components/services/quarantine/public/cpp/quarantine_features_win.h" #endif // OS_WIN @@ -53,8 +52,6 @@ void QuarantineImpl::QuarantineFile( if (base::FeatureList::IsEnabled(quarantine::kOutOfProcessQuarantine)) { // In out of process case, we are running in a utility process, // so directly call QuarantineFile and send the result. - base::win::ScopedCOMInitializer com_initializer; - QuarantineFileResult result = quarantine::QuarantineFile( full_path, source_url, referrer_url, client_guid); diff --git a/components/services/quarantine/quarantine_impl.h b/components/services/quarantine/quarantine_impl.h index a1c3bcf010e102..3c31d4255ad795 100644 --- a/components/services/quarantine/quarantine_impl.h +++ b/components/services/quarantine/quarantine_impl.h @@ -7,10 +7,15 @@ #include +#include "build/build_config.h" #include "components/services/quarantine/public/mojom/quarantine.mojom.h" #include "mojo/public/cpp/bindings/pending_receiver.h" #include "mojo/public/cpp/bindings/receiver.h" +#if defined(OS_WIN) +#include "base/win/scoped_com_initializer.h" +#endif // OS_WIN + namespace quarantine { class QuarantineImpl : public mojom::Quarantine { @@ -30,6 +35,11 @@ class QuarantineImpl : public mojom::Quarantine { private: mojo::Receiver receiver_{this}; +#if defined(OS_WIN) + base::win::ScopedCOMInitializer com_initializer_{ + base::win::ScopedCOMInitializer::Uninitialization::kBlockPremature}; +#endif // OS_WIN + DISALLOW_COPY_AND_ASSIGN(QuarantineImpl); }; diff --git a/components/services/quarantine/quarantine_win_unittest.cc b/components/services/quarantine/quarantine_win_unittest.cc index f931ece5649898..a5eb082d1d1448 100644 --- a/components/services/quarantine/quarantine_win_unittest.cc +++ b/components/services/quarantine/quarantine_win_unittest.cc @@ -14,6 +14,7 @@ #include "base/test/scoped_feature_list.h" #include "base/test/test_file_util.h" #include "base/test/test_reg_util_win.h" +#include "base/win/scoped_com_initializer.h" #include "base/win/win_util.h" #include "base/win/windows_version.h" #include "components/services/quarantine/public/cpp/quarantine_features_win.h" @@ -158,6 +159,9 @@ class QuarantineWinTest : public ::testing::Test { base::ScopedTempDir scoped_temp_dir_; + base::win::ScopedCOMInitializer com_initializer_{ + base::win::ScopedCOMInitializer::Uninitialization::kBlockPremature}; + // Due to caching, these sites zone must be set for all tests, so that the // order the tests are run does not matter. std::unique_ptr scoped_zone_for_trusted_site_;