Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen model #182

Merged
merged 4 commits into from
Nov 20, 2023
Merged

Add Qwen model #182

merged 4 commits into from
Nov 20, 2023

Conversation

Sanster
Copy link
Contributor

@Sanster Sanster commented Nov 10, 2023

Modify according to this PR: #78

@casper-hansen
Copy link
Owner

Hi @Sanster, thank you for this. Did you test if quantizing a model works and that inference runs?

the problem I ran into in my old PR was that there was some problem with the modeling code that prevented me from quantizing.

@casper-hansen
Copy link
Owner

I tried quantizing a model but the model outputs are weird and the eval does not work for this model. At this time, I don't think we can merge this pull request before we can measure that it works after quantizing

@Sanster
Copy link
Contributor Author

Sanster commented Nov 15, 2023

Hi, in my testing, the model is working properly.

image

Quant script

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from datasets import load_dataset

model_dir = "Qwen/Qwen-7B-Chat"
save_dir = "./Qwen-7B-Chat-quant"
dataset = load_dataset("GAIR/lima")["train"]
quant_count = 16

quant_config = {
    "zero_point": True,
    "q_group_size": 128,
    "w_bit": 4,
    "version": "GEMM",
}

examples = []
for it in dataset["conversations"][:quant_count]:
    query = it[0]
    answer = it[1]
    examples.append(
        f"<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>"
    )

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tokenizer.save_pretrained(str(save_dir))

model = (
    AutoAWQForCausalLM.from_pretrained(
        model_dir, trust_remote_code=True, safetensors=True
    )
    .eval()
    .cuda()
)

model.quantize(tokenizer, quant_config=quant_config, calib_data=examples)
model.save_quantized(str(save_dir), safetensors=True)

Test script

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

quant_model_dir = "./Qwen-7B-Chat-quant"
text = "<|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n"

model = AutoAWQForCausalLM.from_quantized(str(quant_model_dir), fuse_layers=True).eval()
tokenizer = AutoTokenizer.from_pretrained(str(quant_model_dir), trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)

tokens = tokenizer(text, return_tensors="pt").input_ids.cuda()

model.generate(tokens, streamer=streamer, max_new_tokens=100, eos_token_id=151645)

@enbacoo
Copy link

enbacoo commented Nov 19, 2023

@casper-hansen , I test @Sanster job , when I use " <|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>" replace Quant script's part, it works well.

@casper-hansen
Copy link
Owner

This seems to work for me now with the right prompt template. Thanks for the PR! (NOTE: Eval is not working currently, but response of Qwen looks good).

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

quant_path = "qwen-7b-chat-awq"

# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# Convert prompt to tokens
prompt_template = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

prompt = "You're standing on the surface of the Earth. "\
        "You walk one mile south, one mile west and one mile north. "\
        "You end up exactly where you started. Where are you?"

tokens = tokenizer(
    prompt_template.format(prompt=prompt), 
    return_tensors='pt'
).input_ids.cuda()

# Generate output
generation_output = model.generate(
    tokens, 
    streamer=streamer,
    max_new_tokens=512,
    eos_token_id=151645
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants