-
Notifications
You must be signed in to change notification settings - Fork 5
/
extract_logits.py
95 lines (67 loc) · 2.76 KB
/
extract_logits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import os
from tqdm import tqdm
import torch
import torch.nn as nn
torch.set_printoptions(sci_mode=False,precision=4)
class ClsFC(nn.Module):
def __init__(self,num_cls,in_dim):
super(ClsFC,self).__init__()
self.fc = nn.Linear(in_dim,num_cls)
@torch.no_grad()
def forward(self,x):
return self.fc(x)
def create_model():
weight = "training_dir/COCO34ORfreq32_4gpu/model_0180000.pth"
# the weight has been released,
# refer to https://github.com/Dawn-LX/VidVRD-tracklets#quick-start
state_dict = torch.load(weight)
state_dict = state_dict["model"]
# print(state_dict.keys())
cls_state_dict = {
"fc.weight":state_dict['module.roi_heads.box.predictor.cls_score.weight'].cpu(),
"fc.bias":state_dict['module.roi_heads.box.predictor.cls_score.bias'].cpu()
}
model = ClsFC(81,1024)
model.load_state_dict(cls_state_dict)
return model
if __name__ == "__main__":
#NOTE originally in 10.12.86.103
dim_feature = 1024
num_cls = 81
cls_model = create_model()
device = torch.device("cuda:0")
cls_model = cls_model.cuda(device)
load_dir = "/home/gkf/deepSORT/tracking_results/miss60_minscore0p3/"
save_dir = "/home/gkf/deepSORT/tracking_results/miss60_minscore0p3/VidORtrain_freq1_logits"
res_path_list = []
for part_id in range(1,15):
part_name = "VidORtrain_freq1_part{:02d}".format(part_id)
part_dir = os.path.join(load_dir,part_name)
paths = sorted(os.listdir(part_dir))
paths = [os.path.join(part_dir,p) for p in paths]
res_path_list += paths
assert len(res_path_list) == 7000
for load_path in tqdm(res_path_list):
track_res = np.load(load_path,allow_pickle=True)
batch_features = []
for box_info in track_res:
if not isinstance(box_info,list):
box_info = box_info.tolist()
assert len(box_info) == 6 or len(box_info) == 12 + dim_feature,"len(box_info)=={}".format(len(box_info))
if len(box_info) == 12 + dim_feature:
cat_id = box_info[7]
roi_feature = box_info[12:]
batch_features.append(roi_feature)
assert cat_id > 0
else:
batch_features.append([0]*dim_feature)
batch_features = torch.tensor(batch_features).float()
assert len(track_res) == batch_features.shape[0]
cls_logits = cls_model(batch_features.to(device)) # shape == (N,81)
cls_logits = cls_logits.cpu().numpy()
save_path = os.path.join(
save_dir,load_path.split('/')[-1].split('.')[0] + "_logits.npy"
)
np.save(save_path,cls_logits)
print("finish")