Skip to content

Commit

Permalink
[NeoML] Add CleanUp in network learning tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kirill Golikov <kirill.golikov@abbyy.com>
  • Loading branch information
favorart committed May 8, 2024
1 parent 80de54b commit c701fac
Showing 1 changed file with 87 additions and 23 deletions.
110 changes: 87 additions & 23 deletions NeoML/test/src/DnnSolverTest.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright © 2021-2023 ABBYY
/* Copyright © 2021-2024 ABBYY
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -21,13 +21,19 @@ limitations under the License.
using namespace NeoML;
using namespace NeoMLTest;

namespace NeoMLTest {

// Returns coefficient of neuron with one in and one out.
static float getFcCoeff( const CFullyConnectedLayer* fcLayer )
{
NeoAssert( fcLayer != 0 );
return fcLayer->GetWeightsData()->GetData().GetValue();
}

} // namespace NeoMLTest

//---------------------------------------------------------------------------------------------------------------------

// Check build/change correctness with gradient accumulation enabled
TEST( CDnnSolverTest, NetworkModificationOnGradientAccumulation )
{
Expand Down Expand Up @@ -150,51 +156,68 @@ TEST( CDnnSolverTest, NetworkModificationOnGradientAccumulation )
EXPECT_EQ( e1, e2 );
}

// Net for weight check.
class CWeightCheckNet {
//---------------------------------------------------------------------------------------------------------------------

namespace NeoMLTest {

// Network for weights check.
class CWeightCheckNet final {
public:
CWeightCheckNet();
void SetSolver( CDnnSolver* solver ) { dnn.SetSolver( solver ); }
CWeightCheckNet( CDnnSolver* solver );

float RunAndLearnOnce();
void GetWeights( CArray<float>& weights ) const;
void CleanUp();

private:
CRandom random;
CDnn dnn;

CPtr<CSourceLayer> data;
CPtr<CSourceLayer> label;
CPtr<CFullyConnectedLayer> fc;
CPtr<CLossLayer> loss;

void setInputs();
};

CWeightCheckNet::CWeightCheckNet() :
CWeightCheckNet::CWeightCheckNet( CDnnSolver* solver ) :
random( 0xAAAAAAAA ),
dnn( random, MathEngine() )
{
CPtr<CSourceLayer> data = AddLayer<CSourceLayer>( "data", dnn );
dnn.SetSolver( solver );

data = AddLayer<CSourceLayer>( "data", dnn );
label = AddLayer<CSourceLayer>( "label", dnn );

fc = AddLayer<CFullyConnectedLayer>( "fc", { data } );
fc->SetNumberOfElements( 2 );
fc->SetZeroFreeTerm( true );

loss = AddLayer<CCrossEntropyLossLayer>( "loss", { fc, label } );
setInputs();
}

void CWeightCheckNet::setInputs()
{
{
CPtr<CDnnBlob> dataBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 1, 2 );
CArray<float> buff( { 0.25f, -0.345f } );
dataBlob->CopyFrom( buff.GetPtr() );
data->SetBlob( dataBlob );
}

CPtr<CSourceLayer> label = AddLayer<CSourceLayer>( "label", dnn );
{
CPtr<CDnnBlob> labelBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Int, 1, 1, 1 );
CArray<int> buff( { 0 } );
labelBlob->CopyFrom( buff.GetPtr() );
label->SetBlob( labelBlob );
}

fc = AddLayer<CFullyConnectedLayer>( "fc", { data } );
fc->SetNumberOfElements( 2 );
fc->SetZeroFreeTerm( true );
{
CPtr<CDnnBlob> weightBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 2, 2 );
CArray<float> buff = { -0.5f, 0.9f, 0.3f, -0.7f };
weightBlob->CopyFrom( buff.GetPtr() );
fc->SetWeightsData( weightBlob );
}

loss = AddLayer<CCrossEntropyLossLayer>( "loss", { fc, label } );
}

float CWeightCheckNet::RunAndLearnOnce()
Expand All @@ -210,23 +233,44 @@ void CWeightCheckNet::GetWeights( CArray<float>& weights ) const
weightBlob->CopyTo( weights.GetPtr() );
}

void CWeightCheckNet::CleanUp()
{
dnn.CleanUp( /*totalCleanUp*/true );
MathEngine().CleanUp();
dnn.Random().Reset( 0xAAAAAAAA );
dnn.GetSolver()->Reset();
setInputs();
}

//---------------------------------------------------------------------------------------------------------------------

void testSolver( CDnnSolver* solver, const CArray<CArray<float>>& expected )
{
CWeightCheckNet net;
CArray<float> weights;
net.SetSolver( solver );
auto check = [&]( CWeightCheckNet& net )
{
for( int i = 0; i < expected.Size(); ++i ) {
float loss = net.RunAndLearnOnce();
loss;
net.GetWeights( weights );
ASSERT_EQ( expected[i].Size(), weights.Size() );
EXPECT_EQ( expected[i].Size(), weights.Size() );
for( int j = 0; j < weights.Size(); ++j ) {
ASSERT_TRUE( FloatEq( expected[i][j], weights[j] ) );
EXPECT_NEAR( expected[i][j], weights[j], 1e-5 );
}
}
};

CWeightCheckNet net( solver );
check( net );

net.CleanUp();
check( net );
}

// ====================================================================================================================
} // namespace NeoMLTest

//---------------------------------------------------------------------------------------------------------------------

// Sgd.

TEST( CDnnSolverTest, SgdNoReg )
Expand Down Expand Up @@ -339,7 +383,8 @@ TEST( CDnnSolverTest, SgdCompatL2 )
testSolver( sgd, expected );
}

// ====================================================================================================================
//---------------------------------------------------------------------------------------------------------------------

// Adam.

TEST( CDnnSolverTest, AdamNoReg )
Expand Down Expand Up @@ -469,7 +514,8 @@ TEST( CDnnSolverTest, AdamCompatL2 )
testSolver( adam, expected );
}

// ====================================================================================================================
//---------------------------------------------------------------------------------------------------------------------

// Nadam.

TEST( CDnnSolverTest, NadamNoReg )
Expand Down Expand Up @@ -532,6 +578,12 @@ TEST( CDnnSolverTest, NadamL2 )
testSolver( adam, expected );
}

//---------------------------------------------------------------------------------------------------------------------

// Serialization.

namespace NeoMLTest {

static bool checkBlobEquality( CDnnBlob& firstBlob, CDnnBlob& secondBlob )
{
if( !firstBlob.HasEqualDimensions( &secondBlob ) ) {
Expand Down Expand Up @@ -683,6 +735,10 @@ static void solverSerializationTestImpl( CPtr<CDnnSolver> firstSolver, bool trai
EXPECT_TRUE( checkLstmEquality( reverse, secondReverse ) );
}

} // namespace NeoMLTest

// sgd.

TEST( CDnnSimpleGradientSolverTest, Serialization1 )
{
CPtr<CDnnSimpleGradientSolver> sgd = new CDnnSimpleGradientSolver( MathEngine() );
Expand All @@ -707,6 +763,10 @@ TEST( CDnnSimpleGradientSolverTest, Serialization2 )
solverSerializationTestImpl( sgd.Ptr(), false );
}

//---------------------------------------------------------------------------------------------------------------------

// Adam.

TEST( CDnnAdaptiveGradientSolverTest, Serialization1 )
{
CPtr<CDnnAdaptiveGradientSolver> adam = new CDnnAdaptiveGradientSolver( MathEngine() );
Expand Down Expand Up @@ -769,6 +829,10 @@ TEST( CDnnNesterovGradientSolverTest, Serialization2 )
solverSerializationTestImpl( nadam.Ptr(), false );
}

//---------------------------------------------------------------------------------------------------------------------

// Lamb.

TEST( CDnnLambGradientSolverTest, Serialization1 )
{
CPtr<CDnnLambGradientSolver> lamb = new CDnnLambGradientSolver( MathEngine() );
Expand Down Expand Up @@ -834,7 +898,7 @@ TEST( CDnnLambGradientSolverTest, Serialization4 )
solverSerializationTestImpl( lamb.Ptr(), true );
}

// ====================================================================================================================
//---------------------------------------------------------------------------------------------------------------------

TEST( CDnnSolverTest, CompositeLearningRate )
{
Expand Down

0 comments on commit c701fac

Please sign in to comment.