Skip to content

Commit

Permalink
Add integer version of IMathEngine::SumMatrixRows (neoml-lib#716)
Browse files Browse the repository at this point in the history
Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>
  • Loading branch information
Valeriy Fedyunin authored Aug 1, 2022
1 parent 9dc80ee commit a6eb044
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 10 deletions.
2 changes: 2 additions & 0 deletions NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ class NEOMATHENGINE_API IBlasEngine : public IVectorMathEngine {
int matrixHeight, int matrixWidth) = 0;
virtual void SumMatrixRows(int batchSize, const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) = 0;
virtual void SumMatrixRows(int batchSize, const CIntHandle& resultHandle, const CConstIntHandle& matrixHandle,
int matrixHeight, int matrixWidth) = 0;

// Calculates the total of matrix columns
virtual void SumMatrixColumns(const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/CPU/CpuMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ class CCpuMathEngine : public IMathEngine, public IRawMemoryManager {
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CIntHandle& resultHandle, const CConstIntHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixColumns(const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void MatrixColumnsEltwiseDivide( const CConstFloatHandle& matrix, int matrixHeight, int matrixWidth,
Expand Down
17 changes: 17 additions & 0 deletions NeoMathEngine/src/CPU/CpuMathEngineBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,23 @@ void CCpuMathEngine::SumMatrixRows(int batchSize,
SumMatrixRowsAdd(batchSize, resultHandle, matrixHandle, matrixHeight, matrixWidth);
}

void CCpuMathEngine::SumMatrixRows(int batchSize, const CIntHandle& resultHandle, const CConstIntHandle& matrixHandle,
int matrixHeight, int matrixWidth)
{
CCpuExecutionScope scope;

VectorFill( resultHandle, 0, batchSize * matrixWidth );
CConstIntHandle matrix = matrixHandle;
CIntHandle result = resultHandle;
for( int i = 0; i < batchSize; ++i ) {
for( int j = 0; j < matrixHeight; j++ ) {
VectorAdd(result, matrix, result, matrixWidth);
matrix += matrixWidth;
}
result += matrixWidth;
}
}

void CCpuMathEngine::SumMatrixRowsAdd(int batchSize,
const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle, int matrixHeight, int matrixWidth)
{
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ class CCudaMathEngine : public IMathEngine, public IRawMemoryManager {
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CIntHandle& resultHandle, const CConstIntHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixColumns(const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void MatrixColumnsEltwiseDivide( const CConstFloatHandle& matrix, int matrixHeight, int matrixWidth,
Expand Down
19 changes: 19 additions & 0 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngineBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,25 @@ void CCudaMathEngine::SubVectorFromMatrixColumns(const CConstFloatHandle& matrix
(GetRaw(matrixHandle), GetRaw(resultHandle), matrixHeight, matrixWidth, GetRaw(vectorHandle));
}

void CCudaMathEngine::SumMatrixRows( int batchSize, const CIntHandle& resultHandle,
const CConstIntHandle& matrixHandle, int matrixHeight, int matrixWidth )
{
ASSERT_EXPR( matrixHandle.GetMathEngine() == this );
ASSERT_EXPR( resultHandle.GetMathEngine() == this );
SetCudaDevice( device->DeviceNumber );

VectorFill( resultHandle, 0, batchSize * matrixWidth );

const int height = ( matrixHeight + SumMatrixRowsAddCombineCount - 1 ) / SumMatrixRowsAddCombineCount;

dim3 blockCount;
dim3 threadCount;
getCudaTaskGrid3D( blockCount, threadCount, batchSize, height, matrixWidth );

SumMatrixRowsAddKernel<<<blockCount, threadCount>>>
( batchSize, GetRaw(resultHandle), GetRaw(matrixHandle), matrixHeight, matrixWidth );
}

void CCudaMathEngine::SumMatrixRowsAdd(
int batchSize, const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth )
Expand Down
5 changes: 3 additions & 2 deletions NeoMathEngine/src/GPU/CUDA/Kernels/CudaBlasKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ __global__ void SubVectorFromMatrixColumnsKernel(const float* __restrict__ matri
}

const int SumMatrixRowsAddCombineCount = 128;
template<class T>
__global__ void SumMatrixRowsAddKernel(
int batchSize, float* result, const float* __restrict__ matrix,
int batchSize, T* result, const T* __restrict__ matrix,
int matrixHeight, int matrixWidth )
{
const int height = ( matrixHeight + SumMatrixRowsAddCombineCount - 1 ) / SumMatrixRowsAddCombineCount;
Expand All @@ -198,7 +199,7 @@ __global__ void SumMatrixRowsAddKernel(
}

matrix += ( batchIndex * matrixHeight + rowIndex ) * matrixWidth + colIndex;
float sum = *matrix;
T sum = *matrix;
for(int j = rowIndex + 1; j < rowEndIndex; ++j) {
matrix += matrixWidth;
sum += *matrix;
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/GPU/Metal/MetalMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ class CMetalMathEngine : public IMathEngine, public IRawMemoryManager {
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CIntHandle& resultHandle, const CConstIntHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixColumns(const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void MatrixColumnsEltwiseDivide( const CConstFloatHandle& matrix, int matrixHeight, int matrixWidth,
Expand Down
5 changes: 5 additions & 0 deletions NeoMathEngine/src/GPU/Metal/MetalMathEngineBlas.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,11 @@ C2DKernel kernel( *queue, "matrixKernelMultiplyTransposedLookupMatrixByVector",
SumMatrixRowsAdd(batchSize, resultHandle, matrixHandle, matrixHeight, matrixWidth);
}

void CMetalMathEngine::SumMatrixRows( int, const CIntHandle&, const CConstIntHandle&, int, int )
{
ASSERT_EXPR( false );
}

void CMetalMathEngine::SingularValueDecomposition( const CFloatHandle&, int, int, const CFloatHandle&, const CFloatHandle&,
const CFloatHandle&, const CFloatHandle&, bool, bool )
{
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ class CVulkanMathEngine : public IMathEngine, public IRawMemoryManager {
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixRows(int batchSize, const CIntHandle& resultHandle, const CConstIntHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void SumMatrixColumns(const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth) override;
void MatrixColumnsEltwiseDivide( const CConstFloatHandle& matrix, int matrixHeight, int matrixWidth,
Expand Down
5 changes: 5 additions & 0 deletions NeoMathEngine/src/GPU/Vulkan/VulkanMathEngineBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,11 @@ void CVulkanMathEngine::SumMatrixRows( int batchSize, const CFloatHandle& result
&param, sizeof(param), 0, 0, 0, 0, bufs, sizes, 2, matrixWidth, 1, batchSize);
}

void CVulkanMathEngine::SumMatrixRows( int, const CIntHandle&, const CConstIntHandle&, int, int )
{
ASSERT_EXPR( false );
}

void CVulkanMathEngine::SumMatrixColumns( const CFloatHandle& resultHandle, const CConstFloatHandle& matrixHandle,
int matrixHeight, int matrixWidth )
{
Expand Down
21 changes: 13 additions & 8 deletions NeoMathEngine/test/src/inference/SumMatrixRowsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ limitations under the License.
using namespace NeoML;
using namespace NeoMLTest;

static void sumMatrixRowsAddNaive( std::vector<float>& vector, const std::vector<float>& matrix, int batchSize, int height, int width )
template<class T>
static void sumMatrixRowsAddNaive( std::vector<T>& vector, const std::vector<T>& matrix,
int batchSize, int height, int width )
{
for( int b = 0; b < batchSize; ++b ) {
for( int h = 0; h < height; ++h ) {
Expand All @@ -29,6 +31,7 @@ static void sumMatrixRowsAddNaive( std::vector<float>& vector, const std::vector
}
}

template<class T>
static void sumMatrixRowsTestImpl( const CTestParams& params, int seed )
{
CRandom random( seed );
Expand All @@ -42,19 +45,20 @@ static void sumMatrixRowsTestImpl( const CTestParams& params, int seed )
const int width = random.UniformInt( widthInterval.Begin, widthInterval.End );
const int batchSize = random.UniformInt( batchSizeInterval.Begin, batchSizeInterval.End );

CREATE_FILL_FLOAT_ARRAY( matrix, valuesInterval.Begin, valuesInterval.End, batchSize * height * width, random )
CREATE_FILL_FLOAT_ARRAY( getVector, valuesInterval.Begin, valuesInterval.End, batchSize * width, random )
std::vector<float> expectedVector;
CREATE_FILL_ARRAY( T, matrix, valuesInterval.Begin, valuesInterval.End, batchSize * height * width, random )
CREATE_FILL_ARRAY( T, getVector, valuesInterval.Begin, valuesInterval.End, batchSize * width, random )
std::vector<T> expectedVector;
expectedVector = getVector;

for( size_t i = 0; i < expectedVector.size(); ++i ) {
expectedVector[i] = 0.f;
expectedVector[i] = static_cast<T>( 0 );
}
MathEngine().SumMatrixRows( batchSize, CARRAY_FLOAT_WRAPPER( getVector ), CARRAY_FLOAT_WRAPPER( matrix ), height, width );
MathEngine().SumMatrixRows( batchSize, CARRAY_WRAPPER( T, getVector ), CARRAY_WRAPPER( T, matrix ), height, width );
sumMatrixRowsAddNaive( expectedVector, matrix, batchSize, height, width );

for( int i = 0; i < batchSize * width; ++i ) {
ASSERT_NEAR( expectedVector[i], getVector[i], 1e-3 );
ASSERT_NEAR( static_cast<double>( expectedVector[i] ),
static_cast<double>( getVector[i] ), 1e-3 );
}
}

Expand Down Expand Up @@ -88,5 +92,6 @@ INSTANTIATE_TEST_CASE_P( CSumMatrixRowsTestInstantiation, CSumMatrixRowsTest,

TEST_P( CSumMatrixRowsTest, Random )
{
RUN_TEST_IMPL( sumMatrixRowsTestImpl )
RUN_TEST_IMPL( sumMatrixRowsTestImpl<float> )
RUN_TEST_IMPL( sumMatrixRowsTestImpl<int> )
}

0 comments on commit a6eb044

Please sign in to comment.