-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
82 lines (67 loc) · 2.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging
import random
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'):
"""将序列padding到同一长度
"""
if isinstance(inputs[0], (np.ndarray, list)):
if length is None:
length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
elif not hasattr(length, '__getitem__'):
length = [length]
slices = [np.s_[:length[i]] for i in range(seq_dims)]
slices = tuple(slices) if len(slices) > 1 else slices[0]
pad_width = [(0, 0) for _ in np.shape(inputs[0])]
outputs = []
for x in inputs:
x = x[slices]
for i in range(seq_dims):
if mode == 'post':
pad_width[i] = (0, length[i] - np.shape(x)[i])
elif mode == 'pre':
pad_width[i] = (length[i] - np.shape(x)[i], 0)
else:
raise ValueError('"mode" argument must be "post" or "pre".')
x = np.pad(x, pad_width, 'constant', constant_values=value)
outputs.append(x)
return np.array(outputs)
elif isinstance(inputs[0], torch.Tensor):
assert mode == 'post', '"mode" argument must be "post" when element is torch.Tensor'
if length is not None:
inputs = [i[:length] for i in inputs]
return pad_sequence(inputs, padding_value=value, batch_first=True)
else:
raise ValueError('"input" argument must be tensor/list/ndarray.')
def set_seed(seed=123):
"""
设置随机数种子,保证实验可重现
:param seed:
:return:
"""
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
def set_logger(log_path):
"""
配置log
:param log_path:s
:return:
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler
if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers):
file_handler = logging.FileHandler(log_path)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers):
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
def getstr(string):
return string.replace('"', "")