From 67704de6602a1c0168eccecff6fb50c335d76c53 Mon Sep 17 00:00:00 2001 From: Kirill Golikov Date: Fri, 10 May 2024 16:19:14 +0200 Subject: [PATCH] [NeoMathEngine] GetThreadBufferMemoryThreshold Signed-off-by: Kirill Golikov --- NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h | 2 ++ NeoMathEngine/src/CPU/CpuMathEngine.cpp | 6 ++++++ NeoMathEngine/src/CPU/CpuMathEngine.h | 1 + NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp | 6 ++++++ NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h | 1 + NeoMathEngine/src/GPU/Metal/MetalMathEngine.h | 1 + NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm | 6 ++++++ NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp | 6 ++++++ NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h | 1 + NeoMathEngine/src/MemoryPool.cpp | 8 ++++++++ NeoMathEngine/src/MemoryPool.h | 3 +++ 11 files changed, 41 insertions(+) diff --git a/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h b/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h index 1746186c3..850f94f97 100644 --- a/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h +++ b/NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h @@ -1169,6 +1169,8 @@ class NEOMATHENGINE_API IMathEngine : public IDnnEngine { // memory blocks of a size <= this threshold would be allocated in buffers if 'reuse' mode enabled // memory blocks of a size > this threshold would be allocated in raw RAM memory (malloc/free) virtual void SetThreadBufferMemoryThreshold( size_t threshold ) = 0; + // Get the memory blocks' sizes threshold for this thread + virtual size_t GetThreadBufferMemoryThreshold() const = 0; virtual CMemoryHandle HeapAlloc( size_t count ) = 0; virtual void HeapFree( const CMemoryHandle& handle ) = 0; diff --git a/NeoMathEngine/src/CPU/CpuMathEngine.cpp b/NeoMathEngine/src/CPU/CpuMathEngine.cpp index 74711099f..d85bc38ef 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngine.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngine.cpp @@ -113,6 +113,12 @@ void CCpuMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CCpuMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CCpuMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( mutex ); diff --git a/NeoMathEngine/src/CPU/CpuMathEngine.h b/NeoMathEngine/src/CPU/CpuMathEngine.h index 5a6b68256..4eef40de5 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngine.h +++ b/NeoMathEngine/src/CPU/CpuMathEngine.h @@ -47,6 +47,7 @@ class CCpuMathEngine : public IMathEngine, public IRawMemoryManager { void SetReuseMemoryMode( bool enabled ) override; bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& handle, size_t size ) override; diff --git a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp index f906596b2..c8d5538ea 100644 --- a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp +++ b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.cpp @@ -92,6 +92,12 @@ void CCudaMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CCudaMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CCudaMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( mutex ); diff --git a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h index d00eb709c..ddaf13c1c 100644 --- a/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h +++ b/NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h @@ -54,6 +54,7 @@ class CCudaMathEngine : public IMathEngine, public IRawMemoryManager { void SetReuseMemoryMode( bool enable ) override; bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& handle, size_t size ) override; diff --git a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h index bdece6020..d7af3e537 100644 --- a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h +++ b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.h @@ -49,6 +49,7 @@ class CMetalMathEngine : public IMathEngine, public IRawMemoryManager { void SetReuseMemoryMode( bool enable ) override; bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& /*handle*/, size_t /*size*/ ) override { ASSERT_EXPR( false ); } diff --git a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm index 9cff2cc3c..c2c803acc 100644 --- a/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm +++ b/NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm @@ -103,6 +103,12 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info ) memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CMetalMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( *mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CMetalMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( *mutex ); diff --git a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp index 262667f4f..c2a793397 100644 --- a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp +++ b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.cpp @@ -109,6 +109,12 @@ void CVulkanMathEngine::SetThreadBufferMemoryThreshold( size_t threshold ) memoryPool->SetThreadBufferMemoryThreshold( threshold ); } +size_t CVulkanMathEngine::GetThreadBufferMemoryThreshold() const +{ + std::lock_guard lock( mutex ); + return memoryPool->GetThreadBufferMemoryThreshold(); +} + CMemoryHandle CVulkanMathEngine::HeapAlloc( size_t size ) { std::lock_guard lock( mutex ); diff --git a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h index 4551fd92a..530ca9e63 100644 --- a/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h +++ b/NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h @@ -61,6 +61,7 @@ class CVulkanMathEngine : public IMathEngine, public IRawMemoryManager { void SetReuseMemoryMode( bool enable ) override; bool GetReuseMemoryMode() const override; void SetThreadBufferMemoryThreshold( size_t threshold ) override; + size_t GetThreadBufferMemoryThreshold() const override; CMemoryHandle HeapAlloc( size_t count ) override; void HeapFree( const CMemoryHandle& handle ) override; void TransferHandleToThisThread( const CMemoryHandle& /*handle*/, size_t /*size*/ ) override { ASSERT_EXPR( false ); } diff --git a/NeoMathEngine/src/MemoryPool.cpp b/NeoMathEngine/src/MemoryPool.cpp index d6597ebe3..f92996c9e 100644 --- a/NeoMathEngine/src/MemoryPool.cpp +++ b/NeoMathEngine/src/MemoryPool.cpp @@ -130,6 +130,14 @@ void CMemoryPool::SetThreadBufferMemoryThreshold( size_t threshold ) getThreadData()->BufferMemoryThreshold = threshold; } +size_t CMemoryPool::GetThreadBufferMemoryThreshold() const +{ + const CThreadData* threadData = getThreadData(); + return ( threadData != nullptr ) + ? threadData->BufferMemoryThreshold + : CThreadData::DefaultBufferMemoryThreshold; +} + CMemoryHandle CMemoryPool::Alloc( size_t size ) { CThreadData& threadData = *getThreadData(); diff --git a/NeoMathEngine/src/MemoryPool.h b/NeoMathEngine/src/MemoryPool.h index 23355c227..9dfd83283 100644 --- a/NeoMathEngine/src/MemoryPool.h +++ b/NeoMathEngine/src/MemoryPool.h @@ -37,8 +37,11 @@ class CMemoryPool : public CCrtAllocatedObject { void SetReuseMemoryMode( bool enable ); // Get the memory reuse mode state for the current thread bool GetReuseMemoryMode() const; + // Change the memory blocks' sizes threshold for this thread from 1GB to the user size in bytes void SetThreadBufferMemoryThreshold( size_t threshold ); + // Get the memory blocks' sizes threshold for this thread + size_t GetThreadBufferMemoryThreshold() const; // Allocates the specified amount of memory CMemoryHandle Alloc( size_t size );