-
Notifications
You must be signed in to change notification settings - Fork 56
/
common.py
126 lines (112 loc) · 4.44 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
import os,joblib
import torch,random
import torch.nn as nn
import cv2,imageio,PIL
from libtiff import TIFFfile
def readImg(img_path):
"""
When reading local image data, because the format of the data set is not uniform,
the reading method needs to be considered.
Default using pillow to read the desired RGB format img
"""
img_format = img_path.split(".")[-1]
try:
#在win下读取tif格式图像在转np的时候异常终止,暂时没找到合适的读取方式,Linux下直接用PIl读取无问题
img = PIL.Image.open(img_path)
except Exception as e:
ValueError("Reading failed, please check path of dataset,",img_path)
return img
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class AverageMeter(object):
"""Computes and stores the average and current value for calculate average loss"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# print(self.val)
# formulate a learning rate decay strategy
def make_lr_schedule(lr_epoch,lr_value):
lr_schedule = np.zeros(lr_epoch[-1])
for l in range(len(lr_epoch)):
if l == 0:
lr_schedule[0:lr_epoch[l]] = lr_value[l]
else:
lr_schedule[lr_epoch[l - 1]:lr_epoch[l]] = lr_value[l]
return lr_schedule
# Save configuration information
def save_args(args,save_path):
if not os.path.exists(save_path):
os.makedirs('%s' % save_path)
print('Config info -----')
for arg in vars(args):
print('%s: %s' % (arg, getattr(args, arg)))
with open('%s/args.txt' % save_path, 'w') as f:
for arg in vars(args):
print('%s: %s' % (arg, getattr(args, arg)), file=f)
joblib.dump(args, '%s/args.pkl' % save_path)
print('\033[0;33m================config infomation has been saved=================\033[0m')
# Seed for repeatability
def setpu_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic=True
random.seed(seed)
# Round off
def dict_round(dic,num):
for key,value in dic.items():
dic[key] = round(value,num)
return dic
# params initialization
def weight_initV1(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def weight_initV2(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def weight_initV3(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)