diff --git a/NeoML/include/NeoML/Dnn/Dnn.h b/NeoML/include/NeoML/Dnn/Dnn.h index 98862d017..695968207 100644 --- a/NeoML/include/NeoML/Dnn/Dnn.h +++ b/NeoML/include/NeoML/Dnn/Dnn.h @@ -413,7 +413,7 @@ class NEOML_API CBaseLayer : public virtual IObject { void link(); void addOutput(int number); void unlink(); - void cleanUp( bool total, bool unlink ); + void cleanUp( bool total, bool linked ); void buildOrder(); void reshape(); void setInputDesc(int i); diff --git a/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h b/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h index b6e769b6a..850f94f97 100644 --- a/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h +++ b/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h @@ -1164,10 +1164,13 @@ class NEOMATHENGINE_API IMathEngine : public IDnnEngine { // Turns on and off the memory reuse mode // In this mode, the allocated memory blocks will not be deleted on HeapFree() and may be used until CleanUp() virtual void SetReuseMemoryMode( bool enable ) = 0; + virtual bool GetReuseMemoryMode() const = 0; // Specialize the size threshold in bytes for the current thread, so // memory blocks of a size <= this threshold would be allocated in buffers if 'reuse' mode enabled // memory blocks of a size > this threshold would be allocated in raw RAM memory (malloc/free) virtual void SetThreadBufferMemoryThreshold( size_t threshold ) = 0; + // Get the memory blocks' sizes threshold for this thread + virtual size_t GetThreadBufferMemoryThreshold() const = 0; virtual CMemoryHandle HeapAlloc( size_t count ) = 0; virtual void HeapFree( const CMemoryHandle& handle ) = 0; @@ -1228,7 +1231,7 @@ class NEOMATHENGINE_API IMathEngine : public IDnnEngine { virtual void AllReduce( const CFloatHandle& handle, int size ) = 0; virtual void Broadcast( const CFloatHandle& handle, int size, int root ) = 0; virtual void AbortDistributed() {}; - virtual bool IsDistributed() { return false; } + virtual bool IsDistributed() const { return false; } }; //------------------------------------------------------------------------------------------------------------ diff --git a/NeoMathEngine/src/CPU/CpuMathEngine.cpp b/NeoMathEngine/src/CPU/CpuMathEngine.cpp index 6c91e0bef..d85bc38ef 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngine.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngine.cpp @@ -97,12 +97,28 @@ void CCpuMathEngine::SetReuseMemoryMode( bool enable ) memoryPool->SetReuseMemoryMode( enable ); } +bool CCpuMathEngine::GetReuseMemoryMode() const +{ + // Distributed CPU math engine always uses memory pools + if( IsDistributed() ) { + return true; + } + std::lock_guard lock( mutex ); + return memoryPool->GetReuseMemoryMode(); +} + void CCpuMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) { std::lock_guard lock( mutex ); memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CCpuMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CCpuMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( mutex ); diff --git a/NeoMathEngine/src/CPU/CpuMathEngine.h b/NeoMathEngine/src/CPU/CpuMathEngine.h index f55253d28..4eef40de5 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngine.h +++ b/NeoMathEngine/src/CPU/CpuMathEngine.h @@ -45,7 +45,9 @@ class CCpuMathEngine : public IMathEngine, public IRawMemoryManager { // IMathEngine interface methods TMathEngineType GetType() const override { return MET_Cpu; } void SetReuseMemoryMode( bool enabled ) override; + bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& handle, size_t size ) override; @@ -628,7 +630,8 @@ class CCpuMathEngine : public IMathEngine, public IRawMemoryManager { void Broadcast( const CFloatHandle& handle, int size, int root ) override; void AbortDistributed() override; CMathEngineDistributedInfo GetDistributedInfo() override { return distributedInfo; } - bool IsDistributed() override { return distributedInfo.Threads > 1; } + bool IsDistributed() const override { return distributedInfo.Threads > 1; } + protected: // IRawMemoryManager interface methods CMemoryHandle Alloc( size_t size ) override; diff --git a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp index ec3be2a8c..c8d5538ea 100644 --- a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp +++ b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp @@ -80,12 +80,24 @@ void CCudaMathEngine::SetReuseMemoryMode( bool ) // Always true, because allocation is sync } +bool CCudaMathEngine::GetReuseMemoryMode() const +{ + // Always true, because allocation is sync + return true; +} + void CCudaMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) { std::lock_guard lock( mutex ); memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CCudaMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CCudaMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( mutex ); diff --git a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h index 3bee1ab35..ddaf13c1c 100644 --- a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h +++ b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h @@ -52,7 +52,9 @@ class CCudaMathEngine : public IMathEngine, public IRawMemoryManager { TMathEngineType GetType() const override { return MET_Cuda; } void GetMathEngineInfo( CMathEngineInfo& info ) const override; void SetReuseMemoryMode( bool enable ) override; + bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& handle, size_t size ) override; @@ -630,15 +632,17 @@ class CCudaMathEngine : public IMathEngine, public IRawMemoryManager { const CFloatHandle& input, const CFloatHandle& output ) override; IPerformanceCounters* CreatePerformanceCounters( bool ) const override { return new CPerformanceCountersDefault(); } + // For Distributed only void AllReduce( const CFloatHandle& handle, int size ) override; void Broadcast( const CFloatHandle& handle, int size, int root ) override; void AbortDistributed() override; CMathEngineDistributedInfo GetDistributedInfo() override { return distributedInfo; } - bool IsDistributed() override { return distributedInfo.Threads > 1; } + bool IsDistributed() const override { return distributedInfo.Threads > 1; } #ifdef NEOML_USE_NCCL void SetDistributedCommunicator( const ncclUniqueId& uniqueId, const CMathEngineDistributedInfo& info, std::shared_ptr> isAbort ); #endif + protected: // IRawMemoryManager interface methods CMemoryHandle Alloc( size_t size ) override; diff --git a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h index 7fd5de691..d7af3e537 100644 --- a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h +++ b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h @@ -47,7 +47,9 @@ class CMetalMathEngine : public IMathEngine, public IRawMemoryManager { // IMathEngine interface methods TMathEngineType GetType() const override { return MET_Metal; } void SetReuseMemoryMode( bool enable ) override; + bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& /*handle*/, size_t /*size*/ ) override { ASSERT_EXPR( false ); } diff --git a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm index dc00b8b1b..c2c803acc 100644 --- a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm +++ b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm @@ -87,16 +87,28 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info ) void CMetalMathEngine::SetReuseMemoryMode( bool enable ) { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); memoryPool->SetReuseMemoryMode( enable ); } -void CVulkanMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) +bool CMetalMathEngine::GetReuseMemoryMode() const +{ + std::lock_guard lock( *mutex ); + return memoryPool->GetReuseMemoryMode(); +} + +void CMetalMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CMetalMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( *mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CMetalMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( *mutex ); @@ -118,7 +130,7 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info ) CMemoryHandle CMetalMathEngine::StackAlloc( size_t size ) { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); CMemoryHandle result = deviceStackAllocator->Alloc( size ); if( result.IsNull() ) { THROW_MEMORY_EXCEPTION; @@ -128,7 +140,7 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info ) void CMetalMathEngine::StackFree( const CMemoryHandle& ptr ) { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); deviceStackAllocator->Free( ptr ); } @@ -140,25 +152,25 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info ) size_t CMetalMathEngine::GetPeakMemoryUsage() const { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); return memoryPool->GetPeakMemoryUsage(); } void CMetalMathEngine::ResetPeakMemoryUsage() { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); memoryPool->ResetPeakMemoryUsage(); } size_t CMetalMathEngine::GetCurrentMemoryUsage() const { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); return memoryPool->GetCurrentMemoryUsage(); } size_t CMetalMathEngine::GetMemoryInPools() const { - std::lock_guard lock( *mutex ); + std::lock_guard lock( *mutex ); return memoryPool->GetMemoryInPools(); } diff --git a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp index a6490c938..c2a793397 100644 --- a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp +++ b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp @@ -97,12 +97,24 @@ void CVulkanMathEngine::SetReuseMemoryMode( bool enable ) memoryPool->SetReuseMemoryMode( enable ); } +bool CVulkanMathEngine::GetReuseMemoryMode() const +{ + std::lock_guard lock( mutex ); + return memoryPool->GetReuseMemoryMode(); +} + void CVulkanMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) { std::lock_guard lock( mutex ); memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CVulkanMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CVulkanMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( mutex ); diff --git a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h index 7bc94be21..530ca9e63 100644 --- a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h +++ b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h @@ -59,7 +59,9 @@ class CVulkanMathEngine : public IMathEngine, public IRawMemoryManager { // IMathEngine interface methods TMathEngineType GetType() const override { return MET_Vulkan; } void SetReuseMemoryMode( bool enable ) override; + bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& /*handle*/, size_t /*size*/ ) override { ASSERT_EXPR( false ); } @@ -643,6 +645,7 @@ class CVulkanMathEngine : public IMathEngine, public IRawMemoryManager { const CFloatHandle& ) override { ASSERT_EXPR( false ); } IPerformanceCounters* CreatePerformanceCounters( bool ) const override { return new CPerformanceCountersDefault(); } + // For Distributed only void AllReduce( const CFloatHandle& /*handle*/, int /*size*/ ) override {}; void Broadcast( const CFloatHandle& /*handle*/, int /*size*/, int /*root*/ ) override {}; diff --git a/NeoMathEngine/src/MemoryPool.cpp b/NeoMathEngine/src/MemoryPool.cpp index 47c207318..f92996c9e 100644 --- a/NeoMathEngine/src/MemoryPool.cpp +++ b/NeoMathEngine/src/MemoryPool.cpp @@ -106,8 +106,8 @@ CMemoryPool::CMemoryPool( size_t _memoryLimit, IRawMemoryManager* _rawMemoryMana CMemoryPool::~CMemoryPool() { - for( auto curPool : pools ) { - cleanUp( curPool.first ); + for( auto& curPool : pools ) { + cleanUp( &curPool.second ); for( auto curMemBufferPool : curPool.second.Pool ) { delete curMemBufferPool; } @@ -116,20 +116,31 @@ CMemoryPool::~CMemoryPool() void CMemoryPool::SetReuseMemoryMode( bool enable ) { - const std::thread::id id = std::this_thread::get_id(); - getThreadData( id )->Enabled = enable; + getThreadData()->Enabled = enable; +} + +bool CMemoryPool::GetReuseMemoryMode() const +{ + const CThreadData* threadData = getThreadData(); + return ( threadData != nullptr ) ? threadData->Enabled : false; } void CMemoryPool::SetThreadBufferMemoryThreshold( size_t threshold ) { - const std::thread::id id = std::this_thread::get_id(); - getThreadData( id )->BufferMemoryThreshold = threshold; + getThreadData()->BufferMemoryThreshold = threshold; +} + +size_t CMemoryPool::GetThreadBufferMemoryThreshold() const +{ + const CThreadData* threadData = getThreadData(); + return ( threadData != nullptr ) + ? threadData->BufferMemoryThreshold + : CThreadData::DefaultBufferMemoryThreshold; } CMemoryHandle CMemoryPool::Alloc( size_t size ) { - const std::thread::id id = std::this_thread::get_id(); - CThreadData& threadData = *getThreadData( id ); + CThreadData& threadData = *getThreadData(); CMemoryHandle result = tryAlloc( size, threadData ); if( !result.IsNull() ) { return result; @@ -159,19 +170,17 @@ void CMemoryPool::Free( const CMemoryHandle& handle ) size_t CMemoryPool::GetMemoryInPools() const { - std::thread::id id = std::this_thread::get_id(); - auto pool = pools.find( id ); - if( pool == pools.end() ) { + const CThreadData* threadData = getThreadData(); + if( threadData != nullptr ) { return 0; } - const TMemoryBufferPoolVector& threadPools = pool->second.Pool; - return std::accumulate( threadPools.begin(), threadPools.end(), size_t( 0 ), + return std::accumulate( threadData->Pool.begin(), threadData->Pool.end(), size_t( 0 ), [] ( const size_t& sum, const CMemoryBufferPool* cur ) { return sum + cur->GetMemoryInPool(); } ); } void CMemoryPool::CleanUp() { - cleanUp( std::this_thread::get_id() ); + cleanUp( getThreadData( /*forceCreate*/false ) ); } // Transfers handle from other thread owner to this thread @@ -186,8 +195,7 @@ void CMemoryPool::TransferHandleToThisThread( const CMemoryHandle& handle, size_ ASSERT_EXPR( size <= otherThreadBufferPool->BufferSize ); size = otherThreadBufferPool->BufferSize; // set actual allocated size - const std::thread::id id = std::this_thread::get_id(); - CThreadData& thisThreadData = *getThreadData( id ); + CThreadData& thisThreadData = *getThreadData(); // If on this thread pools are turned off if( !thisThreadData.Enabled ) { // Transfer the handle from that thread's pool just to heap, so @@ -214,37 +222,40 @@ void CMemoryPool::TransferHandleToThisThread( const CMemoryHandle& handle, size_ } } -CMemoryPool::CThreadData* CMemoryPool::getThreadData( std::thread::id id, bool forceCreate ) +const CMemoryPool::CThreadData* CMemoryPool::getThreadData() const { + auto it = pools.find( std::this_thread::get_id() ); + return ( it == pools.end() ) ? nullptr : &( it->second ); +} + +CMemoryPool::CThreadData* CMemoryPool::getThreadData( bool forceCreate ) +{ + std::thread::id id = std::this_thread::get_id(); auto it = pools.find( id ); if( it == pools.end() ) { if( !forceCreate ) { return nullptr; } - createPools( id ); - it = pools.find( id ); + return createPools( id ); } return &( it->second ); } -void CMemoryPool::createPools( std::thread::id id ) +CMemoryPool::CThreadData* CMemoryPool::createPools( std::thread::id id ) { CThreadData threadData; threadData.Enabled = defaultReuseMemoryMode; for( size_t i = 0; i < sizeof( BufferSizes ) / sizeof( *BufferSizes ); ++i ) { threadData.Pool.push_back( new CMemoryBufferPool( BufferSizes[i] ) ); } - - pools[id] = threadData; + return &( pools[id] = threadData ); } -void CMemoryPool::cleanUp( std::thread::id id ) +void CMemoryPool::cleanUp( CThreadData* threadData ) { - CThreadData* const threadData = getThreadData( id, /*forceCreate*/false ); if( threadData == nullptr ) { return; } - for( CMemoryBufferPool* pool : threadData->Pool ) { CMemoryBuffer* buffer = pool->TryAlloc(); while( buffer != 0 ) { diff --git a/NeoMathEngine/src/MemoryPool.h b/NeoMathEngine/src/MemoryPool.h index 97374dd11..9dfd83283 100644 --- a/NeoMathEngine/src/MemoryPool.h +++ b/NeoMathEngine/src/MemoryPool.h @@ -35,12 +35,16 @@ class CMemoryPool : public CCrtAllocatedObject { // Turns on and off the memory reuse mode for the current thread void SetReuseMemoryMode( bool enable ); + // Get the memory reuse mode state for the current thread + bool GetReuseMemoryMode() const; + // Change the memory blocks' sizes threshold for this thread from 1GB to the user size in bytes void SetThreadBufferMemoryThreshold( size_t threshold ); + // Get the memory blocks' sizes threshold for this thread + size_t GetThreadBufferMemoryThreshold() const; // Allocates the specified amount of memory CMemoryHandle Alloc( size_t size ); - // Frees the memory void Free( const CMemoryHandle& handle ); @@ -103,9 +107,10 @@ class CMemoryPool : public CCrtAllocatedObject { size_t peakMemoryUsage; // peak memory usage TUsedAddressMap usedMap; - CThreadData* getThreadData( std::thread::id id, bool forceCreate = true ); - void createPools( std::thread::id id ); - void cleanUp( std::thread::id id ); + const CThreadData* getThreadData() const; + CThreadData* getThreadData( bool forceCreate = true ); + CThreadData* createPools( std::thread::id id ); + void cleanUp( CThreadData* threadData ); CMemoryHandle tryAlloc( size_t size, CThreadData& data ); CMemoryHandle alloc( size_t size ); void freeMemory( size_t size, const CMemoryHandle& data );