Skip to content

Commit

Permalink
merge scripts into ppo trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 12, 2023
1 parent 4b367e7 commit 9e3da36
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 164 deletions.
73 changes: 17 additions & 56 deletions src/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py

from tqdm import tqdm
import math

import torch
from torch.optim import AdamW

from transformers.optimization import get_scheduler
from trl import PPOConfig
from trl.core import LengthSampler

from utils import (
prepare_args,
Expand All @@ -18,8 +17,6 @@
preprocess_data,
PPODataCollatorForChatGLM,
PPOTrainerForChatGLM,
compute_rewards,
get_logits_processor,
plot_loss
)

Expand All @@ -41,14 +38,22 @@ def main():
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=max(training_args.per_device_train_batch_size // 4, 1),
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
)

optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
total_train_batch_size = \
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
)

# Initialize our Trainer
ppo_trainer = PPOTrainerForChatGLM(
Expand All @@ -60,59 +65,15 @@ def main():
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer
optimizer=optimizer,
lr_scheduler=lr_scheduler
)

# Keyword arguments for `model.generate`
gen_kwargs = {
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
output_length_sampler = LengthSampler(data_args.max_target_length // 2, data_args.max_target_length)

n_batches = len(ppo_trainer.dataloader)
dataloader = iter(ppo_trainer.dataloader)
for step in tqdm(range(int(training_args.num_train_epochs) * n_batches)):

batch = next(dataloader)
queries = batch["input_ids"] # left-padded sequences

model.gradient_checkpointing_disable()
model.config.use_cache = True

# Get response from ChatGLM
responses_with_queries = ppo_trainer.generate(queries, length_sampler=output_length_sampler, **gen_kwargs)
responses = responses_with_queries[:, queries.size(1):].clone().detach() # right-padded sequences (remember to clone!!!)
# batch["response"] = tokenizer.batch_decode(responses, skip_special_tokens=True) # comment to avoid decode error

for i in range(responses_with_queries.size(0)): # change to right-padding
start = (responses_with_queries[i] != tokenizer.pad_token_id).nonzero()[0].item()
responses_with_queries[i] = torch.cat((responses_with_queries[i][start:], responses_with_queries[i][:start]))

# Compute rewards
rewards = compute_rewards(responses_with_queries, model, tokenizer)

# Run PPO step
model.gradient_checkpointing_enable()
model.config.use_cache = False

split_into_list = lambda x: [x[i] for i in range(x.size(0))]
stats = ppo_trainer.step(*map(split_into_list, [queries, responses, rewards]))

ppo_trainer.log_stats(stats, batch, rewards)
ppo_trainer.update_stats(stats, batch, rewards)

if (step+1) % n_batches == 0:
dataloader = iter(ppo_trainer.dataloader)

ppo_trainer.save_state() # along with the loss values
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_state()
ppo_trainer.save_model()
if finetuning_args.plot_loss:
plot_loss(training_args)
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args, keys=["loss", "reward"])


def _mp_fn(index):
Expand Down
5 changes: 2 additions & 3 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@

from .ppo import (
PPODataCollatorForChatGLM,
PPOTrainerForChatGLM,
compute_rewards
PPOTrainerForChatGLM
)

from .config import ModelArguments

from .other import get_logits_processor, plot_loss
from .other import plot_loss
62 changes: 27 additions & 35 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
import torch
import hashlib
import logging
from typing import Literal, Optional, Tuple

import transformers
Expand Down Expand Up @@ -37,6 +36,7 @@
)

from .other import (
get_logger,
load_trainable_params,
load_valuehead_params,
print_trainable_params,
Expand All @@ -46,13 +46,7 @@
)


logger = logging.getLogger(__name__) # setup logging
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = get_logger(__name__)


check_min_version("4.27.4")
Expand All @@ -66,7 +60,7 @@ def init_adapter(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: bool
) -> None:
) -> PreTrainedModel:
r"""
Initializes the adapters.
Expand Down Expand Up @@ -95,8 +89,7 @@ def init_adapter(
load_trainable_params(model, model_args.checkpoint_dir[0])

if finetuning_args.finetuning_type == "p_tuning":
logger.info("Fine-tuning method: P-Tuning v2")
model.transformer.prefix_encoder.float() # other parameters are already fixed
logger.info("Fine-tuning method: P-Tuning v2") # nothing to do

if model_args.checkpoint_dir is not None:
load_trainable_params(model, model_args.checkpoint_dir[0])
Expand Down Expand Up @@ -131,12 +124,6 @@ def init_adapter(
)
model = get_peft_model(model, lora_config)

if not is_trainable:
for param in model.parameters():
param.requires_grad_(False) # fix all params

model = model.half() # cast all params to float16

return model


Expand Down Expand Up @@ -221,13 +208,18 @@ def load_pretrained(
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)

if not is_trainable:
model.requires_grad_(False) # fix all params
model = model.half() # cast all params to float16

# Quantization with the built-in method for P-Tuning v2 training or evaluation.
# Model parameters should be cast to float16 in quantized P-Tuning setting.
if quantization == "cpm":
assert model_args.quantization_bit in [4, 8], "P-Tuning v2 and inference mode only accept 4-bit or 8-bit quantization."
assert not (is_trainable and training_args.fp16), "FP16 training conflicts with cpm quantization."

model = model.quantize(model_args.quantization_bit)
model.quantize(model_args.quantization_bit) # in-place method

for name, param in model.named_parameters():
if "prefix_encoder" not in name:
param.data = param.data.to(torch.float16) # convert all params in half precision except prefix_encoder
Expand All @@ -236,9 +228,10 @@ def load_pretrained(
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))

if stage == "rwd" or stage == "ppo": # add value head
assert is_trainable, "Reward model and PPO model cannot be loaded at evaluation."
assert is_trainable, "Reward and PPO stages cannot be performed at evaluation."

model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

if stage == "ppo": # load reward model
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model)
Expand All @@ -256,36 +249,35 @@ def load_pretrained(
def prepare_args() -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# Provide arguments with a json file.

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()

# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1:
raise ValueError("We must perform single operation among do_train, do_eval and do_predict.")

if model_args.quantization_bit is not None and training_args.do_train == False:
logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.")

if not training_args.fp16:
logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.")

training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning

# Set logger
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1:
raise ValueError("We must perform a single operation among do_train, do_eval and do_predict.")

if model_args.quantization_bit is not None and training_args.do_train == False:
logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.")

if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.")

training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
Expand Down
59 changes: 41 additions & 18 deletions src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@
PREDICTION_FILE_NAME = "generated_predictions.txt"


logger = logging.getLogger(__name__) # setup logging
logger.setLevel(logging.INFO)
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
level=logging.INFO,
handlers=[logging.StreamHandler(sys.stdout)]
)


def get_logger(name: str) -> logging.Logger:
return logging.getLogger(name)


class AverageMeter:
r"""
Computes and stores the average and current value.
Expand Down Expand Up @@ -61,7 +65,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


def get_logits_processor():
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor
Expand All @@ -73,7 +77,7 @@ def prepare_model_for_training(
model: PreTrainedModel,
output_embedding_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: List[str] = ["layernorm"] # for chatglm setting
layer_norm_names: Optional[List[str]] = ["layernorm"] # for chatglm setting
) -> PreTrainedModel:

for name, param in model.named_parameters():
Expand Down Expand Up @@ -156,18 +160,37 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))


def plot_loss(training_args: Seq2SeqTrainingArguments) -> None:
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
"""
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
last = smoothed_val
return smoothed


def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
import matplotlib.pyplot as plt
FIGURE_NAME = "trainer_state.png"
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
train_steps, train_losses = [], []
for i in range(len(data["log_history"]) - 1):
train_steps.append(data["log_history"][i]["step"])
train_losses.append(data["log_history"][i]["loss"])
plt.figure()
plt.plot(train_steps, train_losses)
plt.title("training loss of {}".format(training_args.output_dir))
plt.xlabel("step")
plt.ylabel("training loss")
plt.savefig(os.path.join(training_args.output_dir, FIGURE_NAME), format="png", transparent=True, dpi=300)
print("Figure saved: {}".format(os.path.join(training_args.output_dir, FIGURE_NAME)))

for key in keys:
steps, metrics = [], []

for i in range(len(data["log_history"])):
if key in data["log_history"][i]:
steps.append(data["log_history"][i]["step"])
metrics.append(data["log_history"][i][key])
smoothed_value = smooth(metrics)

plt.figure()
plt.plot(steps, metrics, alpha=0.4)
plt.plot(steps, smoothed_value)
plt.title("training {} of {}".format(key, training_args.output_dir))
plt.xlabel("step")
plt.ylabel(key)
plt.savefig(os.path.join(training_args.output_dir, "training_{}.jpg".format(key)), format="jpg", dpi=100)
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.jpg".format(key)))
11 changes: 2 additions & 9 deletions src/utils/pairwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import sys
import torch
import logging
from typing import Dict, Optional, Sequence

from transformers import Trainer, DataCollatorWithPadding
Expand All @@ -12,19 +10,14 @@
from .config import FinetuningArguments

from .other import (
get_logger,
save_trainable_params,
save_valuehead_params,
FINETUNING_ARGS_NAME
)


logger = logging.getLogger(__name__) # setup logging
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = get_logger(__name__)


class PairwiseDataCollatorForChatGLM(DataCollatorWithPadding):
Expand Down
Loading

0 comments on commit 9e3da36

Please sign in to comment.