Skip to content

Commit

Permalink
Update predict.py
Browse files Browse the repository at this point in the history
  • Loading branch information
taishan1994 committed May 26, 2023
1 parent 0f20f75 commit f306ec0
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,8 @@
logits = output.logits
preds = torch.argmax(logits, -1).detach().cpu().numpy()
preds = np.where(labels != -100, preds, tokenizer.pad_token_id)
for pre, lab in zip(preds.tolist(), labels.tolist()):
d1 = tokenizer.convert_ids_to_tokens(lab)
d2 = tokenizer.convert_ids_to_tokens(pre)
print(d1)
print(d2)
for dd1, dd2 in zip(d1, d2):
print(dd1, dd2)
print("="*100)
preds = preds[:, :-1]
labels = labels[:, 1:]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True, decoder_end_token_id=eos_token_id)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True, decoder_end_token_id=eos_token_id)
Expand Down

0 comments on commit f306ec0

Please sign in to comment.