Skip to content

Commit

Permalink
Add LoRA (#987)
Browse files Browse the repository at this point in the history
* [DEBUG] simply save the code

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Add LoRA with unit tests

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Add CLoraBuilder::DisableNonLoraTraining

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix lora serializer test

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Update comments a bit

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix tests

Signed-off-by: Valery Fedyunin <valery.fedyunin@abbyy.com>

* Fix clang warning

Signed-off-by: Valery Fedyunin <valery.fedyunin@abbyy.com>

* Comment a bit more

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Don't require access to different offsets during LoRA serialization

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix Alpha coefficient

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix CCompositeLayer's forceBackward flags

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Support multiple inputs in CDropoutLayer

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Support multiple inputs in CLinearLayer

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Clarify comments to CLoraFullyConnectedLayer

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Remove default number of elements for baseFc

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix serialization test

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix method naming in CLoraFullyConnectedLayer

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Revert "Support multiple inputs in CDropoutLayer"

This reverts commit a68c029.

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Revert "Support multiple inputs in CLinearLayer"

This reverts commit 8462ad9.

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix serialization test

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Support LoRA serialization for distributed training

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Add CLoraSerializer::SerializeCheckpoint

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Move LoRA builder and serializer out of compact version

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Revert "Move LoRA builder and serializer out of compact version"

This reverts commit f104a54.

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Fix Android/iOS

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Use more complex nets during distributed checkpoint test

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

* Remove debug output

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>

---------

Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com>
Signed-off-by: Valery Fedyunin <valery.fedyunin@abbyy.com>
  • Loading branch information
Valeriy Fedyunin committed Oct 24, 2023
1 parent a035923 commit 45b82ac
Show file tree
Hide file tree
Showing 15 changed files with 1,525 additions and 34 deletions.
5 changes: 5 additions & 0 deletions NeoML/include/NeoML/Dnn/DnnDistributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ limitations under the License.

namespace NeoML {

// Forward declaration
class CLoraSerializer;

// Interface for setting input to a neural network
class IDistributedDataset {
public:
Expand Down Expand Up @@ -95,6 +98,8 @@ class NEOML_API CDistributedTraining {
CString errorMessage;

void initialize( CArchive& archive, int count, TDistributedInitializer initializer, int seed );

friend class CLoraSerializer;
};

} // namespace NeoML
114 changes: 114 additions & 0 deletions NeoML/include/NeoML/Dnn/DnnLora.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/* Copyright © 2023 ABBYY
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#pragma once

#include <NeoML/NeoMLDefs.h>
#include <NeoML/Dnn/Dnn.h>

namespace NeoML {

// Forward declaration
class CDistributedTraining;

// Implementation of Low-Ranked Adaption (LoRA)
// https://arxiv.org/pdf/2106.09685v2.pdf

struct NEOML_API CLoraParams {
int Rank; // Size of vector in-between A and B matrices of LoRA
float Alpha; // Coefficient, the output will be multiplied by Alpha / Rank
float Dropout; // Dropout applied to input before matrix multiplications

explicit CLoraParams( int rank = 1, float alpha = 1.f, float dropout = 0.f )
: Rank( rank ), Alpha( alpha ), Dropout( dropout ) {}

void Serialize( CArchive& archive );
};

// Mechanism which allows to add/remove/merge LoRA into nets
// It works with CDnnLayerGraph which allows you to modify CDnn or specific composites (e.g. CTransformerEncoderLayer)
class NEOML_API CLoraBuilder {
public:
CLoraBuilder();
// Special constructor which sets list of composites allowed to be modified by this builder
// See BuildForAll* methods for more information
explicit CLoraBuilder( const CArray<CString>& _compositeClases );

// Adds LoRA weights to a specific layer
// Fully-connected
void BuildFcWrapper( CDnnLayerGraph& graph, const char* fcName, const CLoraParams& params ) const;

// Builds LoRA weights for every layer of specific type inside graph and its subgraphs (composite layers)
// In some cases it may lead to troubles because some composite derivatives contain logic of their own
// And these layers may break if some of their internal layers will be replaced with LoRA wrappers
// Which is why not every derivative of CCompositeLayer is supported as a subgraph
//
// By default supported derivatives are:
// 1. CCompositeLayer
// 2. CTemplateLayer
// 3. CRecurrentLayer
// 4. CMultiheadAttentionLayer
// 5. CTransformerEncoderLayer
//
// If this list doesn't fit your task you can replace it with your own via constructor
// Always replaces CFullyConnectedLayer which are directly inside of graph
// Returns the total number of fully-connected layers replaced by this call
// Fully-connected
int BuildAllFcWrappers( CDnnLayerGraph& graph, const CLoraParams& params ) const;

// Disables training of all layers in the net except LoRA wrappers
// Affects only trainable non-composite layers which enabled training
int DisableNonLoraTraining( CDnnLayerGraph& graph ) const;

// Replaces specific LoRA wrapper without merging LoRA weights (restores original layer weights)
// Fully-connected
void DiscardFcWrapper( CDnnLayerGraph& graph, const char* fcName ) const
{ replaceFcWrapper( graph, fcName, false ); }

// Replaces all LoRA wrappers of specific type in graph without merging LoRA weights
// Fully-connected
int DiscardAllFcWrappers( CDnnLayerGraph& graph ) const { return replaceAllFcWrappers( graph, false ); }

// Replaces specific LoRA wrapper with merging LoRA weights
// Fully-connected
void MergeFcWrapper( CDnnLayerGraph& graph, const char* fcName ) const { replaceFcWrapper( graph, fcName, true ); }

// Replaces all LoRA wrappers of specific type in graph with merging LoRa weights
// Fully-connected
int MergeAllFcWrappers( CDnnLayerGraph& graph ) const { return replaceAllFcWrappers( graph, true ); }

private:
CArray<CString> compositeClasses;

void replaceFcWrapper( CDnnLayerGraph& graph, const char* fcName, bool mergeWeights ) const;
int replaceAllFcWrappers( CDnnLayerGraph& graph, bool mergeWeights ) const;
};

// A special mechanism which allows to serialize only LoRA weights of CDnn
class NEOML_API CLoraSerializer {
public:
// Returns the number of LoRA layers whose weights were stored/load
// Weights can be loaded into net with both wrappers or original layers
// In second case LoRA wrappers will be built on the fly
int Serialize( CDnn& dnn, CArchive& archive ) const;
// The same as above but for distributed training
int Serialize( CDistributedTraining& distributed, CArchive& archive ) const;

// LoRA checkpoint is serialized LoRA weights + solver(s)
int SerializeCheckpoint( CDnn& dnn, CArchive& archive ) const;
int SerializeCheckpoint( CDistributedTraining& distributed, CArchive& archive ) const;
};

} // namespace NeoML
129 changes: 129 additions & 0 deletions NeoML/include/NeoML/Dnn/Layers/LoraFullyConnectedLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/* Copyright © 2023 ABBYY
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#pragma once

#include <NeoML/NeoMLDefs.h>
#include <NeoML/Dnn/DnnLora.h>
#include <NeoML/Dnn/Layers/ActivationLayers.h>
#include <NeoML/Dnn/Layers/CompositeLayer.h>
#include <NeoML/Dnn/Layers/DropoutLayer.h>
#include <NeoML/Dnn/Layers/EltwiseLayer.h>
#include <NeoML/Dnn/Layers/FullyConnectedLayer.h>

namespace NeoML {

// Fully Connected Layer with Low Rank Adaptation implements
// https://arxiv.org/pdf/2106.09685v2.pdf
// LoRA wrapper for CFullyConnectedLayer
//
// This layer lazily switches between 2 states
//
// 1st: "split" state
// In this state baseFc contains unmodified weights from the original network
// the LoRA part is present explicitly as a group of layers
// Used during training
//
// loraSum
// ^
// |
// +---------+
// ^ ^
// | |
// | scaling
// | |
// | fcB
// baseFc |
// | fcA
// | |
// | dropout
// | |
// ^ ^
// +---------+
// ^
// |
// inputData
//
// 2nd: "merged" state
// In this state layer baseFc contains weights which emulate full LoRA
// and no other layers are present in the composite
// Used during inference
//
// loraSum
// ^
// |
// baseFc
// ^
// |
// inputData
//
// NOTE: even in the merged state this layer has to store A and B matrices
// in order to switch back to "split" state when needed
// If you need only inference then you can replace this layer with CFullyConnected with merged weights
// (this can be done via CLoraBuilder::MergeFcWrapper)
class NEOML_API CLoraFullyConnectedLayer : public CCompositeLayer {
NEOML_DNN_LAYER( CLoraFullyConnectedLayer )
public:
CLoraFullyConnectedLayer( CDnnBlob& baseWeights, CDnnBlob* baseFreeTerms, const CLoraParams& params );
explicit CLoraFullyConnectedLayer( IMathEngine& mathEngine ); // used for loading serialized layer

void Serialize( CArchive& ) override;

void UpdateParams( const CLoraParams& newParams, CDnnBlob* newA, CDnnBlob* newB );

int OutputSize() const { return baseFc->GetNumberOfElements(); }
int Rank() const { return fcA->GetNumberOfElements(); }
float Alpha() const { return scaling->GetMultiplier() * Rank(); }
float Dropout() const { return dropout->GetDropoutRate(); }

// Raw getters for weights
// These getters do not copy weights which may lead to difficult-to-debug troubles
// But they're necessary for making LoRA work without excessive copying
// baseFc weights from "split" state
CPtr<CDnnBlob> GetSplitWeightsNoCopy() { split(); return baseFc->Weights(); }
// baseFc weights from "merged" state
CPtr<CDnnBlob> GetMergedWeightsNoCopy() { merge(); return baseFc->Weights(); }
// baseFc free terms
CPtr<CDnnBlob> GetFreeTermsNoCopy() { return baseFc->FreeTerms(); }
// A LoRA matrix
CPtr<CDnnBlob> GetAWeightsNoCopy() { return fcA->Weights(); }
// B LoRA matrix
CPtr<CDnnBlob> GetBWeightsNoCopy() { return fcB->Weights(); }

// Mostly for testing/debugging
CPtr<CDnnBlob>& GetWeightsNoCopy() { return baseFc->Weights(); }
bool IsMerged() const { return isMerged; }

protected:
~CLoraFullyConnectedLayer() override = default;

void Reshape() override;

private:
bool isMerged = true;
CPtr<CFullyConnectedLayer> baseFc;
CPtr<CDropoutLayer> dropout;
CPtr<CFullyConnectedLayer> fcA;
CPtr<CFullyConnectedLayer> fcB;
CPtr<CLinearLayer> scaling;
CPtr<CEltwiseSumLayer> sum;

void initialize( const CLoraParams& params );
void merge();
void split();
void recalcBaseWeights();
};

} // namespace NeoML
6 changes: 3 additions & 3 deletions NeoML/include/NeoML/Dnn/Layers/TransformerLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class NEOML_API CTransformerEncoderLayer : public CCompositeLayer {
void SetDropoutRate( float rate );

// Sets the size of the first fully-connected layer inside of feed-forward
int GetFeedForwardSize() const { return fc1->GetNumberOfElements(); }
int GetFeedForwardSize() const { return CheckCast<CFullyConnectedLayer>( fc1 )->GetNumberOfElements(); }
void SetFeedForwardSize( int size );

// Sets activation between fully-connected layers inside of feed-forward
Expand All @@ -118,9 +118,9 @@ class NEOML_API CTransformerEncoderLayer : public CCompositeLayer {
CPtr<CMultiheadAttentionLayer> selfAttention;
CPtr<CDropoutLayer> dropoutSelfAttention;
CPtr<CEltwiseSumLayer> selfAttentionSum;
CPtr<CFullyConnectedLayer> fc1;
CPtr<CBaseLayer> fc1;
CPtr<CDropoutLayer> dropoutFc1;
CPtr<CFullyConnectedLayer> fc2;
CPtr<CBaseLayer> fc2;
CPtr<CDropoutLayer> dropoutFc2;
CPtr<CEltwiseSumLayer> feedForwardSum;

Expand Down
2 changes: 2 additions & 0 deletions NeoML/include/NeoML/NeoML.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ limitations under the License.
#include <NeoML/TraditionalML/WordDictionary.h>

#include <NeoML/Dnn/DnnDistributed.h>
#include <NeoML/Dnn/DnnLora.h>
#include <NeoML/Dnn/DnnOptimization.h>
#include <NeoML/Dnn/Layers/3dPoolingLayer.h>
#include <NeoML/Dnn/Layers/3dTransposedConvLayer.h>
Expand Down Expand Up @@ -125,6 +126,7 @@ limitations under the License.
#include <NeoML/Dnn/Layers/InterpolationLayer.h>
#include <NeoML/Dnn/Layers/IrnnLayer.h>
#include <NeoML/Dnn/Layers/LogicalLayers.h>
#include <NeoML/Dnn/Layers/LoraFullyConnectedLayer.h>
#include <NeoML/Dnn/Layers/LrnLayer.h>
#include <NeoML/Dnn/Layers/MaxOverTimePoolingLayer.h>
#include <NeoML/Dnn/Layers/ModelWrapperLayer.h>
Expand Down
4 changes: 4 additions & 0 deletions NeoML/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ set(NeoML_SOURCES_COMPACT
set(NeoML_SOURCES
${NeoML_SOURCES_COMPACT}
Dnn/DnnDistributed.cpp
Dnn/DnnLora.cpp
Dnn/DnnOptimization.cpp
Dnn/Layers/3dPoolingLayer.cpp
Dnn/Layers/3dTransposedConvLayer.cpp
Expand Down Expand Up @@ -126,6 +127,7 @@ set(NeoML_SOURCES
Dnn/Layers/InterpolationLayer.cpp
Dnn/Layers/IrnnLayer.cpp
Dnn/Layers/LogicalLayers.cpp
Dnn/Layers/LoraFullyConnectedLayer.cpp
Dnn/Layers/LrnLayer.cpp
Dnn/Layers/MaxOverTimePoolingLayer.cpp
Dnn/Layers/MobileNetV3BlockLayer.cpp
Expand Down Expand Up @@ -356,6 +358,7 @@ set(NeoML_HEADERS

# Headers
../include/NeoML/Dnn/DnnDistributed.h
../include/NeoML/Dnn/DnnLora.h
../include/NeoML/Dnn/DnnOptimization.h
../include/NeoML/Dnn/Layers/3dPoolingLayer.h
../include/NeoML/Dnn/Layers/3dTransposedConvLayer.h
Expand Down Expand Up @@ -384,6 +387,7 @@ set(NeoML_HEADERS
../include/NeoML/Dnn/Layers/InterpolationLayer.h
../include/NeoML/Dnn/Layers/IrnnLayer.h
../include/NeoML/Dnn/Layers/LogicalLayers.h
../include/NeoML/Dnn/Layers/LoraFullyConnectedLayer.h
../include/NeoML/Dnn/Layers/LrnLayer.h
../include/NeoML/Dnn/Layers/MaxOverTimePoolingLayer.h
../include/NeoML/Dnn/Layers/MobileNetV3BlockLayer.h
Expand Down
2 changes: 2 additions & 0 deletions NeoML/src/Dnn/Dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ limitations under the License.
#include <NeoML/Dnn/Layers/InterpolationLayer.h>
#include <NeoML/Dnn/Layers/IrnnLayer.h>
#include <NeoML/Dnn/Layers/LogicalLayers.h>
#include <NeoML/Dnn/Layers/LoraFullyConnectedLayer.h>
#include <NeoML/Dnn/Layers/LrnLayer.h>
#include <NeoML/Dnn/Layers/MaxOverTimePoolingLayer.h>
#include <NeoML/Dnn/Layers/MobileNetV3BlockLayer.h>
Expand Down Expand Up @@ -355,6 +356,7 @@ REGISTER_NEOML_LAYER( CImageResizeLayer, "FmlCnnImageResizeLayer" )
REGISTER_NEOML_LAYER( CImageToPixelLayer, "FmlCnnImageToPixelLayerClass" )
REGISTER_NEOML_LAYER( CFocalLossLayer, "FmlCnnFocalLossLayer" )
REGISTER_NEOML_LAYER( CFullyConnectedSourceLayer, "FmlCnnFullyConnectedSourceLayer" )
REGISTER_NEOML_LAYER( CLoraFullyConnectedLayer, "NeoMLDnnLoraFullyConnectedLayer" )
REGISTER_NEOML_LAYER( CMaxOverTimePoolingLayer, "FmlCnnMaxOverTimePoolingLayer" )
REGISTER_NEOML_LAYER( CMobileNetV3PreSEBlockLayer, "NeoMLDnnMobileNetV3PreSEBlockLayer" )
REGISTER_NEOML_LAYER( CMobileNetV3PostSEBlockLayer, "NeoMLDnnMobileNetV3PostSEBlockLayer" )
Expand Down
Loading

0 comments on commit 45b82ac

Please sign in to comment.