-
Notifications
You must be signed in to change notification settings - Fork 1
/
NNet.py
120 lines (97 loc) · 4.04 KB
/
NNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import sys
import time
import numpy as np
from tqdm import tqdm
sys.path.append('../../')
from utils import *
from NeuralNet import NeuralNet
import torch
import torch.optim as optim
from GoNNet import GoNNet as onnet
args = dotdict({
'lr': 0.001,
'dropout': 0.3,
'epochs': 3, # 10
'batch_size': 64,
'cuda': torch.cuda.is_available(),
'num_channels': 512,
})
class NNetWrapper(NeuralNet):
def __init__(self, game):
self.nnet = onnet(game, args)
self.board_x, self.board_y = game.getBoardSize()
self.action_size = game.getActionSize()
if args.cuda:
self.nnet.cuda()
def train(self, examples):
"""
examples: list of examples, each example is of form (board, pi, v)
"""
optimizer = optim.Adam(self.nnet.parameters())
for epoch in range(args.epochs):
print('EPOCH ::: ' + str(epoch + 1))
self.nnet.train()
pi_losses = AverageMeter()
v_losses = AverageMeter()
batch_count = int(len(examples) / args.batch_size)
t = tqdm(range(batch_count), desc='Training Net')
for _ in t:
sample_ids = np.random.randint(len(examples), size=args.batch_size)
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
boards = torch.FloatTensor(np.array(boards).astype(np.float64))
target_pis = torch.FloatTensor(np.array(pis))
target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))
# predict
if args.cuda:
boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda()
# compute output
out_pi, out_v = self.nnet(boards)
l_pi = self.loss_pi(target_pis, out_pi)
l_v = self.loss_v(target_vs, out_v)
total_loss = l_pi + l_v
# record loss
pi_losses.update(l_pi.item(), boards.size(0))
v_losses.update(l_v.item(), boards.size(0))
t.set_postfix(Loss_pi=pi_losses, Loss_v=v_losses)
# compute gradient and do SGD step
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
def predict(self, board):
"""
board: np array with board
"""
# timing
start = time.time()
# preparing input
board = torch.FloatTensor(board.astype(np.float64))
if args.cuda: board = board.contiguous().cuda()
board = board.view(1, self.board_x, self.board_y)
self.nnet.eval()
with torch.no_grad():
pi, v = self.nnet(board)
# print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start))
return torch.exp(pi).data.cpu().numpy()[0], v.data.cpu().numpy()[0]
def loss_pi(self, targets, outputs):
return -torch.sum(targets * outputs) / targets.size()[0]
def loss_v(self, targets, outputs):
return torch.sum((targets - outputs.view(-1)) ** 2) / targets.size()[0]
def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
filepath = os.path.join(folder, filename)
if not os.path.exists(folder):
print("Checkpoint Directory does not exist! Making directory {}".format(folder))
os.mkdir(folder)
else:
print("Checkpoint Directory exists! ")
torch.save({
'state_dict': self.nnet.state_dict(),
}, filepath)
def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
# https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
raise ("No model in path {}".format(filepath))
map_location = None if args.cuda else 'cpu'
checkpoint = torch.load(filepath, map_location=map_location)
self.nnet.load_state_dict(checkpoint['state_dict'])