forked from vimar-gu/Bias-Eliminate-DA-ReID
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
86 lines (69 loc) · 2.25 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import argparse
import os
import sys
import torch
import time
import numpy as np
import random
from torch.backends import cudnn
sys.path.append('.')
from config import cfg
from data import make_test_data_loader
from data.datasets import init_dataset
from engine.tester import tester
from modeling import build_model, build_camera_model
from layers import make_loss
from solver import make_optimizer, WarmupMultiStepLR
from utils.logger import setup_logger
def setup_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
setup_seed(1)
def test(cfg):
logger = setup_logger("reid_baseline", cfg.OUTPUT_DIR)
logger.info("Running with config:\n{}".format(cfg))
# prepare dataset
test_data_loader, num_query = make_test_data_loader(cfg)
# prepare model
model = build_model(cfg, num_classes=[700,500])
logger.info('Path to the checkpoint of model:%s' %(cfg.TEST.WEIGHT))
model.load_param(cfg.TEST.WEIGHT, 'self')
camera_model = build_camera_model(cfg, num_classes=5)
logger.info('Path to the checkpoint of model:%s' %(cfg.TEST.CAMERA_WEIGHT))
camera_model.load_param(cfg.TEST.CAMERA_WEIGHT, 'self')
tester(cfg,
model,
camera_model,
test_data_loader,
num_query
)
def main():
parser = argparse.ArgumentParser(description="ReID Baseline Training")
parser.add_argument(
"--config_file", default="", help="path to config file", type=str
)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
output_dir = cfg.OUTPUT_DIR
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if cfg.MODEL.DEVICE == "cuda":
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
cudnn.benchmark = True
test(cfg)
if __name__ == '__main__':
main()