Skip to content

Commit

Permalink
Fix eos_token_id error
Browse files Browse the repository at this point in the history
  • Loading branch information
ypwhs authored Mar 24, 2023
1 parent 19b626b commit 58b195e
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tokenize_dataset_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,29 @@
import transformers


def preprocess(tokenizer, example, max_seq_length):
def preprocess(tokenizer, config, example, max_seq_length):
prompt = example["context"]
target = example["target"]
prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
target_ids = tokenizer.encode(
target, max_length=max_seq_length, truncation=True, add_special_tokens=False
)
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
target,
max_length=max_seq_length,
truncation=True,
add_special_tokens=False)
input_ids = prompt_ids + target_ids + [config.eos_token_id]
return {"input_ids": input_ids, "seq_len": len(prompt_ids)}


def read_jsonl(path, max_seq_length, skip_overlength=False):
model_name = "THUDM/chatglm-6b"
tokenizer = transformers.AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True
)
model_name, trust_remote_code=True)
config = transformers.AutoConfig.from_pretrained(
model_name, trust_remote_code=True, device_map='auto')
with open(path, "r") as f:
for line in tqdm(f.readlines()):
example = json.loads(line)
feature = preprocess(tokenizer, example, max_seq_length)
feature = preprocess(tokenizer, config, example, max_seq_length)
if skip_overlength and len(feature["input_ids"]) > max_seq_length:
continue
feature["input_ids"] = feature["input_ids"][:max_seq_length]
Expand Down

0 comments on commit 58b195e

Please sign in to comment.