Skip to content

Commit

Permalink
Update supported models for Liger Kernel (axolotl-ai-cloud#1875)
Browse files Browse the repository at this point in the history
* Update supported models for Liger Kernel

Add Mistral LCE, Gemma LCE, Gemma 2 without LCE (softcapping is not yet implemented for Gemma in Liger Kernel LCE forward), Phi3 without LCE

* move import to their appropriate conditions

* Integrate Phi3 LCE support

linkedin/Liger-Kernel#103

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
DocShotgun and winglian authored Sep 1, 2024
1 parent ce33e1e commit 15408d0
Showing 1 changed file with 44 additions and 7 deletions.
51 changes: 44 additions & 7 deletions src/axolotl/integrations/liger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
Expand All @@ -43,6 +42,9 @@ def get_input_args(self):

def pre_model_load(self, cfg):
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama

if cfg.liger_rope:
Expand All @@ -57,6 +59,9 @@ def pre_model_load(self, cfg):
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward

elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral

if cfg.liger_rope:
Expand All @@ -68,11 +73,12 @@ def pre_model_load(self, cfg):
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Mistral."
)
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward

elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma

if cfg.liger_rope:
Expand All @@ -84,9 +90,7 @@ def pre_model_load(self, cfg):
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma."
)
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward

elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
Expand Down Expand Up @@ -145,3 +149,36 @@ def pre_model_load(self, cfg):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward

elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2

if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)

elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3

if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

0 comments on commit 15408d0

Please sign in to comment.