Skip to content

Commit

Permalink
Add Baichuan2 Support (#247)
Browse files Browse the repository at this point in the history
Co-authored-by: Casper <casperbh.96@gmail.com>
  • Loading branch information
AoyuQC and casper-hansen authored Dec 23, 2023
1 parent 9e8e28b commit cef9f11
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 9 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM
from .baichuan import BaichuanAWQForCausalLM
from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM,
"qwen": QwenAWQForCausalLM,
"baichuan": BaichuanAWQForCausalLM,
"llava": LlavaAWQForCausalLM,
}

Expand Down
137 changes: 137 additions & 0 deletions awq/models/baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm

class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "BaichuanLayer"
max_new_tokens_key = "model_max_length"

@staticmethod
def fuse_layers(model):
fuser = BaichuanFuser(model)
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model):
return model.model.layers

@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)

@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
# def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []

# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.W_pack],
inp=input_feat['self_attn.W_pack'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))

# # attention out
# # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
# if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
# layers.append(dict(
# prev_op=module.self_attn.v_proj,
# layers=[module.self_attn.o_proj],
# inp=input_feat['self_attn.o_proj'],
# ))

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
layers.append(dict(
prev_op=module.self_attn.W_pack,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))

# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))

# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))

return layers


class BaichuanFuser:
def __init__(self, model):
self.model = model

self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
# qkv = fuse_qkv(
# module,
# module.self_attn.q_proj,
# module.self_attn.k_proj,
# module.self_attn.v_proj
# )
qkv = module.self_attn.W_pack
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_attention_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens,
use_alibi=True
))

self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
2 changes: 2 additions & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"aquila": "AutoModelForCausalLM",
"Yi": "AutoModelForCausalLM",
"qwen": "AutoModelForCausalLM",
"baichuan": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq",
}

Expand Down Expand Up @@ -90,6 +91,7 @@ def quantize(self, tokenizer=None, quant_config={},
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
)
quantizer.quantize()

self.is_quantized = True

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LlamaLikeBlock(nn.Module):
"""
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta, use_alibi=False
):
super().__init__()
self.n_heads = n_heads
Expand All @@ -52,7 +52,7 @@ def __init__(
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta
dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
Expand Down Expand Up @@ -185,4 +185,4 @@ def forward(

out = h_attn + h_mlp

return out, None, past_key_value
return out, None, past_key_value
24 changes: 18 additions & 6 deletions awq/utils/calib_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Union
from datasets import load_dataset

def get_calib_dataset(data: Union[str, List[str]] = "pileval",
def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None, n_samples=512, block_size=512,
split="train", text_column="text"):
if isinstance(data, str):
Expand All @@ -15,18 +15,30 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
dataset = dataset.shuffle(seed=42)

elif isinstance(data, list):
dataset = [{text_column: text} for text in data]
if isinstance(data[0], str):
dataset = [{text_column: text} for text in data]
elif isinstance(data[0][0], int):
dataset = data
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.")
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element.")
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.")

samples = []
n_run = 0
for data in dataset:
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if isinstance(data, list):
line_encoded = data
else:
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def main(args):
{"context": 512, "n_generate": 512},
{"context": 1024, "n_generate": 1024},
{"context": 2048, "n_generate": 2048},
{"context": 4096, "n_generate": 4096},
]

if args.generator == "torch":
Expand Down

0 comments on commit cef9f11

Please sign in to comment.