Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 20, 2023
1 parent cebb2d0 commit bb2e050
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 37 deletions.
8 changes: 6 additions & 2 deletions src/api_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

import uvicorn

from glmtuner import create_app
from glmtuner.api.app import create_app


if __name__ == "__main__":
def main():
app = create_app()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Implements stream chat in command line for ChatGLM fine-tuned with PEFT.
# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint [--quantization_bit 4]

from glmtuner import ChatModel, get_infer_args
from glmtuner import ChatModel
from glmtuner.tuner import get_infer_args


def main():
Expand Down
3 changes: 1 addition & 2 deletions src/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# Exports the fine-tuned ChatGLM-6B model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model


from glmtuner import get_train_args, load_model_and_tokenizer
from glmtuner.tuner import get_train_args, load_model_and_tokenizer


def main():
Expand Down
3 changes: 0 additions & 3 deletions src/glmtuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from glmtuner.api import create_app
from glmtuner.chat import ChatModel
from glmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_sft, run_rm, run_ppo
from glmtuner.webui import create_ui


__version__ = "0.1.3"
1 change: 0 additions & 1 deletion src/glmtuner/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from glmtuner.api.app import create_app
1 change: 1 addition & 0 deletions src/glmtuner/dsets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from glmtuner.dsets.collator import DataCollatorForChatGLM
from glmtuner.dsets.loader import get_dataset
from glmtuner.dsets.preprocess import preprocess_dataset
from glmtuner.dsets.utils import split_dataset
16 changes: 16 additions & 0 deletions src/glmtuner/dsets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Dict
from datasets import Dataset


def split_dataset(
dataset: Dataset, dev_ratio: float, do_train: bool
) -> Dict[str, Dataset]:
# Split the dataset
if do_train:
if dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}
14 changes: 2 additions & 12 deletions src/glmtuner/tuner/rm/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback

from glmtuner.dsets import get_dataset, preprocess_dataset
from glmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from glmtuner.extras.callbacks import LogCallback
from glmtuner.extras.ploting import plot_loss
from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
Expand All @@ -29,16 +29,6 @@ def run_rm(

training_args.remove_unused_columns = False # Important for pairwise dataset

# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}

# Initialize our Trainer
trainer = PairwiseTrainerForChatGLM(
finetuning_args=finetuning_args,
Expand All @@ -48,7 +38,7 @@ def run_rm(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
**trainer_kwargs
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
)

# Training
Expand Down
14 changes: 2 additions & 12 deletions src/glmtuner/tuner/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback

from glmtuner.dsets import DataCollatorForChatGLM, get_dataset, preprocess_dataset
from glmtuner.dsets import DataCollatorForChatGLM, get_dataset, preprocess_dataset, split_dataset
from glmtuner.extras.callbacks import LogCallback
from glmtuner.extras.misc import get_logits_processor
from glmtuner.extras.ploting import plot_loss
Expand Down Expand Up @@ -35,16 +35,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams

# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}

# Initialize our Trainer
trainer = Seq2SeqTrainerForChatGLM(
finetuning_args=finetuning_args,
Expand All @@ -54,7 +44,7 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
)

# Keyword arguments for `model.generate`
Expand Down
1 change: 0 additions & 1 deletion src/glmtuner/webui/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from glmtuner.webui.interface import create_ui
2 changes: 1 addition & 1 deletion src/train_bash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from glmtuner import get_train_args, run_sft, run_rm, run_ppo
from glmtuner.tuner import get_train_args, run_sft, run_rm, run_ppo


def main():
Expand Down
2 changes: 1 addition & 1 deletion src/train_web.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from glmtuner import create_ui
from glmtuner.webui.interface import create_ui


def main():
Expand Down
2 changes: 1 addition & 1 deletion src/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gradio as gr
from transformers.utils.versions import require_version

from glmtuner import get_infer_args
from glmtuner.tuner import get_infer_args
from glmtuner.webui.chat import WebChatModel
from glmtuner.webui.components.chatbot import create_chat_box
from glmtuner.webui.manager import Manager
Expand Down

0 comments on commit bb2e050

Please sign in to comment.