Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeoMathEngine] GetReuseMemoryMode and GetCurrentMemoryUsage #1065

Merged
merged 4 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[NeoMathEngine] GetThreadBufferMemoryThreshold
Signed-off-by: Kirill Golikov <kirill.golikov@abbyy.com>
  • Loading branch information
favorart committed May 10, 2024
commit ba7fbaecd79e7e236871bef7f79cdd09d38328d3
2 changes: 2 additions & 0 deletions NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,8 @@ class NEOMATHENGINE_API IMathEngine : public IDnnEngine {
// memory blocks of a size <= this threshold would be allocated in buffers if 'reuse' mode enabled
// memory blocks of a size > this threshold would be allocated in raw RAM memory (malloc/free)
virtual void SetThreadBufferMemoryThreshold( size_t threshold ) = 0;
// Get the memory blocks' sizes threshold for this thread
virtual size_t GetThreadBufferMemoryThreshold() const = 0;

virtual CMemoryHandle HeapAlloc( size_t count ) = 0;
virtual void HeapFree( const CMemoryHandle& handle ) = 0;
Expand Down
6 changes: 6 additions & 0 deletions NeoMathEngine/src/CPU/CpuMathEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ void CCpuMathEngine::SetThreadBufferMemoryThreshold( size_t threshold )
memoryPool->SetThreadBufferMemoryThreshold( threshold );
}

size_t CCpuMathEngine::GetThreadBufferMemoryThreshold() const
{
std::lock_guard<std::mutex> lock( mutex );
return memoryPool->GetThreadBufferMemoryThreshold();
}

CMemoryHandle CCpuMathEngine::HeapAlloc( size_t size )
{
std::lock_guard<std::mutex> lock( mutex );
Expand Down
1 change: 1 addition & 0 deletions NeoMathEngine/src/CPU/CpuMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class CCpuMathEngine : public IMathEngine, public IRawMemoryManager {
void SetReuseMemoryMode( bool enabled ) override;
bool GetReuseMemoryMode() const override;
void SetThreadBufferMemoryThreshold( size_t threshold ) override;
size_t GetThreadBufferMemoryThreshold() const override;
CMemoryHandle HeapAlloc( size_t count ) override;
void HeapFree( const CMemoryHandle& handle ) override;
void TransferHandleToThisThread( const CMemoryHandle& handle, size_t size ) override;
Expand Down
6 changes: 6 additions & 0 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ void CCudaMathEngine::SetThreadBufferMemoryThreshold( size_t threshold )
memoryPool->SetThreadBufferMemoryThreshold( threshold );
}

size_t CCudaMathEngine::GetThreadBufferMemoryThreshold() const
{
std::lock_guard<std::mutex> lock( mutex );
return memoryPool->GetThreadBufferMemoryThreshold();
}

CMemoryHandle CCudaMathEngine::HeapAlloc( size_t size )
{
std::lock_guard<std::mutex> lock( mutex );
Expand Down
1 change: 1 addition & 0 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class CCudaMathEngine : public IMathEngine, public IRawMemoryManager {
void SetReuseMemoryMode( bool enable ) override;
bool GetReuseMemoryMode() const override;
void SetThreadBufferMemoryThreshold( size_t threshold ) override;
size_t GetThreadBufferMemoryThreshold() const override;
CMemoryHandle HeapAlloc( size_t count ) override;
void HeapFree( const CMemoryHandle& handle ) override;
void TransferHandleToThisThread( const CMemoryHandle& handle, size_t size ) override;
Expand Down
1 change: 1 addition & 0 deletions NeoMathEngine/src/GPU/Metal/MetalMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class CMetalMathEngine : public IMathEngine, public IRawMemoryManager {
void SetReuseMemoryMode( bool enable ) override;
bool GetReuseMemoryMode() const override;
void SetThreadBufferMemoryThreshold( size_t threshold ) override;
size_t GetThreadBufferMemoryThreshold() const override;
CMemoryHandle HeapAlloc( size_t count ) override;
void HeapFree( const CMemoryHandle& handle ) override;
void TransferHandleToThisThread( const CMemoryHandle& /*handle*/, size_t /*size*/ ) override { ASSERT_EXPR( false ); }
Expand Down
6 changes: 6 additions & 0 deletions NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info )
memoryPool->SetThreadBufferMemoryThreshold( threshold );
}

size_t CMetalMathEngine::GetThreadBufferMemoryThreshold() const
{
std::lock_guard<CMutex> lock( *mutex );
return memoryPool->GetThreadBufferMemoryThreshold();
}

CMemoryHandle CMetalMathEngine::HeapAlloc( size_t size )
{
std::lock_guard<CMutex> lock( *mutex );
Expand Down
6 changes: 6 additions & 0 deletions NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ void CVulkanMathEngine::SetThreadBufferMemoryThreshold( size_t threshold )
memoryPool->SetThreadBufferMemoryThreshold( threshold );
}

size_t CVulkanMathEngine::GetThreadBufferMemoryThreshold() const
{
std::lock_guard<std::mutex> lock( mutex );
return memoryPool->GetThreadBufferMemoryThreshold();
}

CMemoryHandle CVulkanMathEngine::HeapAlloc( size_t size )
{
std::lock_guard<std::mutex> lock( mutex );
Expand Down
1 change: 1 addition & 0 deletions NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class CVulkanMathEngine : public IMathEngine, public IRawMemoryManager {
void SetReuseMemoryMode( bool enable ) override;
bool GetReuseMemoryMode() const override;
void SetThreadBufferMemoryThreshold( size_t threshold ) override;
size_t GetThreadBufferMemoryThreshold() const override;
CMemoryHandle HeapAlloc( size_t count ) override;
void HeapFree( const CMemoryHandle& handle ) override;
void TransferHandleToThisThread( const CMemoryHandle& /*handle*/, size_t /*size*/ ) override { ASSERT_EXPR( false ); }
Expand Down
8 changes: 8 additions & 0 deletions NeoMathEngine/src/MemoryPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ void CMemoryPool::SetThreadBufferMemoryThreshold( size_t threshold )
getThreadData()->BufferMemoryThreshold = threshold;
}

size_t CMemoryPool::GetThreadBufferMemoryThreshold() const
{
const CThreadData* threadData = getThreadData();
return ( threadData != nullptr )
? threadData->BufferMemoryThreshold
: CThreadData::DefaultBufferMemoryThreshold;
}

CMemoryHandle CMemoryPool::Alloc( size_t size )
{
CThreadData& threadData = *getThreadData();
Expand Down
3 changes: 3 additions & 0 deletions NeoMathEngine/src/MemoryPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ class CMemoryPool : public CCrtAllocatedObject {
void SetReuseMemoryMode( bool enable );
// Get the memory reuse mode state for the current thread
bool GetReuseMemoryMode() const;

// Change the memory blocks' sizes threshold for this thread from 1GB to the user size in bytes
void SetThreadBufferMemoryThreshold( size_t threshold );
// Get the memory blocks' sizes threshold for this thread
size_t GetThreadBufferMemoryThreshold() const;

// Allocates the specified amount of memory
CMemoryHandle Alloc( size_t size );
Expand Down