Skip to content

Commit

Permalink
fix seq2seq predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 4, 2023
1 parent 4b20033 commit 3a53dd4
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pip install -r requirements.txt

If you want to enable LoRA(QLoRA) or Freeze quantization on Windows, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.

```
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```

Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pip install -r requirements.txt

对于 Windows 用户,若要启用 LoRA(QLoRA) 或 Freeze 的量化微调,请下载预构建的 `bitsandbytes` 包,目前支持 CUDA 11.1 到12.1。

```
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```

Expand Down
6 changes: 3 additions & 3 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ def prepare_args(
assert (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."

assert not (finetuning_args.finetuning_type == "p_tuning" and training_args.fp16), \
"Please disable fp16 training while using the P-Tuning v2 method."

if model_args.quantization_bit is not None:
assert finetuning_args.finetuning_type != "full" and finetuning_args.finetuning_type != "freeze", \
"Quantization is incompatible with the full-parameter and freeze tuning."

assert not (finetuning_args.finetuning_type == "p_tuning" and training_args.fp16), \
"FP16 training conflicts with quantized P-Tuning."

if not training_args.do_train:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")

Expand Down
40 changes: 30 additions & 10 deletions src/utils/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import json
import torch
import numpy as np
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from transformers.trainer import PredictionOutput
from transformers.tokenization_utils import PreTrainedTokenizer
Expand Down Expand Up @@ -36,11 +38,10 @@ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -
preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}

for pred, label in zip(preds, labels):
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
label = label[:len(label) - label_pad_len]
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)

for pred, label in zip(preds, labels):
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))

Expand All @@ -65,6 +66,25 @@ class Seq2SeqTrainerForChatGLM(PeftTrainer):
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""

def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
r"""
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
input_ids = inputs["input_ids"]
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None
return (loss, generated_tokens, labels)

def save_predictions(
self,
predict_results: PredictionOutput
Expand All @@ -79,13 +99,13 @@ def save_predictions(

output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")

preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)

with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(predict_results.predictions, predict_results.label_ids):
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
label = label[:len(label) - label_pad_len]

for pred, label in zip(preds, labels):
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
label = self.tokenizer.decode(label, skip_special_tokens=True)

Expand Down

0 comments on commit 3a53dd4

Please sign in to comment.