Skip to content

Commit

Permalink
Merge pull request hiyouga#20 from codemayq/main
Browse files Browse the repository at this point in the history
add web_demo adapted from ChatGLM-6B
  • Loading branch information
hiyouga committed Apr 23, 2023
2 parents 32eb96f + 7ff1d3f commit fed1a48
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ CUDA_VISIBLE_DEVICES=0 python src/infer.py \
--checkpoint_dir path_to_checkpoint
```

### Web Demo
```bash
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--checkpoint_dir path_to_checkpoint
```

### Deploy the Fine-tuned Model
```python
from .src import load_pretrained, ModelArguments
Expand Down
7 changes: 7 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ CUDA_VISIBLE_DEVICES=0 python src/infer.py \
--checkpoint_dir path_to_checkpoint
```

### 网页版测试

```bash
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--checkpoint_dir path_to_checkpoint
```

### 部署微调模型

```python
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ jieba
rouge_chinese
nltk
git+https://github.com/huggingface/peft.git
gradio
mdtex2html
106 changes: 106 additions & 0 deletions src/web_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""copy from https://github.com/THUDM/ChatGLM-6B"""
import gradio as gr
import mdtex2html

from utils import ModelArguments, load_pretrained
from transformers import HfArgumentParser

parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
model = model.half().cuda()
model = model.eval()


"""Override Chatbot.postprocess"""

def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>"+line
text = "".join(lines)
return text


def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))

yield chatbot, history


def reset_user_input():
return gr.update(value='')


def reset_state():
return [], []


with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM-Efficient-Tuning</h1>""")

chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

history = gr.State([])

submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])

emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True)

0 comments on commit fed1a48

Please sign in to comment.