Skip to content

Commit

Permalink
[NeoML] BaseLayer more clean-up
Browse files Browse the repository at this point in the history
Signed-off-by: Kirill Golikov <kirill.golikov@abbyy.com>
  • Loading branch information
favorart committed Apr 23, 2024
1 parent f1e59a2 commit 37f03f6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion NeoML/include/NeoML/Dnn/Dnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class NEOML_API CBaseLayer : public virtual IObject {
virtual size_t GetOutputBlobsSize() const;

// Releases all temporary resources allocated for the layer
virtual void CleanUp();
virtual void CleanUp( bool totalCleanUp = false );

// Returns the total size of trainable parameters in this layer
// Returns the total size of trainable parameters of its internal layers, if layer is composite or recurrent
Expand Down
2 changes: 2 additions & 0 deletions NeoML/include/NeoML/Dnn/Layers/SinkLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class NEOML_API CSinkLayer : public CBaseLayer {
// After each call to RunOnce this blob contains the results
const CPtr<CDnnBlob>& GetBlob() const;

void CleanUp() override { blob = nullptr; }

protected:
CPtr<CDnnBlob> blob;

Expand Down
12 changes: 11 additions & 1 deletion NeoML/src/Dnn/BaseLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,22 @@ size_t CBaseLayer::GetOutputBlobsSize() const
return result;
}

void CBaseLayer::CleanUp()
void CBaseLayer::CleanUp( bool totalCleanUp )
{
inputBlobs.DeleteAll();
inputBlobs.SetSize(inputDescs.Size());
outputBlobs.DeleteAll();
outputBlobs.SetSize(outputDescs.Size());

if ( totalCleanUp ) {
inputDiffBlobs.DeleteAll();
outputDiffBlobs.DeleteAll();
paramDiffBlobs.DeleteAll();
readyOutputDiffs.DeleteAll();
clearAllRuntimeBlobs();

ForceReshape();
}
}

size_t CBaseLayer::GetTrainableParametersSize() const
Expand Down
4 changes: 2 additions & 2 deletions NeoML/src/Dnn/Dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,10 @@ void CDnn::RunAndLearnOnce()
solver->Train();
}

void CDnn::CleanUp()
void CDnn::CleanUp( bool totalCleanUp )
{
for( int i = 0; i < layers.Size(); i++ ) {
layers[i]->CleanUp();
layers[i]->CleanUp( totalCleanUp );
}
}

Expand Down

0 comments on commit 37f03f6

Please sign in to comment.