Skip to content

Commit

Permalink
use tbb::concurrent_unordered_map instead of std::map so data races d…
Browse files Browse the repository at this point in the history
…o not happen across threads in profiling
  • Loading branch information
SteveBronder committed May 10, 2024
1 parent 35d37ce commit 8aca589
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
20 changes: 19 additions & 1 deletion stan/math/rev/core/profiling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/fun/value_of.hpp>
#include <stan/math/prim/err.hpp>
#include <tbb/concurrent_unordered_map.h>
#include <iostream>
#include <sstream>
#include <thread>
Expand Down Expand Up @@ -115,7 +116,24 @@ class profile_info {

using profile_key = std::pair<std::string, std::thread::id>;

using profile_map = std::map<profile_key, profile_info>;

namespace internal {
struct hash_profile_key {
std::size_t operator()(const profile_key& key) const {
return std::hash<std::string>()(key.first)
^ std::hash<std::thread::id>()(key.second);
}
};
struct equal_profile_key {
bool operator()(const profile_key& lhs, const profile_key& rhs) const {
return lhs.first == rhs.first && lhs.second == rhs.second;
}
};

}

using profile_map = tbb::concurrent_unordered_map<profile_key,
profile_info, internal::hash_profile_key, internal::equal_profile_key>;

/**
* Profiles C++ lines where the object is in scope.
Expand Down
22 changes: 11 additions & 11 deletions test/unit/math/rev/core/profiling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@
TEST(Profiling, double_basic) {
using stan::math::profile;
using stan::math::var;
stan::math::profile_map profiles;
stan::math::profile_map prof_map;
double a = 3.0, b = 2.0, c;
{
profile<double> p1("p1", profiles);
profile<double> p1("p1", prof_map);
c = log(exp(a)) * log(exp(b));
std::chrono::milliseconds timespan(10);
std::this_thread::sleep_for(timespan);
}
{
profile<int> p1("p1", profiles);
profile<int> p1("p1", prof_map);
c = log(exp(a)) * log(exp(b));
std::chrono::milliseconds timespan(10);
std::this_thread::sleep_for(timespan);
}

stan::math::profile_key key = {"p1", std::this_thread::get_id()};
EXPECT_NEAR(c, 6.0, 1E-8);
EXPECT_EQ(profiles[key].get_chain_stack_used(), 0);
EXPECT_EQ(profiles[key].get_nochain_stack_used(), 0);
EXPECT_FLOAT_EQ(profiles[key].get_rev_time(), 0.0);
EXPECT_EQ(profiles[key].get_num_rev_passes(), 0);
EXPECT_EQ(profiles[key].get_num_fwd_passes(), 2);
EXPECT_EQ(profiles[key].get_num_no_AD_fwd_passes(), 2);
EXPECT_EQ(profiles[key].get_num_AD_fwd_passes(), 0);
EXPECT_TRUE(profiles[key].get_fwd_time() > 0.0);
EXPECT_EQ(prof_map[key].get_chain_stack_used(), 0);
EXPECT_EQ(prof_map[key].get_nochain_stack_used(), 0);
EXPECT_FLOAT_EQ(prof_map[key].get_rev_time(), 0.0);
EXPECT_EQ(prof_map[key].get_num_rev_passes(), 0);
EXPECT_EQ(prof_map[key].get_num_fwd_passes(), 2);
EXPECT_EQ(prof_map[key].get_num_no_AD_fwd_passes(), 2);
EXPECT_EQ(prof_map[key].get_num_AD_fwd_passes(), 0);
EXPECT_TRUE(prof_map[key].get_fwd_time() > 0.0);
}

TEST(Profiling, var_basic) {
Expand Down

0 comments on commit 8aca589

Please sign in to comment.