Skip to content

Commit

Permalink
update web UI
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 20, 2023
1 parent d270895 commit 0c58dfc
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 43 deletions.
5 changes: 4 additions & 1 deletion src/api_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import uvicorn

from glmtuner import ChatModel
from glmtuner.api.app import create_app
from glmtuner.tuner import get_infer_args


def main():
app = create_app()
chat_model = ChatModel(*get_infer_args())
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)


Expand Down
7 changes: 3 additions & 4 deletions src/glmtuner/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory
torch_gc()


def create_app():
chat_model = ChatModel(*get_infer_args())

def create_app(chat_model: ChatModel) -> FastAPI:
app = FastAPI(lifespan=lifespan)

app.add_middleware(
Expand Down Expand Up @@ -124,5 +122,6 @@ async def predict(query: str, history: List[Tuple[str, str]], prefix: str, reque


if __name__ == "__main__":
app = create_app()
chat_model = ChatModel(*get_infer_args())
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
3 changes: 2 additions & 1 deletion src/glmtuner/webui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ def predict(
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
prefix: str,
max_length: int,
top_p: float,
temperature: float
):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, max_length=max_length, top_p=top_p, temperature=temperature
query, history, prefix, max_length=max_length, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
Expand Down
11 changes: 5 additions & 6 deletions src/glmtuner/webui/components/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ def create_chat_box(

with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
query = gr.Textbox(show_label=False, lines=8)

with gr.Column(min_width=32, scale=1):
submit_btn = gr.Button(variant="primary")
prefix = gr.Dropdown(show_label=False)
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")

with gr.Column(scale=1):
clear_btn = gr.Button()
Expand All @@ -36,7 +34,7 @@ def create_chat_box(

submit_btn.click(
chat_model.predict,
[chatbot, query, history, max_length, top_p, temperature],
[chatbot, query, history, prefix, max_length, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
Expand All @@ -46,6 +44,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)

return chat_box, chatbot, history, dict(
prefix=prefix,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
Expand Down
3 changes: 2 additions & 1 deletion src/glmtuner/webui/components/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
start_btn = gr.Button()
stop_btn = gr.Button()

output_box = gr.Markdown()
with gr.Box():
output_box = gr.Markdown()

start_btn.click(
runner.run_eval,
Expand Down
2 changes: 1 addition & 1 deletion src/glmtuner/webui/components/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
load_btn = gr.Button()
unload_btn = gr.Button()

info_box = gr.Markdown()
info_box = gr.Textbox(show_label=False, interactive=False)

chat_model = WebChatModel()
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
Expand Down
37 changes: 30 additions & 7 deletions src/glmtuner/webui/components/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,21 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
)
max_grad_norm = gr.Textbox(value="1.0")
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
fp16 = gr.Checkbox(value=True)

with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")

with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)

with gr.Row():
start_btn = gr.Button()
Expand All @@ -49,7 +58,9 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True)
output_box = gr.Markdown()

with gr.Box():
output_box = gr.Markdown()

with gr.Column(scale=1):
loss_viewer = gr.Plot()
Expand All @@ -73,10 +84,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
max_grad_norm,
dev_ratio,
fp16,
logging_steps,
save_steps,
warmup_steps,
compute_type,
lora_rank,
lora_dropout,
lora_target,
output_dir
],
[output_box]
Expand All @@ -102,10 +118,17 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
dev_ratio=dev_ratio,
fp16=fp16,
advanced_tab=advanced_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
compute_type=compute_type,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
Expand Down
8 changes: 5 additions & 3 deletions src/glmtuner/webui/components/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def create_top() -> Dict[str, Component]:
checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1)

with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1)
source_prefix = gr.Textbox(scale=6)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1)
source_prefix = gr.Textbox(scale=4)

model_name.change(
get_model_path, [model_name], [model_path]
Expand All @@ -47,6 +48,7 @@ def create_top() -> Dict[str, Component]:
finetuning_type=finetuning_type,
checkpoints=checkpoints,
refresh_btn=refresh_btn,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
source_prefix=source_prefix
)
2 changes: 1 addition & 1 deletion src/glmtuner/webui/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)

with gr.Tab("Inference"):
with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems)

elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
Expand Down
106 changes: 90 additions & 16 deletions src/glmtuner/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
"value": "刷新断点"
}
},
"advanced_tab": {
"en": {
"label": "Advanced configurations"
},
"zh": {
"label": "高级设置"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit (optional)",
Expand All @@ -71,12 +79,12 @@
},
"source_prefix": {
"en": {
"label": "Source prefix (optional)",
"info": "A sequence used as the prefix of each samples."
"label": "System prompt (optional)",
"info": "A sequence used as the default system prompt."
},
"zh": {
"label": "前缀序列(非必填)",
"info": "作为每个输入样本前缀的序列"
"label": "系统提示词(非必填)",
"info": "默认使用的系统提示词"
}
},
"dataset_dir": {
Expand Down Expand Up @@ -209,30 +217,30 @@
"info": "采用的学习率调节器名称。"
}
},
"dev_ratio": {
"max_grad_norm": {
"en": {
"label": "Dev ratio",
"info": "Proportion of data in the dev set."
"label": "Maximum gradient norm",
"info": "Norm for gradient clipping.."
},
"zh": {
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
"label": "最大梯度范数",
"info": "用于梯度裁剪的范数。"
}
},
"fp16": {
"dev_ratio": {
"en": {
"label": "fp16",
"info": "Whether to use fp16 mixed precision training."
"label": "Dev ratio",
"info": "Proportion of data in the dev set."
},
"zh": {
"label": "fp16",
"info": "是否启用 FP16 混合精度训练。"
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
}
},
"logging_steps": {
"en": {
"label": "Logging steps",
"info": "Number of update steps between two logs."
"info": "Number of steps between two logs."
},
"zh": {
"label": "日志间隔",
Expand All @@ -242,13 +250,71 @@
"save_steps": {
"en": {
"label": "Save steps",
"info": "Number of updates steps between two checkpoints."
"info": "Number of steps between two checkpoints."
},
"zh": {
"label": "保存间隔",
"info": "每两次断点保存间的更新步数。"
}
},
"warmup_steps": {
"en": {
"label": "Warmup steps",
"info": "Number of steps used for warmup."
},
"zh": {
"label": "预热步数",
"info": "学习率预热采用的步数。"
}
},
"compute_type": {
"en": {
"label": "Compute type",
"info": "Whether to use fp16 or bf16 mixed precision training."
},
"zh": {
"label": "计算类型",
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
},
"zh": {
"label": "LoRA 参数设置"
}
},
"lora_rank": {
"en": {
"label": "LoRA rank",
"info": "The rank of LoRA matrices."
},
"zh": {
"label": "LoRA 秩",
"info": "LoRA 矩阵的秩。"
}
},
"lora_dropout": {
"en": {
"label": "LoRA Dropout",
"info": "Dropout ratio of LoRA weights."
},
"zh": {
"label": "LoRA 随机丢弃",
"info": "LoRA 权重随机丢弃的概率。"
}
},
"lora_target": {
"en": {
"label": "LoRA modules (optional)",
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
},
"zh": {
"label": "LoRA 作用层(非必填)",
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
}
},
"start_btn": {
"en": {
"value": "Start"
Expand Down Expand Up @@ -323,6 +389,14 @@
"value": "模型未加载,请先加载模型。"
}
},
"prefix": {
"en": {
"placeholder": "System prompt (optional)"
},
"zh": {
"placeholder": "系统提示词(非必填)"
}
},
"query": {
"en": {
"placeholder": "Input..."
Expand Down
Loading

0 comments on commit 0c58dfc

Please sign in to comment.