Skip to content

Commit

Permalink
fix bugs in distributed PPO training
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 18, 2023
1 parent b855320 commit ecf5760
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 61 deletions.
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
42 changes: 23 additions & 19 deletions src/utils/data_collator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from typing import Dict, Optional, Sequence
from typing import Dict, Optional, Sequence, Union

from transformers import DataCollatorWithPadding
from transformers.modeling_utils import PreTrainedModel
Expand All @@ -23,7 +23,7 @@ def __init__(
self.model = model
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id

def get_attention_masks(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates attention masks for left-padded sequences.
Expand All @@ -32,7 +32,7 @@ def get_attention_masks(self, input_ids: torch.Tensor) -> torch.Tensor:
According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L680
"""
batch_size, seq_length = input_ids.size()
attention_mask = torch.ones((batch_size, seq_length, seq_length))
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
attention_mask.tril_()
for i, seq in enumerate(input_ids):
attention_mask[i, :, :(seq == self.tokenizer.bos_token_id).nonzero()[0].item()] = 1 # context
Expand All @@ -41,7 +41,7 @@ def get_attention_masks(self, input_ids: torch.Tensor) -> torch.Tensor:
attention_mask = (attention_mask < 0.5).bool()
return attention_mask

def get_position_ids(self, input_ids: torch.Tensor):
def get_position_ids(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates position ids for left-padded sequenes.
Expand All @@ -50,36 +50,40 @@ def get_position_ids(self, input_ids: torch.Tensor):
batch_size, seq_length = input_ids.size()
mask: int = self.model.config.mask_token_id
gmask: int = self.model.config.gmask_token_id
position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
block_position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
block_position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
for i, seq in enumerate(input_ids):
mask_token = gmask if gmask in seq else mask
context_length = (seq == self.tokenizer.bos_token_id).nonzero()[0].item()
padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item()
position_ids[i, padding_length:] = torch.arange(seq_length - padding_length, dtype=torch.long)
position_ids[i, padding_length:] = torch.arange(seq_length - padding_length, dtype=torch.long, device=device)
if self.model.position_encoding_2d or (mask_token != gmask): # 2d position encoding or not gMASK
position_ids[i, context_length:] = (seq == mask_token).nonzero()[0].item() - padding_length # mask position
block_position_ids[i, context_length:] = torch.arange(seq_length - context_length, dtype=torch.long) + 1
block_position_ids[i, context_length:] = torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
if self.model.position_encoding_2d:
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
return position_ids

def __call__(self, features: Sequence[Dict[str, Sequence[int]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We adopt left-padding in both training and evaluation.
"""
input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1)
if "labels" in features[0]:
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id).flip(-1)
if isinstance(features[0]["input_ids"], torch.Tensor):
input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
else:
labels = None
return {
input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1)
batch = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": self.get_attention_masks(input_ids),
"position_ids": self.get_position_ids(input_ids)
"attention_mask": self.get_attention_masks(input_ids, device=input_ids.device),
"position_ids": self.get_position_ids(input_ids, device=input_ids.device)
}
if "labels" in features[0]:
if isinstance(features[0]["labels"], torch.Tensor):
labels = [feature["labels"].clone().detach().flip(0) for feature in features]
else:
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
batch["labels"] = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id).flip(-1)
return batch
2 changes: 1 addition & 1 deletion src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]]
metrics.append(data["log_history"][i][key])

if len(metrics) == 0:
logger.warning("No metrics to plot.")
logger.warning("No metric to plot.")
return

plt.figure()
Expand Down
86 changes: 45 additions & 41 deletions src/utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,11 @@ class PPOTrainerForChatGLM(PPOTrainer):

def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_collator = kwargs["data_collator"]
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
self.state = {"log_history": []}
self.training_args = training_args
self.finetuning_args = finetuning_args

@torch.no_grad()
def generate(
self,
inputs: Dict[str, torch.Tensor],
length_sampler: Callable = None,
return_prompt: bool = True,
**generation_kwargs,
) -> torch.Tensor:
r"""
Generate response with the model given the query tensor.
Inspired by: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/trl/trainer/ppo_trainer.py#L387
"""

self.model, layer_norm_params = cast_layernorm_dtype(self.model)

if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()

unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)

response = unwrapped_model.generate(**inputs, **generation_kwargs)

# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False

self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)

if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response

def ppo_train(self, max_target_length: int) -> None:

total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.training_args.world_size
Expand Down Expand Up @@ -161,14 +127,15 @@ def ppo_train(self, max_target_length: int) -> None:
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
if response_length < 2:
continue
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
if response_length < 2: # make response have at least 2 tokens
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
else:
responses.append(response_tensors[i, :response_length]) # remove padding from right

# Compute rewards
replace_model(unwrapped_model, target="reward")
_, _, values = unwrapped_model(**self.prepare_model_inputs(queries, responses))
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[-1]]
replace_model(unwrapped_model, target="default")

Expand Down Expand Up @@ -201,7 +168,41 @@ def ppo_train(self, max_target_length: int) -> None:
if (step+1) % self.training_args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.training_args.output_dir, f"checkpoint-{step+1}"))

def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
@torch.no_grad()
def generate(
self,
inputs: Dict[str, torch.Tensor],
length_sampler: Callable = None,
return_prompt: bool = True,
**generation_kwargs,
) -> torch.Tensor:
r"""
Generate response with the model given the query tensor.
Inspired by: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/trl/trainer/ppo_trainer.py#L387
"""

self.model, layer_norm_params = cast_layernorm_dtype(self.model)

if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()

unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)

response = unwrapped_model.generate(**inputs, **generation_kwargs)

# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False

self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)

if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response

def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
Expand Down Expand Up @@ -229,8 +230,11 @@ def batched_forward_pass(
all_values = []

for i in range(int(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
input_kwargs = {k: v[i * fbs : (i + 1) * fbs] for k, v in model_inputs.items()}
input_ids: torch.Tensor = input_kwargs["input_ids"] # left-padded sequences
if self.is_distributed: # re-generate them to adapt padded inputs
input_kwargs["attention_mask"] = self.data_collator.get_attention_masks(input_ids, device=self.current_device)
input_kwargs["position_ids"] = self.data_collator.get_position_ids(input_ids, device=self.current_device)
logits, _, values = model(**input_kwargs)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])

Expand Down

0 comments on commit ecf5760

Please sign in to comment.