Skip to content

Commit

Permalink
update web UI, support RM predict
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 21, 2023
1 parent c9ec965 commit 3ed046a
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 27 deletions.
6 changes: 4 additions & 2 deletions src/glmtuner/dsets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def print_ppo_dataset_example(example):
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

if stage == "sft":
preprocess_function = preprocess_evaluation_dataset \
if training_args.predict_with_generate else preprocess_supervised_dataset
if not training_args.predict_with_generate:
preprocess_function = preprocess_supervised_dataset
else:
preprocess_function = preprocess_evaluation_dataset
elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
Expand Down
2 changes: 1 addition & 1 deletion src/glmtuner/tuner/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_train_args(
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."

assert (not training_args.do_predict) or training_args.predict_with_generate, \
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."

if model_args.quantization_bit is not None:
Expand Down
13 changes: 8 additions & 5 deletions src/glmtuner/tuner/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel

from glmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from glmtuner.extras.logging import get_logger
Expand Down Expand Up @@ -49,9 +50,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str,
else:
backbone_model = model

if self.finetuning_args.finetuning_type == "lora":
if isinstance(backbone_model, PeftModel): # LoRA tuning
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning
elif isinstance(backbone_model, PreTrainedModel): # freeze/full-tuning or p_tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
Expand All @@ -61,6 +62,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str,
backbone_model.config.use_cache = False
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
else:
logger.warning("No model to save.")

with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
Expand All @@ -77,8 +80,8 @@ def _load_best_model(self):
model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model

if self.finetuning_args.finetuning_type == "lora":
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
if isinstance(backbone_model, PeftModel):
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
Expand Down
29 changes: 29 additions & 0 deletions src/glmtuner/tuner/rm/trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os
import json
import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel

from glmtuner.extras.logging import get_logger
from glmtuner.tuner.core.trainer import PeftTrainer


logger = get_logger(__name__)


class PairwiseTrainerForChatGLM(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
Expand Down Expand Up @@ -36,3 +43,25 @@ def compute_loss(
r_accept, r_reject = values[-1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss

def save_predictions(
self,
predict_results: PredictionOutput
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return

output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")

acc_scores, rej_scores = predict_results.predictions

with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
writer.write("\n".join(res))
7 changes: 7 additions & 0 deletions src/glmtuner/tuner/rm/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,10 @@ def run_rm(
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
5 changes: 3 additions & 2 deletions src/glmtuner/webui/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from glmtuner.webui.components.eval import create_eval_tab
from glmtuner.webui.components.infer import create_infer_tab
from glmtuner.webui.components.top import create_top
from glmtuner.webui.components.sft import create_sft_tab
from glmtuner.webui.components.eval import create_eval_tab
from glmtuner.webui.components.infer import create_infer_tab
from glmtuner.webui.components.export import create_export_tab
10 changes: 3 additions & 7 deletions src/glmtuner/webui/components/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@ def create_chat_box(

with gr.Column(scale=1):
clear_btn = gr.Button()
max_length = gr.Slider(
10, 2048, value=chat_model.generating_args.max_length, step=1, interactive=True
)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
max_length = gr.Slider(10, 2048, value=chat_model.generating_args.max_length, step=1)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)

history = gr.State([])

Expand Down
34 changes: 34 additions & 0 deletions src/glmtuner/webui/components/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Dict
import gradio as gr
from gradio.components import Component

from glmtuner.webui.utils import export_model


def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)

export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)

export_btn.click(
export_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
max_shard_size,
save_dir
],
[info_box]
)

return dict(
save_dir=save_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box
)
2 changes: 1 addition & 1 deletion src/glmtuner/webui/components/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,

with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True)
output_dir = gr.Textbox()

with gr.Box():
output_box = gr.Markdown()
Expand Down
8 changes: 6 additions & 2 deletions src/glmtuner/webui/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
create_top,
create_sft_tab,
create_eval_tab,
create_infer_tab
create_infer_tab,
create_export_tab
)
from glmtuner.webui.css import CSS
from glmtuner.webui.manager import Manager
Expand All @@ -30,7 +31,10 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems)

elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)

elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)

demo.load(
Expand Down
44 changes: 44 additions & 0 deletions src/glmtuner/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,34 @@
"zh": {
"label": "温度系数"
}
},
"save_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小(GB)",
"info": "模型文件的最大大小。"
}
},
"export_btn": {
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
}
}

Expand All @@ -477,6 +505,14 @@
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"err_no_checkpoint": {
"en": "Please select a checkpoint.",
"zh": "请选择断点。"
},
"err_no_save_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
Expand Down Expand Up @@ -504,5 +540,13 @@
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
}
}
14 changes: 9 additions & 5 deletions src/glmtuner/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import time
import transformers
from typing import List, Optional, Tuple
from typing import Generator, List, Optional, Tuple

from glmtuner.extras.callbacks import LogCallback
from glmtuner.extras.logging import LoggerHandler
Expand All @@ -24,7 +24,9 @@ def set_abort(self):
self.aborted = True
self.running = False

def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
def initialize(
self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None

Expand All @@ -49,7 +51,9 @@ def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, st

return model_name_or_path, "", logger_handler, trainer_callback

def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
def finalize(
self, lang: str, finish_info: Optional[str] = None
) -> str:
self.running = False
torch_gc()
if self.aborted:
Expand Down Expand Up @@ -85,7 +89,7 @@ def run_train(
lora_dropout: float,
lora_target: str,
output_dir: str
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
Expand Down Expand Up @@ -170,7 +174,7 @@ def run_eval(
max_samples: str,
batch_size: int,
predict: bool
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
Expand Down
44 changes: 42 additions & 2 deletions src/glmtuner/webui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Any, Dict, Tuple
from typing import Any, Dict, Generator, List, Tuple
from datetime import datetime

from glmtuner.extras.ploting import smooth
from glmtuner.webui.common import get_save_dir, DATA_CONFIG
from glmtuner.tuner import get_infer_args, load_model_and_tokenizer
from glmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from glmtuner.webui.locales import ALERTS


def format_info(log: str, tracker: dict) -> str:
Expand Down Expand Up @@ -83,3 +85,41 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig


def export_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
return

model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return

if not checkpoints:
yield ALERTS["err_no_checkpoint"][lang]
return

checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)

if not save_dir:
yield ALERTS["err_no_save_dir"][lang]
return

args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type
)

yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
yield ALERTS["info_exported"][lang]

0 comments on commit 3ed046a

Please sign in to comment.