Skip to content

Commit

Permalink
reduce_sum realization (#348)
Browse files Browse the repository at this point in the history
* reduce_sum realization

Signed-off-by: valinurovdenis <denis.valinurow@yandex.ru>

* reduce_sum realization

Signed-off-by: valinurovdenis <denis.valinurow@yandex.ru>

* reduce_sum realization

Signed-off-by: valinurovdenis <denis.valinurow@yandex.ru>
  • Loading branch information
valinurovdenis committed Jun 9, 2021
1 parent 48db01d commit 77fdc68
Show file tree
Hide file tree
Showing 16 changed files with 202 additions and 31 deletions.
4 changes: 2 additions & 2 deletions NeoML/Python/neoml/AutoDiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def max(a, b):

raise ValueError('At least one of `a` and `b` should be neoml.Blob.')

def sum(a):
def sum(a, axis=None):
"""Calculates the total sum of blob elements.
"""
if not type(a) is Blob:
Expand All @@ -97,7 +97,7 @@ def sum(a):
if a.size == 0:
raise ValueError("The blob shouldn't be empty.")

return Blob(PythonWrapper.blob_sum(a._internal))
return Blob(PythonWrapper.blob_sum(a._internal, -1 if axis is None else int(axis)))

def neg(a):
"""Returns the negative of a blob or a number.
Expand Down
4 changes: 2 additions & 2 deletions NeoML/Python/src/PyAutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ void InitializeTape(py::module& m)
return CPyBlob( second.MathEngineOwner(), const_cast<CDnnBlob*>(result.Ptr()) );
}, py::return_value_policy::reference );

m.def( "blob_sum", [](const CPyBlob& first) {
CPtr<const CDnnBlob> result( Sum( first.Blob() ) );
m.def( "blob_sum", [](const CPyBlob& first, int axis) {
CPtr<const CDnnBlob> result( Sum( first.Blob(), axis) );
return CPyBlob( first.MathEngineOwner(), const_cast<CDnnBlob*>(result.Ptr()) );
}, py::return_value_policy::reference );

Expand Down
1 change: 1 addition & 0 deletions NeoML/Python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,7 @@ def test_autodiff_functions(self):
self.assertTrue( np.equal( ad.clip(const2, 3, 4).asarray(), 3 * ones ).all() )
self.assertTrue( np.equal( ad.top_k(const2, 3).asarray(), [2, 2, 2] ).all() )
self.assertTrue( np.equal( ad.binary_cross_entropy(const0, const0, False).asarray(), 0 * ones ).all() )
self.assertTrue( np.equal( ad.sum(blob, 1).asarray(), 3 * np.ones((2, 1, 1, 1, 1, 2, 3)) ).all() )

def test_cross_entropy_loss(self):
math_engine = neoml.MathEngine.CpuMathEngine(1)
Expand Down
2 changes: 1 addition & 1 deletion NeoML/include/NeoML/Dnn/AutoDiffFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ CPtr<const CDnnBlob> NEOML_API Max( const CDnnBlob* first, float second );
CPtr<const CDnnBlob> NEOML_API Max( float first, const CDnnBlob* second );

// Calculates the total sum of all blob elements and returns it as a single-element blob.
NEOML_API CPtr<const CDnnBlob> Sum( const CDnnBlob* first );
NEOML_API CPtr<const CDnnBlob> Sum( const CDnnBlob* first, int axis );

// Creates the blob each element of which is the negative value of the corresponding element of the specified blob.
// res[i] = -first[i]
Expand Down
60 changes: 51 additions & 9 deletions NeoML/src/Dnn/AutoDiffFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,16 +608,34 @@ CPtr<const CDnnBlob> NEOML_API Max( float first, const CDnnBlob* second )

class CTapeSum : public ITapeOperation {
public:
explicit CTapeSum( const CDnnBlob& first );
explicit CTapeSum( const CDnnBlob& first, int axis );

CPtr<CDnnBlob> Jacobian( const CTapeBlob* var ) const override;

static void GetDimensions( const CDnnBlob* first, int axis, int& followingDimension,
int& dimensions, int& precedingDimension );
private:
CPtr<const CDnnBlob> first;
int axis;
};

CTapeSum::CTapeSum( const CDnnBlob& _first ) :
first( &_first )
void CTapeSum::GetDimensions( const CDnnBlob* first, int axis, int& followingDimension,
int& dimension, int& precedingDimension )
{
followingDimension = 1;
for( int d = 0; d < axis; d++ ) {
followingDimension *= first->DimSize( d );
}
dimension = first->DimSize( axis );
precedingDimension = 1;
for( int d = axis + 1; d < BD_Count; d++ ) {
precedingDimension *= first->DimSize( d );
}
}

CTapeSum::CTapeSum( const CDnnBlob& _first, int axis ) :
first( &_first ),
axis( axis )
{
NeoAssert( dynamic_cast<const CTapeBlob*>(first.Ptr()) != 0 );
}
Expand All @@ -635,24 +653,48 @@ CPtr<CDnnBlob> CTapeSum::Jacobian( const CTapeBlob* var ) const
return jacobian;
}

CPtr<CDnnBlob> result = CDnnBlob::CreateBlob( jacobian->GetMathEngine(), { width } );
result->GetMathEngine().SumMatrixRows( 1, result->GetData(), jacobian->GetData(), height, width );
CPtr<CDnnBlob> result;
if( axis == -1 ) {
result = CDnnBlob::CreateBlob( jacobian->GetMathEngine(), { width } );
result->GetMathEngine().SumMatrixRows( 1, result->GetData(), jacobian->GetData(), height, width );
} else {
int precedingDimension;
int dimension;
int followingDimension;
GetDimensions( first, axis, followingDimension, dimension, precedingDimension );
result = CDnnBlob::CreateBlob( jacobian->GetMathEngine(), { height / dimension, 1, 1, 1, 1, 1, width } );
result->GetMathEngine().VectorSumAlongDimension( jacobian->GetData(), precedingDimension * width, dimension,
followingDimension, result->GetData() );
}
return result;
}

CPtr<const CDnnBlob> Sum( const CDnnBlob* first )
CPtr<const CDnnBlob> Sum( const CDnnBlob* first, int axis )
{
NeoAssert( first != 0 );
NeoAssert( axis >= -1 && axis < BD_Count );

IMathEngine& mathEngine = first->GetMathEngine();
const CTapeBlob* tapeBlob = dynamic_cast<const CTapeBlob*>( first );
IGradientTape* tape = tapeBlob != 0 ? tapeBlob->Tape() : 0;

CPtr<CTapeBlob> result( new CTapeBlob( tape, mathEngine, CBlobDesc( {1} ) ) );
mathEngine.VectorSum( first->GetData(), first->GetDataSize(), result->GetData() );
CPtr<CTapeBlob> result;
if( axis == -1 ) {
result = new CTapeBlob( tape, mathEngine, CBlobDesc( { 1 } ) );
mathEngine.VectorSum( first->GetData(), first->GetDataSize(), result->GetData() );
} else {
int precedingDimension;
int dimension;
int followingDimension;
CTapeSum::GetDimensions( first, axis, followingDimension, dimension, precedingDimension );
CBlobDesc desc = first->GetDesc();
desc.SetDimSize( axis, 1 );
result = new CTapeBlob( tape, mathEngine, desc );
mathEngine.VectorSumAlongDimension( first->GetData(), precedingDimension, dimension, followingDimension, result->GetData() );
}

if( tape != 0 ) {
CPtr<ITapeOperation> operation( new CTapeSum( *tapeBlob ) );
CPtr<ITapeOperation> operation( new CTapeSum( *tapeBlob, axis ) );
tape->Add( result, operation );
}

Expand Down
88 changes: 71 additions & 17 deletions NeoML/test/src/AutoDiffTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,41 +210,95 @@ TEST_F( CAutoDiffTest, TestSub1 )
}
}

TEST_F( CAutoDiffTest, TestSum )
TEST_F( CAutoDiffTest, TestSum1 )
{
for( int axis : { -1, 6 } ) {
CGradientTape tape;

const int VectorSize = 16;

CArray<float> xData;
xData.InsertAt( 1.0, 0, VectorSize );
CPtr<CDnnBlob> xBlob( CDnnBlob::CreateVector( MathEngine(), CT_Float, xData.Size() ) );
xBlob->CopyFrom( xData.GetPtr() );
CPtr<const CDnnBlob> x = tape.Variable( *xBlob );

float valuesA[VectorSize] = { 0.501, 0.001, 0.002, 0.003, 0.004, 0.004, 0.005, 0.505,
0.010, 0.010, 0.011, 0.490, 0.489, 0.488, 0.487, 0.491 };
CPtr<CDnnBlob> a( CDnnBlob::CreateVector( MathEngine(), CT_Float, VectorSize ) );
a->CopyFrom( valuesA );

float valuesB[VectorSize] = { 0.001, 0.401, 0.002, 0.003, 0.004, 0.004, 0.005, 0.010,
0.405, 0.010, 0.011, 0.289, 0.390, 0.288, 0.391, 0.291 };
CPtr<CDnnBlob> b( CDnnBlob::CreateVector( MathEngine(), CT_Float, VectorSize ) );
b->CopyFrom( valuesB );

CPtr<const CDnnBlob> ax = Mul( x, a );
CPtr<const CDnnBlob> bx = Mul( b, x );

CPtr<const CDnnBlob> top4ax = TopK( ax, 4 );
CPtr<const CDnnBlob> top4bx = TopK( bx, 4 );

CPtr<const CDnnBlob> loss = Sum( Add( top4ax, top4bx ), axis );

CArray<float> lossData;
lossData.SetSize( loss->GetDataSize() );
loss->CopyTo( lossData.GetPtr() );

float lossRes[1] = { 3.574 };
for( int i = 0; i < _countof( lossRes ); i++ ) {
ASSERT_NEAR( lossRes[i], lossData[i], 1e-3 );
}

CPtr<const CDnnBlob> grad = tape.Gradient( *loss, *x );

CArray<float> gradData;
gradData.SetSize( grad->GetDataSize() );
grad->CopyTo( gradData.GetPtr() );

float gradRes[VectorSize] = { 0.501, 0.401, 0, 0,
0, 0, 0, 0.505,
0.405, 0, 0, 0.49,
0.39, 0, 0.391, 0.491 };
for( int i = 0; i < _countof( gradRes ); i++ ) {
ASSERT_NEAR( gradRes[i], gradData[i], 1e-3 );
}
}
}

TEST_F( CAutoDiffTest, TestSum2 )
{
CGradientTape tape;

const int VectorSize = 16;

CArray<float> xData;
xData.InsertAt( 1.0, 0, VectorSize );
CPtr<CDnnBlob> xBlob( CDnnBlob::CreateVector( MathEngine(), CT_Float, xData.Size() ) );
auto dimensions = { 2, 1, 1, 4, 2 };
CPtr<CDnnBlob> xBlob( CDnnBlob::CreateTensor( MathEngine(), CT_Float, dimensions ) );
xBlob->CopyFrom( xData.GetPtr() );
CPtr<const CDnnBlob> x = tape.Variable( *xBlob );

float valuesA[VectorSize] = { 0.501, 0.001, 0.002, 0.003, 0.004, 0.004, 0.005, 0.505,
0.010, 0.010, 0.011, 0.490, 0.489, 0.488, 0.487, 0.491 };
CPtr<CDnnBlob> a( CDnnBlob::CreateVector( MathEngine(), CT_Float, VectorSize ) );
float valuesA[VectorSize] = { 0.28, 0.3, 0.2 , 0.73, 0.73, 0.72, 0.33, 0.11,
0.08, 0.49, 0.09, 0.76, 0.05, 0.65, 0.28, 0.97 };
CPtr<CDnnBlob> a( CDnnBlob::CreateTensor( MathEngine(), CT_Float, dimensions ) );
a->CopyFrom( valuesA );

float valuesB[VectorSize] = { 0.001, 0.401, 0.002, 0.003, 0.004, 0.004, 0.005, 0.010,
0.405, 0.010, 0.011, 0.289, 0.390, 0.288, 0.391, 0.291 };
CPtr<CDnnBlob> b( CDnnBlob::CreateVector( MathEngine(), CT_Float, VectorSize ) );
float valuesB[VectorSize] = { 0.1 , 0.08, 0.46, 0.26, 0.23, 0.08, 0.33, 0.34,
0.1 , 0.4, 0.37, 0.41, 0.32, 0.53, 0.43, 0.82 };
CPtr<CDnnBlob> b( CDnnBlob::CreateTensor( MathEngine(), CT_Float, dimensions ) );
b->CopyFrom( valuesB );

CPtr<const CDnnBlob> ax = Mul(x, a);
CPtr<const CDnnBlob> bx = Mul(b, x);

CPtr<const CDnnBlob> top4ax = TopK(ax, 4);
CPtr<const CDnnBlob> top4bx = TopK(bx, 4);

CPtr<const CDnnBlob> loss = Sum( Add(top4ax, top4bx) );
CPtr<const CDnnBlob> loss = Sum( Add(ax, bx), 3 );

CArray<float> lossData;
lossData.SetSize( loss->GetDataSize() );
loss->CopyTo( lossData.GetPtr() );

float lossRes[1] = { 3.574 };
float lossRes[4] = { 2.66, 2.62, 1.72, 5.03 };
for( int i = 0; i < _countof(lossRes); i++ ) {
ASSERT_NEAR( lossRes[i], lossData[i], 1e-3 );
}
Expand All @@ -255,10 +309,10 @@ TEST_F( CAutoDiffTest, TestSum )
gradData.SetSize( grad->GetDataSize() );
grad->CopyTo( gradData.GetPtr() );

float gradRes[VectorSize] = { 0.501, 0.401, 0, 0,
0, 0, 0, 0.505,
0.405, 0, 0, 0.49,
0.39, 0, 0.391, 0.491 };
float gradRes[VectorSize] = { 0.38, 0.38, 0.66, 0.99,
0.96, 0.8, 0.66, 0.45,
0.18, 0.89, 0.46, 1.17,
0.37, 1.18, 0.71, 1.79 };
for( int i = 0; i < _countof(gradRes); i++ ) {
ASSERT_NEAR( gradRes[i], gradData[i], 1e-3 );
}
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class NEOMATHENGINE_API IVectorMathEngine : public CCrtAllocatedObject {
// The resultHandle is not set to null
virtual void VectorSumAdd(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) = 0;
virtual void VectorNegSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) = 0;
virtual void VectorSumAlongDimension(const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle) = 0;

// result = (first == second) ? 1.0 : 0.0 elementwise
virtual void VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/CPU/CpuMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class CCpuMathEngine : public IMathEngine, public IRawMemoryManager {
void VectorSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAdd(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorNegSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAlongDimension( const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle ) override;
void VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
const CFloatHandle& resultHandle, int vectorSize ) override;
void VectorEqualValue( const CConstIntHandle& firstHandle,
Expand Down
20 changes: 20 additions & 0 deletions NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ void CCpuMathEngine::VectorNegSum(const CConstFloatHandle& firstHandle, int vect
*GetRaw(resultHandle) = -*GetRaw(resultHandle);
}

void CCpuMathEngine::VectorSumAlongDimension( const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle )
{
ASSERT_EXPR( firstHandle.GetMathEngine() == this );
ASSERT_EXPR( resultHandle.GetMathEngine() == this );

int firstIndex = 0;
int resultIndex = 0;

for( int i = 0; i < followingDimension; i++ ) {
VectorCopy( resultHandle + resultIndex, firstHandle + firstIndex, precedingDimension );
firstIndex += precedingDimension;
for( int j = 1; j < dimension; j++ ) {
VectorAdd( firstHandle + firstIndex, resultHandle + resultIndex, resultHandle + resultIndex, precedingDimension );
firstIndex += precedingDimension;
}
resultIndex += precedingDimension;
}
}

void CCpuMathEngine::VectorFillBernoulli( const CFloatHandle& result, float p, int vectorSize, float value, int seed )
{
ASSERT_EXPR( result.GetMathEngine() == this );
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class CCudaMathEngine : public IMathEngine, public IRawMemoryManager {
void VectorSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAdd(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorNegSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAlongDimension( const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle ) override;
void VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
const CFloatHandle& resultHandle, int vectorSize ) override;
void VectorEqualValue( const CConstIntHandle& firstHandle,
Expand Down
15 changes: 15 additions & 0 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngineVectorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,21 @@ void CCudaMathEngine::VectorNegSum(const CConstFloatHandle& firstHandle, int vec
(GetRaw(firstHandle), vectorSize, GetRaw(resultHandle), true, setZero);
}

void CCudaMathEngine::VectorSumAlongDimension( const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle )
{
ASSERT_EXPR( firstHandle.GetMathEngine() == this );
ASSERT_EXPR( resultHandle.GetMathEngine() == this );
SetCudaDevice( device->DeviceNumber );

dim3 blockCount;
dim3 threadCount;
getCudaTaskGrid2D( blockCount, threadCount, precedingDimension, followingDimension );

VectorSumAlongDimensionKernel<<<blockCount, threadCount>>>
( GetRaw(firstHandle), precedingDimension, dimension, followingDimension, GetRaw(resultHandle) );
}

void CCudaMathEngine::VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
const CFloatHandle& resultHandle, int vectorSize )
{
Expand Down
17 changes: 17 additions & 0 deletions NeoMathEngine/src/GPU/CUDA/Kernels/CudaVectorMathKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ __global__ void VectorSumKernel(const float* __restrict__ mem, int count, float*
atomicAdd(result, sum);
}
}

__global__ void VectorSumAlongDimensionKernel( const float* __restrict__ input, int precedingDims, int dims,
int followingDims, float* result )
{
int x;
int y;
if( GetCudaTaskIndex2D( precedingDims, followingDims, x, y ) ) {
input += y * dims * precedingDims + x;
result += y * precedingDims + x;
*result = 0;
for( int i = 0; i < dims; i++ ) {
*result += *input;
input += precedingDims;
}
}
}

const int VectorEqualCombineCount = 16;
__global__ void VectorEqualKernel( const int* __restrict__ first,
const int* __restrict__ second, float* __restrict__ result, int count )
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/GPU/Metal/MetalMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class CMetalMathEngine : public IMathEngine, public IRawMemoryManager {
void VectorSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAdd(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorNegSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAlongDimension(const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle) override;
void VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
const CFloatHandle& resultHandle, int vectorSize ) override;
void VectorEqualValue( const CConstIntHandle& firstHandle,
Expand Down
6 changes: 6 additions & 0 deletions NeoMathEngine/src/GPU/Metal/MetalMathEngineVectorMath.mm
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@
ASSERT_EXPR( kernel.Run( 0, 0, 1 ) );
}

void CMetalMathEngine::VectorSumAlongDimension( const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle )
{
ASSERT_EXPR( false );
}

void CMetalMathEngine::VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
const CFloatHandle& resultHandle, int vectorSize )
{
Expand Down
2 changes: 2 additions & 0 deletions NeoMathEngine/src/GPU/Vulkan/VulkanMathEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class CVulkanMathEngine : public IMathEngine, public IRawMemoryManager {
void VectorSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAdd(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorNegSum(const CConstFloatHandle& firstHandle, int vectorSize, const CFloatHandle& resultHandle) override;
void VectorSumAlongDimension(const CConstFloatHandle& firstHandle, int precedingDimension, int dimension,
int followingDimension, const CFloatHandle& resultHandle) override;
void VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
const CFloatHandle& resultHandle, int vectorSize ) override;
void VectorEqualValue( const CConstIntHandle& firstHandle,
Expand Down
Loading

0 comments on commit 77fdc68

Please sign in to comment.