diff --git a/stan/math/rev/core/profiling.hpp b/stan/math/rev/core/profiling.hpp index ea51aa79499..14536ac9df7 100644 --- a/stan/math/rev/core/profiling.hpp +++ b/stan/math/rev/core/profiling.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -115,7 +116,24 @@ class profile_info { using profile_key = std::pair; -using profile_map = std::map; + +namespace internal { + struct hash_profile_key { + std::size_t operator()(const profile_key& key) const { + return std::hash()(key.first) + ^ std::hash()(key.second); + } + }; + struct equal_profile_key { + bool operator()(const profile_key& lhs, const profile_key& rhs) const { + return lhs.first == rhs.first && lhs.second == rhs.second; + } + }; + +} + +using profile_map = tbb::concurrent_unordered_map; /** * Profiles C++ lines where the object is in scope. diff --git a/test/unit/math/rev/core/profiling_test.cpp b/test/unit/math/rev/core/profiling_test.cpp index b8f6c1e2f3d..a8d2ed4fe7c 100644 --- a/test/unit/math/rev/core/profiling_test.cpp +++ b/test/unit/math/rev/core/profiling_test.cpp @@ -6,16 +6,16 @@ TEST(Profiling, double_basic) { using stan::math::profile; using stan::math::var; - stan::math::profile_map profiles; + stan::math::profile_map prof_map; double a = 3.0, b = 2.0, c; { - profile p1("p1", profiles); + profile p1("p1", prof_map); c = log(exp(a)) * log(exp(b)); std::chrono::milliseconds timespan(10); std::this_thread::sleep_for(timespan); } { - profile p1("p1", profiles); + profile p1("p1", prof_map); c = log(exp(a)) * log(exp(b)); std::chrono::milliseconds timespan(10); std::this_thread::sleep_for(timespan); @@ -23,14 +23,14 @@ TEST(Profiling, double_basic) { stan::math::profile_key key = {"p1", std::this_thread::get_id()}; EXPECT_NEAR(c, 6.0, 1E-8); - EXPECT_EQ(profiles[key].get_chain_stack_used(), 0); - EXPECT_EQ(profiles[key].get_nochain_stack_used(), 0); - EXPECT_FLOAT_EQ(profiles[key].get_rev_time(), 0.0); - EXPECT_EQ(profiles[key].get_num_rev_passes(), 0); - EXPECT_EQ(profiles[key].get_num_fwd_passes(), 2); - EXPECT_EQ(profiles[key].get_num_no_AD_fwd_passes(), 2); - EXPECT_EQ(profiles[key].get_num_AD_fwd_passes(), 0); - EXPECT_TRUE(profiles[key].get_fwd_time() > 0.0); + EXPECT_EQ(prof_map[key].get_chain_stack_used(), 0); + EXPECT_EQ(prof_map[key].get_nochain_stack_used(), 0); + EXPECT_FLOAT_EQ(prof_map[key].get_rev_time(), 0.0); + EXPECT_EQ(prof_map[key].get_num_rev_passes(), 0); + EXPECT_EQ(prof_map[key].get_num_fwd_passes(), 2); + EXPECT_EQ(prof_map[key].get_num_no_AD_fwd_passes(), 2); + EXPECT_EQ(prof_map[key].get_num_AD_fwd_passes(), 0); + EXPECT_TRUE(prof_map[key].get_fwd_time() > 0.0); } TEST(Profiling, var_basic) {