Skip to content

Commit

Permalink
alter rewards data type hiyouga#127
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 2, 2023
1 parent e733305 commit d22550d
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 43 deletions.
10 changes: 6 additions & 4 deletions src/cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import signal
import platform

from utils import ModelArguments, auto_configure_device_map, load_pretrained
from utils import ModelArguments, FinetuningArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser


Expand All @@ -35,15 +35,17 @@ def signal_handler(signal, frame):
def main():

global stop_stream
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args, finetuning_args)

if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()

model.eval()

history = []
Expand Down
2 changes: 1 addition & 1 deletion src/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main():
ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "reward"])
plot_loss(training_args.output_dir, keys=["loss", "reward"])


def _mp_fn(index):
Expand Down
2 changes: 1 addition & 1 deletion src/train_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main():
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

# Evaluation
if training_args.do_eval:
Expand Down
2 changes: 1 addition & 1 deletion src/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main():
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

# Evaluation
if training_args.do_eval:
Expand Down
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
from .pairwise import PairwiseDataCollatorForChatGLM, PairwiseTrainerForChatGLM
from .ppo import PPOTrainerForChatGLM

from .config import ModelArguments
from .config import ModelArguments, FinetuningArguments
from .other import auto_configure_device_map, get_logits_processor, plot_loss
16 changes: 5 additions & 11 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@
load_valuehead_params,
print_trainable_params,
prepare_model_for_training,
IGNORE_INDEX,
FINETUNING_ARGS_NAME
IGNORE_INDEX
)

check_min_version("4.27.4")
Expand Down Expand Up @@ -130,8 +129,8 @@ def init_adapter(

def load_pretrained(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
training_args: Optional[Seq2SeqTrainingArguments] = None,
finetuning_args: Optional[FinetuningArguments] = None,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
Expand All @@ -140,14 +139,9 @@ def load_pretrained(
Support both training and inference.
"""
if finetuning_args is None: # load the fine-tuning arguments
if model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
else:
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")

assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."

Expand Down
6 changes: 3 additions & 3 deletions src/utils/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Dict, Optional, Sequence, Union

from transformers import DataCollatorWithPadding
from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer

Expand Down Expand Up @@ -64,7 +64,7 @@ def get_position_ids(self, input_ids: torch.Tensor, device: torch.device) -> tor
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
return position_ids

def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
r"""
Pads batched data to the longest sequence in the batch.
Expand Down Expand Up @@ -95,4 +95,4 @@ def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
batch["position_ids"] = self.get_position_ids(input_ids, device=input_ids.device)

return batch
return BatchEncoding(batch)
14 changes: 7 additions & 7 deletions src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
from typing import Dict, List, Optional

from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TRAINER_STATE_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
Expand Down Expand Up @@ -166,7 +165,7 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
return device_map


def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
"""
EMA implementation according to TensorBoard.
"""
Expand All @@ -179,9 +178,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
return smoothed


def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
import matplotlib.pyplot as plt
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)

for key in keys:
steps, metrics = [], []
Expand All @@ -197,9 +197,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]]
plt.figure()
plt.plot(steps, metrics, alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), label="smoothed")
plt.title("training {} of {}".format(key, training_args.output_dir))
plt.title("training {} of {}".format(key, save_dictionary))
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key)))
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
3 changes: 2 additions & 1 deletion src/utils/peft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str,
if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))

torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))

def _load_best_model(self):
Expand Down
11 changes: 2 additions & 9 deletions src/utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
self.finetuning_args = finetuning_args
self.log_callback = callbacks[0]
self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer

def ppo_train(self, max_target_length: int) -> None:
r"""
Expand Down Expand Up @@ -148,7 +148,7 @@ def ppo_train(self, max_target_length: int) -> None:
# Compute rewards
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[-1]]
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end

# Run PPO step
Expand Down Expand Up @@ -214,13 +214,6 @@ def generate(
return response[:, inputs["input_ids"].size(1):]
return response

def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data

@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,
Expand Down
10 changes: 6 additions & 4 deletions src/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@
import mdtex2html
import gradio as gr

from utils import ModelArguments, auto_configure_device_map, load_pretrained
from utils import ModelArguments, FinetuningArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser
from transformers.utils.versions import require_version


require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args, finetuning_args)

if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()

model.eval()


Expand Down

0 comments on commit d22550d

Please sign in to comment.