From c701facf3ebeace11d60b10dad776704df97c285 Mon Sep 17 00:00:00 2001 From: Kirill Golikov Date: Wed, 8 May 2024 14:51:12 +0200 Subject: [PATCH] [NeoML] Add CleanUp in network learning tests Signed-off-by: Kirill Golikov --- NeoML/test/src/DnnSolverTest.cpp | 110 ++++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 23 deletions(-) diff --git a/NeoML/test/src/DnnSolverTest.cpp b/NeoML/test/src/DnnSolverTest.cpp index b932be687c..4b756aa393 100644 --- a/NeoML/test/src/DnnSolverTest.cpp +++ b/NeoML/test/src/DnnSolverTest.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2021-2023 ABBYY +/* Copyright © 2021-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ limitations under the License. using namespace NeoML; using namespace NeoMLTest; +namespace NeoMLTest { + // Returns coefficient of neuron with one in and one out. static float getFcCoeff( const CFullyConnectedLayer* fcLayer ) { @@ -28,6 +30,10 @@ static float getFcCoeff( const CFullyConnectedLayer* fcLayer ) return fcLayer->GetWeightsData()->GetData().GetValue(); } +} // namespace NeoMLTest + +//--------------------------------------------------------------------------------------------------------------------- + // Check build/change correctness with gradient accumulation enabled TEST( CDnnSolverTest, NetworkModificationOnGradientAccumulation ) { @@ -150,51 +156,68 @@ TEST( CDnnSolverTest, NetworkModificationOnGradientAccumulation ) EXPECT_EQ( e1, e2 ); } -// Net for weight check. -class CWeightCheckNet { +//--------------------------------------------------------------------------------------------------------------------- + +namespace NeoMLTest { + +// Network for weights check. +class CWeightCheckNet final { public: - CWeightCheckNet(); - void SetSolver( CDnnSolver* solver ) { dnn.SetSolver( solver ); } + CWeightCheckNet( CDnnSolver* solver ); + float RunAndLearnOnce(); void GetWeights( CArray& weights ) const; + void CleanUp(); + private: CRandom random; CDnn dnn; + + CPtr data; + CPtr label; CPtr fc; CPtr loss; + + void setInputs(); }; -CWeightCheckNet::CWeightCheckNet() : +CWeightCheckNet::CWeightCheckNet( CDnnSolver* solver ) : random( 0xAAAAAAAA ), dnn( random, MathEngine() ) { - CPtr data = AddLayer( "data", dnn ); + dnn.SetSolver( solver ); + + data = AddLayer( "data", dnn ); + label = AddLayer( "label", dnn ); + + fc = AddLayer( "fc", { data } ); + fc->SetNumberOfElements( 2 ); + fc->SetZeroFreeTerm( true ); + + loss = AddLayer( "loss", { fc, label } ); + setInputs(); +} + +void CWeightCheckNet::setInputs() +{ { CPtr dataBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 1, 2 ); CArray buff( { 0.25f, -0.345f } ); dataBlob->CopyFrom( buff.GetPtr() ); data->SetBlob( dataBlob ); } - - CPtr label = AddLayer( "label", dnn ); { CPtr labelBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Int, 1, 1, 1 ); CArray buff( { 0 } ); labelBlob->CopyFrom( buff.GetPtr() ); label->SetBlob( labelBlob ); } - - fc = AddLayer( "fc", { data } ); - fc->SetNumberOfElements( 2 ); - fc->SetZeroFreeTerm( true ); { CPtr weightBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 2, 2 ); CArray buff = { -0.5f, 0.9f, 0.3f, -0.7f }; weightBlob->CopyFrom( buff.GetPtr() ); fc->SetWeightsData( weightBlob ); } - - loss = AddLayer( "loss", { fc, label } ); } float CWeightCheckNet::RunAndLearnOnce() @@ -210,23 +233,44 @@ void CWeightCheckNet::GetWeights( CArray& weights ) const weightBlob->CopyTo( weights.GetPtr() ); } +void CWeightCheckNet::CleanUp() +{ + dnn.CleanUp( /*totalCleanUp*/true ); + MathEngine().CleanUp(); + dnn.Random().Reset( 0xAAAAAAAA ); + dnn.GetSolver()->Reset(); + setInputs(); +} + +//--------------------------------------------------------------------------------------------------------------------- + void testSolver( CDnnSolver* solver, const CArray>& expected ) { - CWeightCheckNet net; CArray weights; - net.SetSolver( solver ); + auto check = [&]( CWeightCheckNet& net ) + { for( int i = 0; i < expected.Size(); ++i ) { float loss = net.RunAndLearnOnce(); loss; net.GetWeights( weights ); - ASSERT_EQ( expected[i].Size(), weights.Size() ); + EXPECT_EQ( expected[i].Size(), weights.Size() ); for( int j = 0; j < weights.Size(); ++j ) { - ASSERT_TRUE( FloatEq( expected[i][j], weights[j] ) ); + EXPECT_NEAR( expected[i][j], weights[j], 1e-5 ); } } + }; + + CWeightCheckNet net( solver ); + check( net ); + + net.CleanUp(); + check( net ); } -// ==================================================================================================================== +} // namespace NeoMLTest + +//--------------------------------------------------------------------------------------------------------------------- + // Sgd. TEST( CDnnSolverTest, SgdNoReg ) @@ -339,7 +383,8 @@ TEST( CDnnSolverTest, SgdCompatL2 ) testSolver( sgd, expected ); } -// ==================================================================================================================== +//--------------------------------------------------------------------------------------------------------------------- + // Adam. TEST( CDnnSolverTest, AdamNoReg ) @@ -469,7 +514,8 @@ TEST( CDnnSolverTest, AdamCompatL2 ) testSolver( adam, expected ); } -// ==================================================================================================================== +//--------------------------------------------------------------------------------------------------------------------- + // Nadam. TEST( CDnnSolverTest, NadamNoReg ) @@ -532,6 +578,12 @@ TEST( CDnnSolverTest, NadamL2 ) testSolver( adam, expected ); } +//--------------------------------------------------------------------------------------------------------------------- + +// Serialization. + +namespace NeoMLTest { + static bool checkBlobEquality( CDnnBlob& firstBlob, CDnnBlob& secondBlob ) { if( !firstBlob.HasEqualDimensions( &secondBlob ) ) { @@ -683,6 +735,10 @@ static void solverSerializationTestImpl( CPtr firstSolver, bool trai EXPECT_TRUE( checkLstmEquality( reverse, secondReverse ) ); } +} // namespace NeoMLTest + +// sgd. + TEST( CDnnSimpleGradientSolverTest, Serialization1 ) { CPtr sgd = new CDnnSimpleGradientSolver( MathEngine() ); @@ -707,6 +763,10 @@ TEST( CDnnSimpleGradientSolverTest, Serialization2 ) solverSerializationTestImpl( sgd.Ptr(), false ); } +//--------------------------------------------------------------------------------------------------------------------- + +// Adam. + TEST( CDnnAdaptiveGradientSolverTest, Serialization1 ) { CPtr adam = new CDnnAdaptiveGradientSolver( MathEngine() ); @@ -769,6 +829,10 @@ TEST( CDnnNesterovGradientSolverTest, Serialization2 ) solverSerializationTestImpl( nadam.Ptr(), false ); } +//--------------------------------------------------------------------------------------------------------------------- + +// Lamb. + TEST( CDnnLambGradientSolverTest, Serialization1 ) { CPtr lamb = new CDnnLambGradientSolver( MathEngine() ); @@ -834,7 +898,7 @@ TEST( CDnnLambGradientSolverTest, Serialization4 ) solverSerializationTestImpl( lamb.Ptr(), true ); } -// ==================================================================================================================== +//--------------------------------------------------------------------------------------------------------------------- TEST( CDnnSolverTest, CompositeLearningRate ) {