-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [DEBUG] simply save the code Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Add LoRA with unit tests Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Add CLoraBuilder::DisableNonLoraTraining Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix lora serializer test Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Update comments a bit Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix tests Signed-off-by: Valery Fedyunin <valery.fedyunin@abbyy.com> * Fix clang warning Signed-off-by: Valery Fedyunin <valery.fedyunin@abbyy.com> * Comment a bit more Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Don't require access to different offsets during LoRA serialization Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix Alpha coefficient Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix CCompositeLayer's forceBackward flags Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Support multiple inputs in CDropoutLayer Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Support multiple inputs in CLinearLayer Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Clarify comments to CLoraFullyConnectedLayer Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Remove default number of elements for baseFc Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix serialization test Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix method naming in CLoraFullyConnectedLayer Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Revert "Support multiple inputs in CDropoutLayer" This reverts commit a68c029. Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Revert "Support multiple inputs in CLinearLayer" This reverts commit 8462ad9. Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix serialization test Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Support LoRA serialization for distributed training Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Add CLoraSerializer::SerializeCheckpoint Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Move LoRA builder and serializer out of compact version Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Revert "Move LoRA builder and serializer out of compact version" This reverts commit f104a54. Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Fix Android/iOS Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Use more complex nets during distributed checkpoint test Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> * Remove debug output Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> --------- Signed-off-by: Valerii Fediunin <valery.fedyunin@abbyy.com> Signed-off-by: Valery Fedyunin <valery.fedyunin@abbyy.com>
- Loading branch information
Valeriy Fedyunin
committed
Oct 24, 2023
1 parent
a035923
commit 45b82ac
Showing
15 changed files
with
1,525 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* Copyright © 2023 ABBYY | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--------------------------------------------------------------------------------------------------------------*/ | ||
|
||
#pragma once | ||
|
||
#include <NeoML/NeoMLDefs.h> | ||
#include <NeoML/Dnn/Dnn.h> | ||
|
||
namespace NeoML { | ||
|
||
// Forward declaration | ||
class CDistributedTraining; | ||
|
||
// Implementation of Low-Ranked Adaption (LoRA) | ||
// https://arxiv.org/pdf/2106.09685v2.pdf | ||
|
||
struct NEOML_API CLoraParams { | ||
int Rank; // Size of vector in-between A and B matrices of LoRA | ||
float Alpha; // Coefficient, the output will be multiplied by Alpha / Rank | ||
float Dropout; // Dropout applied to input before matrix multiplications | ||
|
||
explicit CLoraParams( int rank = 1, float alpha = 1.f, float dropout = 0.f ) | ||
: Rank( rank ), Alpha( alpha ), Dropout( dropout ) {} | ||
|
||
void Serialize( CArchive& archive ); | ||
}; | ||
|
||
// Mechanism which allows to add/remove/merge LoRA into nets | ||
// It works with CDnnLayerGraph which allows you to modify CDnn or specific composites (e.g. CTransformerEncoderLayer) | ||
class NEOML_API CLoraBuilder { | ||
public: | ||
CLoraBuilder(); | ||
// Special constructor which sets list of composites allowed to be modified by this builder | ||
// See BuildForAll* methods for more information | ||
explicit CLoraBuilder( const CArray<CString>& _compositeClases ); | ||
|
||
// Adds LoRA weights to a specific layer | ||
// Fully-connected | ||
void BuildFcWrapper( CDnnLayerGraph& graph, const char* fcName, const CLoraParams& params ) const; | ||
|
||
// Builds LoRA weights for every layer of specific type inside graph and its subgraphs (composite layers) | ||
// In some cases it may lead to troubles because some composite derivatives contain logic of their own | ||
// And these layers may break if some of their internal layers will be replaced with LoRA wrappers | ||
// Which is why not every derivative of CCompositeLayer is supported as a subgraph | ||
// | ||
// By default supported derivatives are: | ||
// 1. CCompositeLayer | ||
// 2. CTemplateLayer | ||
// 3. CRecurrentLayer | ||
// 4. CMultiheadAttentionLayer | ||
// 5. CTransformerEncoderLayer | ||
// | ||
// If this list doesn't fit your task you can replace it with your own via constructor | ||
// Always replaces CFullyConnectedLayer which are directly inside of graph | ||
// Returns the total number of fully-connected layers replaced by this call | ||
// Fully-connected | ||
int BuildAllFcWrappers( CDnnLayerGraph& graph, const CLoraParams& params ) const; | ||
|
||
// Disables training of all layers in the net except LoRA wrappers | ||
// Affects only trainable non-composite layers which enabled training | ||
int DisableNonLoraTraining( CDnnLayerGraph& graph ) const; | ||
|
||
// Replaces specific LoRA wrapper without merging LoRA weights (restores original layer weights) | ||
// Fully-connected | ||
void DiscardFcWrapper( CDnnLayerGraph& graph, const char* fcName ) const | ||
{ replaceFcWrapper( graph, fcName, false ); } | ||
|
||
// Replaces all LoRA wrappers of specific type in graph without merging LoRA weights | ||
// Fully-connected | ||
int DiscardAllFcWrappers( CDnnLayerGraph& graph ) const { return replaceAllFcWrappers( graph, false ); } | ||
|
||
// Replaces specific LoRA wrapper with merging LoRA weights | ||
// Fully-connected | ||
void MergeFcWrapper( CDnnLayerGraph& graph, const char* fcName ) const { replaceFcWrapper( graph, fcName, true ); } | ||
|
||
// Replaces all LoRA wrappers of specific type in graph with merging LoRa weights | ||
// Fully-connected | ||
int MergeAllFcWrappers( CDnnLayerGraph& graph ) const { return replaceAllFcWrappers( graph, true ); } | ||
|
||
private: | ||
CArray<CString> compositeClasses; | ||
|
||
void replaceFcWrapper( CDnnLayerGraph& graph, const char* fcName, bool mergeWeights ) const; | ||
int replaceAllFcWrappers( CDnnLayerGraph& graph, bool mergeWeights ) const; | ||
}; | ||
|
||
// A special mechanism which allows to serialize only LoRA weights of CDnn | ||
class NEOML_API CLoraSerializer { | ||
public: | ||
// Returns the number of LoRA layers whose weights were stored/load | ||
// Weights can be loaded into net with both wrappers or original layers | ||
// In second case LoRA wrappers will be built on the fly | ||
int Serialize( CDnn& dnn, CArchive& archive ) const; | ||
// The same as above but for distributed training | ||
int Serialize( CDistributedTraining& distributed, CArchive& archive ) const; | ||
|
||
// LoRA checkpoint is serialized LoRA weights + solver(s) | ||
int SerializeCheckpoint( CDnn& dnn, CArchive& archive ) const; | ||
int SerializeCheckpoint( CDistributedTraining& distributed, CArchive& archive ) const; | ||
}; | ||
|
||
} // namespace NeoML |
129 changes: 129 additions & 0 deletions
129
NeoML/include/NeoML/Dnn/Layers/LoraFullyConnectedLayer.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
/* Copyright © 2023 ABBYY | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--------------------------------------------------------------------------------------------------------------*/ | ||
|
||
#pragma once | ||
|
||
#include <NeoML/NeoMLDefs.h> | ||
#include <NeoML/Dnn/DnnLora.h> | ||
#include <NeoML/Dnn/Layers/ActivationLayers.h> | ||
#include <NeoML/Dnn/Layers/CompositeLayer.h> | ||
#include <NeoML/Dnn/Layers/DropoutLayer.h> | ||
#include <NeoML/Dnn/Layers/EltwiseLayer.h> | ||
#include <NeoML/Dnn/Layers/FullyConnectedLayer.h> | ||
|
||
namespace NeoML { | ||
|
||
// Fully Connected Layer with Low Rank Adaptation implements | ||
// https://arxiv.org/pdf/2106.09685v2.pdf | ||
// LoRA wrapper for CFullyConnectedLayer | ||
// | ||
// This layer lazily switches between 2 states | ||
// | ||
// 1st: "split" state | ||
// In this state baseFc contains unmodified weights from the original network | ||
// the LoRA part is present explicitly as a group of layers | ||
// Used during training | ||
// | ||
// loraSum | ||
// ^ | ||
// | | ||
// +---------+ | ||
// ^ ^ | ||
// | | | ||
// | scaling | ||
// | | | ||
// | fcB | ||
// baseFc | | ||
// | fcA | ||
// | | | ||
// | dropout | ||
// | | | ||
// ^ ^ | ||
// +---------+ | ||
// ^ | ||
// | | ||
// inputData | ||
// | ||
// 2nd: "merged" state | ||
// In this state layer baseFc contains weights which emulate full LoRA | ||
// and no other layers are present in the composite | ||
// Used during inference | ||
// | ||
// loraSum | ||
// ^ | ||
// | | ||
// baseFc | ||
// ^ | ||
// | | ||
// inputData | ||
// | ||
// NOTE: even in the merged state this layer has to store A and B matrices | ||
// in order to switch back to "split" state when needed | ||
// If you need only inference then you can replace this layer with CFullyConnected with merged weights | ||
// (this can be done via CLoraBuilder::MergeFcWrapper) | ||
class NEOML_API CLoraFullyConnectedLayer : public CCompositeLayer { | ||
NEOML_DNN_LAYER( CLoraFullyConnectedLayer ) | ||
public: | ||
CLoraFullyConnectedLayer( CDnnBlob& baseWeights, CDnnBlob* baseFreeTerms, const CLoraParams& params ); | ||
explicit CLoraFullyConnectedLayer( IMathEngine& mathEngine ); // used for loading serialized layer | ||
|
||
void Serialize( CArchive& ) override; | ||
|
||
void UpdateParams( const CLoraParams& newParams, CDnnBlob* newA, CDnnBlob* newB ); | ||
|
||
int OutputSize() const { return baseFc->GetNumberOfElements(); } | ||
int Rank() const { return fcA->GetNumberOfElements(); } | ||
float Alpha() const { return scaling->GetMultiplier() * Rank(); } | ||
float Dropout() const { return dropout->GetDropoutRate(); } | ||
|
||
// Raw getters for weights | ||
// These getters do not copy weights which may lead to difficult-to-debug troubles | ||
// But they're necessary for making LoRA work without excessive copying | ||
// baseFc weights from "split" state | ||
CPtr<CDnnBlob> GetSplitWeightsNoCopy() { split(); return baseFc->Weights(); } | ||
// baseFc weights from "merged" state | ||
CPtr<CDnnBlob> GetMergedWeightsNoCopy() { merge(); return baseFc->Weights(); } | ||
// baseFc free terms | ||
CPtr<CDnnBlob> GetFreeTermsNoCopy() { return baseFc->FreeTerms(); } | ||
// A LoRA matrix | ||
CPtr<CDnnBlob> GetAWeightsNoCopy() { return fcA->Weights(); } | ||
// B LoRA matrix | ||
CPtr<CDnnBlob> GetBWeightsNoCopy() { return fcB->Weights(); } | ||
|
||
// Mostly for testing/debugging | ||
CPtr<CDnnBlob>& GetWeightsNoCopy() { return baseFc->Weights(); } | ||
bool IsMerged() const { return isMerged; } | ||
|
||
protected: | ||
~CLoraFullyConnectedLayer() override = default; | ||
|
||
void Reshape() override; | ||
|
||
private: | ||
bool isMerged = true; | ||
CPtr<CFullyConnectedLayer> baseFc; | ||
CPtr<CDropoutLayer> dropout; | ||
CPtr<CFullyConnectedLayer> fcA; | ||
CPtr<CFullyConnectedLayer> fcB; | ||
CPtr<CLinearLayer> scaling; | ||
CPtr<CEltwiseSumLayer> sum; | ||
|
||
void initialize( const CLoraParams& params ); | ||
void merge(); | ||
void split(); | ||
void recalcBaseWeights(); | ||
}; | ||
|
||
} // namespace NeoML |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.