Skip to content

Commit

Permalink
Add first version of thread safety and tests
Browse files Browse the repository at this point in the history
Using a mutex to guard the internal map. Not yet a policy with
possibility to disable locking from the outside
  • Loading branch information
tmadlener committed Apr 4, 2022
1 parent a5fcd71 commit 0e7c4be
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 5 deletions.
18 changes: 13 additions & 5 deletions include/podio/Frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <initializer_list>
#include <memory>
#include <mutex>
#include <string>
#include <type_traits>
#include <unordered_map>
Expand Down Expand Up @@ -73,6 +74,7 @@ class Frame {

CollectionMapT m_collections{}; ///< The internal map for storing unpacked collections
podio::GenericParameters m_parameters{}; ///< The generic parameter store for this frame
mutable std::mutex m_mapMutex{}; ///< The mutex for guarding the internal collection map
};

std::unique_ptr<FrameConcept> m_self; ///< The internal concept pointer through which all the work is done
Expand Down Expand Up @@ -171,18 +173,24 @@ const CollT& Frame::put(CollT&& coll, const std::string& name) {
}

const podio::CollectionBase* Frame::FrameModel::get(const std::string& name) const {
if (const auto it = m_collections.find(name); it != m_collections.end()) {
return it->second.get();
{
std::lock_guard lock(m_mapMutex);
if (const auto it = m_collections.find(name); it != m_collections.end()) {
return it->second.get();
}
}

return nullptr;
}

const podio::CollectionBase* Frame::FrameModel::put(std::unique_ptr<podio::CollectionBase> coll,
const std::string& name) {
auto [it, success] = m_collections.try_emplace(name, std::move(coll));
if (success) {
return it->second.get();
{
std::lock_guard lock(m_mapMutex);
auto [it, success] = m_collections.try_emplace(name, std::move(coll));
if (success) {
return it->second.get();
}
}

return nullptr;
Expand Down
108 changes: 108 additions & 0 deletions tests/frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "datamodel/ExampleClusterCollection.h"
#include "datamodel/ExampleHitCollection.h"

#include <string>
#include <thread>
#include <vector>

TEST_CASE("Frame collections", "[frame][basics]") {
auto event = podio::Frame();
auto clusters = ExampleClusterCollection();
Expand Down Expand Up @@ -38,6 +42,51 @@ TEST_CASE("Frame parameters", "[frame][basics]") {
REQUIRE(strings[2] == "three");
}

TEST_CASE("Frame collections multithreaded insert", "[frame][basics][multithread]") {
constexpr int nThreads = 10;
std::vector<std::thread> threads;
threads.reserve(10);

auto frame = podio::Frame();

// Fill collections from different threads
for (int i = 0; i < nThreads; ++i) {
threads.emplace_back(
[&frame](int j) {
auto clusters = ExampleClusterCollection();
clusters.create(j * 3.14);
clusters.create(j * 3.14);
frame.put(std::move(clusters), "clusters_" + std::to_string(j));

auto hits = ExampleHitCollection();
hits.create(j * 100ULL);
hits.create(j * 100ULL);
hits.create(j * 100ULL);
frame.put(std::move(hits), "hits_" + std::to_string(j));
},
i);
}

for (auto& t : threads) {
t.join();
}

// Check the frame contents after all threads have finished
for (int i = 0; i < nThreads; ++i) {
auto& hits = frame.get<ExampleHitCollection>("hits_" + std::to_string(i));
REQUIRE(hits.size() == 3);
for (const auto h : hits) {
REQUIRE(h.cellID() == i * 100ULL);
}

auto& clusters = frame.get<ExampleClusterCollection>("clusters_" + std::to_string(i));
REQUIRE(clusters.size() == 2);
for (const auto c : clusters) {
REQUIRE(c.energy() == i * 3.14);
}
}
}

// Helper function to create a frame in the tests below
auto createFrame() {
auto frame = podio::Frame();
Expand Down Expand Up @@ -70,6 +119,65 @@ auto createFrame() {
return frame;
}

TEST_CASE("Frame collections multithreaded insert and read", "[frame][basics][multithread]") {
constexpr int nThreads = 10;
std::vector<std::thread> threads;
threads.reserve(10);

// create a pre-populated frame
auto frame = createFrame();

// Fill collections from different threads
for (int i = 0; i < nThreads; ++i) {
threads.emplace_back(
[&frame](int j) {
auto clusters = ExampleClusterCollection();
clusters.create(j * 3.14);
clusters.create(j * 3.14);
frame.put(std::move(clusters), "clusters_" + std::to_string(j));

// Retrieve a few collections in between and do just a very basic testing
auto& existingClu = frame.get<ExampleClusterCollection>("clusters");
REQUIRE(existingClu.size() == 2);
auto& existingHits = frame.get<ExampleHitCollection>("hits");
REQUIRE(existingHits.size() == 2);

auto hits = ExampleHitCollection();
hits.create(j * 100ULL);
hits.create(j * 100ULL);
hits.create(j * 100ULL);
frame.put(std::move(hits), "hits_" + std::to_string(j));

// Fill in a lot of new collections to trigger a rehashing of the
// internal map, which invalidates iterators
constexpr int nColls = 100;
for (int k = 0; k < nColls; ++k) {
frame.put(ExampleHitCollection(), "h_" + std::to_string(j) + "_" + std::to_string(k));
}
},
i);
}

for (auto& t : threads) {
t.join();
}

// Check the frame contents after all threads have finished
for (int i = 0; i < nThreads; ++i) {
auto& hits = frame.get<ExampleHitCollection>("hits_" + std::to_string(i));
REQUIRE(hits.size() == 3);
for (const auto h : hits) {
REQUIRE(h.cellID() == i * 100ULL);
}

auto& clusters = frame.get<ExampleClusterCollection>("clusters_" + std::to_string(i));
REQUIRE(clusters.size() == 2);
for (const auto c : clusters) {
REQUIRE(c.energy() == i * 3.14);
}
}
}

// Helper function to keep the tests below a bit easier to read and not having
// to repeat this bit all the time. This checks that the contents are the ones
// that would be expected from the createFrame above
Expand Down

0 comments on commit 0e7c4be

Please sign in to comment.