Skip to content

Commit

Permalink
update mask
Browse files Browse the repository at this point in the history
Signed-off-by: mymusise <mymusise1@gmail.com>
  • Loading branch information
mymusise committed Mar 20, 2023
1 parent f7ba507 commit b89ee70
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 48 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@

## 数据预处理


转化alpaca数据集为jsonl

```bash
python cover_alpaca2jsonl.py \
--data_path data/alpaca_data.json \
--save_path data/alpaca_data.jsonl \
```

tokenization

```bash
python tokenize_dataset_rows.py \
--jsonl_path data/alpaca_data.jsonl \
--save_path data/alpaca \
--max_seq_length 320
```

- `--jsonl_path` 微调的数据路径, 格式jsonl, 对每行的['text']字段进行encode
- `--jsonl_path` 微调的数据路径, 格式jsonl, 对每行的['context']['target']字段进行encode
- `--save_path` 输出路径
- `--max_seq_length` 样本的最大长度

Expand All @@ -45,6 +56,7 @@ python finetune.py \
--save_total_limit 2 \
--learning_rate 2e-5 \
--fp16 \
--remove_unused_columns false \
--logging_steps 50 \
--output_dir output
```
Expand All @@ -56,6 +68,6 @@ python finetune.py \

# TODO:

- ~ bs > 1 support ~
- ~~bs > 1 support~~
- 使用中文数据
- 加入RLHF
30 changes: 30 additions & 0 deletions cover_alpaca2jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import argparse
import json
from tqdm import tqdm


def format_example(example: dict) -> dict:
context = f"Instruction: {example['instruction']}\n"
if example.get("input"):
context += f"Input: {example['input']}\n"
context += "Answer: "
target = example["output"]
return {"context": context, "target": target}


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default="data/alpaca_data.json")
parser.add_argument("--save_path", type=str, default="data/alpaca_data.jsonl")

args = parser.parse_args()
with open(args.data_path) as f:
examples = json.load(f)

with open(args.save_path, 'w') as f:
for example in tqdm(examples, desc="formatting.."):
f.write(json.dumps(format_example(example)) + '\n')


if __name__ == "__main__":
main()
106 changes: 82 additions & 24 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from transformers import TrainingArguments
from transformers import Trainer, HfArgumentParser
from transformers import AutoTokenizer
from modeling_chatglm import ChatGLMForConditionalGeneration
import torch
import torch.nn as nn
Expand All @@ -9,6 +10,9 @@
import os


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)


@dataclass
class FinetuneArguments:
dataset_path: str = field(default="data/alpaca")
Expand All @@ -17,53 +21,106 @@ class FinetuneArguments:


class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
def forward(self, x):
return super().forward(x).to(torch.float32)


def get_masks_and_position_ids(
seq, mask_position, context_length, device, gmask=False, position_encoding_2d=True
):
attention_mask = torch.ones((1, context_length, context_length), device=device)
attention_mask.tril_()
attention_mask[..., : mask_position - 1] = 1
attention_mask = (attention_mask < 0.5).bool()

if position_encoding_2d:
seq_length = seq.index(tokenizer.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat(
(
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(
context_length - seq_length, dtype=torch.long, device=device
)
+ 1,
)
)
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1 :] = mask_position
return attention_mask, position_ids


class ModifiedTrainer(Trainer):
def data_collator(features: list) -> dict:
len_ids = [len(feature["input_ids"]) for feature in features]
longest = max(len_ids)
input_ids = []
attention_mask_list = []
position_ids_list = []
labels_list = []
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
ids = feature["input_ids"]
seq_len = feature["seq_len"]
ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
labels = [-100] * (seq_len - 1) + ids[(seq_len - 1) :]
_ids = torch.LongTensor(ids)
attention_mask, position_ids = get_masks_and_position_ids(
ids, seq_len, longest, _ids.device, gmask=True
)
labels_list.append(torch.LongTensor(labels))
input_ids.append(_ids)
attention_mask_list.append(attention_mask)
position_ids_list.append(position_ids)
input_ids = torch.stack(input_ids)
labels = torch.stack(labels_list)
attention_mask = torch.stack(attention_mask_list)
position_ids = torch.stack(position_ids_list)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}


class ModifiedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_shape = inputs["input_ids"].shape
return model(
input_ids=inputs["input_ids"],
attention_mask=torch.ones(1, 1, input_shape[-1], input_shape[-1]).bool(),
labels=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
position_ids=inputs["position_ids"],
labels=inputs["labels"],
).loss


def data_collator(features: list) -> dict:
len_ids = [len(feature['input_ids']) for feature in features]
longest = max(len_ids)
input_ids = []
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x:-x[0]):
ids = feature['input_ids']
_ids = torch.LongTensor(ids + [150004] * (longest - ids_l))
input_ids.append(_ids)
return {"input_ids": torch.stack(input_ids)}


def save_tunable_parameters(model, path):
saved_params = {
k: v.to("cpu")
for k, v in model.named_parameters()
if v.requires_grad
k: v.to("cpu") for k, v in model.named_parameters() if v.requires_grad
}
torch.save(saved_params, path)


def main():
finetune_args, training_args = HfArgumentParser(
(FinetuneArguments, TrainingArguments)).parse_args_into_dataclasses()
(FinetuneArguments, TrainingArguments)
).parse_args_into_dataclasses()

# init model
model = ChatGLMForConditionalGeneration.from_pretrained(
"THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map='auto')
"THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map="auto"
)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.is_parallelizable = True
model.model_parallel = True
model.lm_head = CastOutputToFloat(model.lm_head)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
model.config.use_cache = (
False # silence the warnings. Please re-enable for inference!
)

# setup peft
peft_config = LoraConfig(
Expand All @@ -88,8 +145,9 @@ def main():
trainer.train()

# save model
save_tunable_parameters(model, os.path.join(training_args.output_dir, "chatglm-lora.pt"))

save_tunable_parameters(
model, os.path.join(training_args.output_dir, "chatglm-lora.pt")
)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def get_masks(seq, device):
return attention_mask

def get_position_ids(self, seq, mask_position, device, gmask=False):
context_length = seq.index(150004) + 1
context_length = len(seq)
if self.position_encoding_2d:
seq_length = seq.index(150004)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
Expand Down Expand Up @@ -949,7 +949,7 @@ def set_output_embeddings(self, new_embeddings):
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
attention_mask = torch.ones((1, context_length, context_length), device=device)
attention_mask.tril_()
attention_mask[..., :context_length - 1] = 1
attention_mask[..., :mask_position] = 1
attention_mask.unsqueeze_(1)
attention_mask = (attention_mask < 0.5).bool()

Expand Down
41 changes: 21 additions & 20 deletions tokenize_dataset_rows.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import argparse
import json
import random
import tqdm.auto as tqdm
from tqdm import tqdm

import datasets
import transformers


def preprocess(tokenizer, example, max_seq_length=512):
prompt = example["context"]
target = example["target"]
prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
target_ids = tokenizer.encode(
target, max_length=max_seq_length, truncation=True, add_special_tokens=False
)
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
return {"input_ids": input_ids, "seq_len": len(prompt_ids)}


def read_jsonl(path):
# Manually open because .splitlines is different from iterating over lines
tokenizer = transformers.AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True
)
with open(path, "r") as f:
for line in f:
yield json.loads(line)
for line in tqdm(f):
example = json.loads(line)
yield preprocess(tokenizer, example)


def main():
Expand All @@ -21,22 +34,10 @@ def main():
parser.add_argument("--max_seq_length", type=int, default=384)
args = parser.parse_args()

tokenizer = transformers.AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True
dataset = datasets.Dataset.from_generator(
lambda: read_jsonl("data/alpaca_data.jsonl")
)

all_tokenized = []
for elem in tqdm.tqdm(read_jsonl(args.jsonl_path)):
all_tokenized.append(
tokenizer.encode(
elem["text"], max_length=args.max_seq_length, truncation=True,
)
)
random.shuffle(all_tokenized)

ds = datasets.Dataset.from_dict({"input_ids": all_tokenized})
ds.save_to_disk(args.save_path)
print(f"Generated {len(all_tokenized)} samples.")
dataset.save_to_disk(args.save_path)


if __name__ == "__main__":
Expand Down

0 comments on commit b89ee70

Please sign in to comment.