Skip to content

Commit

Permalink
Merge pull request mymusise#73 from ypwhs/patch-3
Browse files Browse the repository at this point in the history
Fix eos_token_id error
  • Loading branch information
mymusise authored Mar 24, 2023
2 parents 19b626b + 158384b commit b60af47
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
8 changes: 3 additions & 5 deletions examples/finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@
"\n",
"def data_collator(features: list) -> dict:\n",
" len_ids = [len(feature[\"input_ids\"]) for feature in features]\n",
" longest = max(len_ids) + 1\n",
" longest = max(len_ids)\n",
" input_ids = []\n",
" attention_mask_list = []\n",
" position_ids_list = []\n",
Expand All @@ -288,10 +288,9 @@
" labels = (\n",
" [-100] * (seq_len - 1)\n",
" + ids[(seq_len - 1) :]\n",
" + [tokenizer.eos_token_id]\n",
" + [-100] * (longest - ids_l - 1)\n",
" + [-100] * (longest - ids_l)\n",
" )\n",
" ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)\n",
" ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)\n",
" _ids = torch.LongTensor(ids)\n",
" attention_mask, position_ids = get_masks_and_position_ids(\n",
" ids, seq_len, longest, _ids.device, gmask=False\n",
Expand All @@ -312,7 +311,6 @@
" }\n",
"\n",
"\n",
"\n",
"class ModifiedTrainer(Trainer):\n",
"\n",
" def compute_loss(self, model, inputs, return_outputs=False):\n",
Expand Down
7 changes: 3 additions & 4 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_masks_and_position_ids(

def data_collator(features: list) -> dict:
len_ids = [len(feature["input_ids"]) for feature in features]
longest = max(len_ids) + 1
longest = max(len_ids)
input_ids = []
attention_mask_list = []
position_ids_list = []
Expand All @@ -71,10 +71,9 @@ def data_collator(features: list) -> dict:
labels = (
[-100] * (seq_len - 1)
+ ids[(seq_len - 1) :]
+ [tokenizer.eos_token_id]
+ [-100] * (longest - ids_l - 1)
+ [-100] * (longest - ids_l)
)
ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)
ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
_ids = torch.LongTensor(ids)
attention_mask, position_ids = get_masks_and_position_ids(
ids, seq_len, longest, _ids.device, gmask=False
Expand Down
18 changes: 11 additions & 7 deletions tokenize_dataset_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,29 @@
import transformers


def preprocess(tokenizer, example, max_seq_length):
def preprocess(tokenizer, config, example, max_seq_length):
prompt = example["context"]
target = example["target"]
prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
target_ids = tokenizer.encode(
target, max_length=max_seq_length, truncation=True, add_special_tokens=False
)
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
target,
max_length=max_seq_length,
truncation=True,
add_special_tokens=False)
input_ids = prompt_ids + target_ids + [config.eos_token_id]
return {"input_ids": input_ids, "seq_len": len(prompt_ids)}


def read_jsonl(path, max_seq_length, skip_overlength=False):
model_name = "THUDM/chatglm-6b"
tokenizer = transformers.AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True
)
model_name, trust_remote_code=True)
config = transformers.AutoConfig.from_pretrained(
model_name, trust_remote_code=True, device_map='auto')
with open(path, "r") as f:
for line in tqdm(f.readlines()):
example = json.loads(line)
feature = preprocess(tokenizer, example, max_seq_length)
feature = preprocess(tokenizer, config, example, max_seq_length)
if skip_overlength and len(feature["input_ids"]) > max_seq_length:
continue
feature["input_ids"] = feature["input_ids"][:max_seq_length]
Expand Down

0 comments on commit b60af47

Please sign in to comment.