Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 21, 2023
1 parent 4c7b17f commit 91c62b7
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,9 @@ def preprocess_supervised_dataset(examples):
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)

if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens
source_ids = source_ids[:data_args.max_source_length - 2]
source_ids = source_ids[-data_args.max_source_length + 2:] # truncating from left side
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
target_ids = target_ids[:data_args.max_target_length - 1] # truncating from right side

input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids)

Expand All @@ -419,9 +419,9 @@ def preprocess_evaluation_dataset(examples):
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)

if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens
source_ids = source_ids[:data_args.max_source_length - 2]
source_ids = source_ids[-data_args.max_source_length + 2:] # truncating from left side
if len(target_ids) > data_args.max_target_length - 2: # gmask and bos tokens
target_ids = target_ids[:data_args.max_target_length - 2]
target_ids = target_ids[:data_args.max_target_length - 2] # truncating from right side

input_ids = tokenizer.build_inputs_with_special_tokens(source_ids)
labels = tokenizer.build_inputs_with_special_tokens(target_ids)
Expand All @@ -439,11 +439,11 @@ def preprocess_pairwise_dataset(examples):
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)

if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens
source_ids = source_ids[:data_args.max_source_length - 2]
source_ids = source_ids[-data_args.max_source_length + 2:] # truncating from left side
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
accept_ids = accept_ids[:data_args.max_target_length - 1] # truncating from right side
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
reject_ids = reject_ids[:data_args.max_target_length - 1] # truncating from right side

accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error
reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids)
Expand Down

0 comments on commit 91c62b7

Please sign in to comment.