-
Notifications
You must be signed in to change notification settings - Fork 4
/
test.py
61 lines (51 loc) · 2.14 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import json
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataset import load_data, NerCollate
from transformers import AutoModel, AutoTokenizer
from config_utils import ConfigParser
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
data_name = "msra"
train_args_path = "./checkpoint/{}/train_trainer/adapter_model/train_args.json".format(data_name)
with open(train_args_path, "r") as fp:
args = json.load(fp)
config_parser = ConfigParser(args)
args = config_parser.parse_main()
model = AutoModel.from_pretrained(args.model_dir, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
model.eval()
model = PeftModel.from_pretrained(model, os.path.join(args.save_dir, "adapter_model"), torch_dtype=torch.float32, trust_remote_code=True)
model.half().cuda()
model.eval()
test_data = load_data(args.dev_path)
ner_collate = NerCollate(args, tokenizer)
test_dataloader = DataLoader(test_data,
batch_size=args.train_batch_size,
shuffle=True,
drop_last=False,
collate_fn=ner_collate.collate_fn)
# 找到labels中预测开始的部分
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
with torch.no_grad():
all_preds = []
all_trues = []
for step, batch in enumerate(tqdm(test_dataloader, ncols=100)):
for k,v in batch.items():
batch[k] = v.cuda()
labels = batch["labels"].detach().cpu().numpy()
output = model(**batch)
logits = output.logits
preds = torch.argmax(logits, -1).detach().cpu().numpy()
preds = np.where(labels != -100, preds, tokenizer.pad_token_id)
preds = preds[:, :-1]
labels = labels[:, 1:]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
all_preds.extend(decoded_preds)
all_trues.extend(decoded_labels)
print("预测:", all_preds[:20])
print("真实:", all_trues[:20])