-
Notifications
You must be signed in to change notification settings - Fork 1
/
Train.py
114 lines (88 loc) · 4.15 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import r2_score
from Net_Model import Net
import os
import numpy as np
if __name__ == '__main__':
save_path = "./models/net_arcloss.pth"
train_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, ], std=[0.5, ])
]))
test_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, ], std=[0.5,])
]))
train_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=100, num_workers=4)
test_loader = data.DataLoader(dataset=test_data, shuffle=True, batch_size=100, num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net().to(device)
if os.path.exists(save_path):
net.load_state_dict(torch.load(save_path))
else:
print("No Param")
'CrossEntropyLoss()=torch.log(torch.softmax(None))+nn.NLLLoss()'
'CrossEntropyLoss()=log_softmax() + NLLLoss() '
'nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合'
loss_fn = nn.NLLLoss()
# optimizer = torch.optim.Adam(net.parameters())
# optimizer = torch.optim.SGD(net.parameters(),lr=1e-3, momentum=0.9)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
# optimizer = torch.optim.SGD(net.parameters(),lr=1e-3, momentum=0.9, weight_decay=0.0005)
epoch = 0
while True:
feat_loader = []
label_loader = []
for i, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device)
feature, output = net.forward(x)
# print(feature.shape) # torch.Size([100, 2])
# print(output.shape) # torch.Size([100, 10])
loss = loss_fn(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print(y.shape) # torch.Size([100])
feat_loader.append(feature)
label_loader.append(y)
if i % 100 == 0:
print("epoch:", epoch, "i:", i, "arcsoftmax_loss:", loss.item())
feat = torch.cat(feat_loader, 0)
labels = torch.cat(label_loader, 0)
net.visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch)
torch.save(net.state_dict(), save_path)
with torch.no_grad():
feat_loader2 = []
label_loader2 = []
label_list = []
output_list = []
for i, (x, y) in enumerate(test_loader): # 加验证集
x = x.to(device)
y = y.to(device)
feature, output = net.forward(x)
loss = loss_fn(output, y)
feat_loader2.append(feature) # 方便画图
label_loader2.append(y)
output = torch.argmax(output, 1)
# print(output.shape) # torch.Size([100])
# print(y.shape) # torch.Size([100])
label_list.append(y.data.cpu().numpy().reshape(-1)) # 方便做r2_score
output_list.append(output.data.cpu().numpy().reshape(-1))
if i % 600 == 0:
print("epoch:", epoch, "i:", i, "validate_loss:", loss.item())
feat2 = torch.cat(feat_loader2, 0)
labels2 = torch.cat(label_loader2, 0)
net.visualize2(feat2.data.cpu().numpy(), labels2.data.cpu().numpy(), epoch)
# r2_score评估
r2 = r2_score(label_list, output_list)
print("验证集第{}轮, r2_score评估分类精度为:{}".format(epoch, r2))
epoch += 1
if epoch == 30:
break