Skip to content

Commit

Permalink
implement mcts_alphago based on mcts_pure
Browse files Browse the repository at this point in the history
  • Loading branch information
cmusjtuliuyuan committed Feb 6, 2018
1 parent d342367 commit bf01312
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pyc
.DS_Store
175 changes: 175 additions & 0 deletions MCTS_AlphaGo_Style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
"""
A pure implementation of the Monte Carlo Tree Search (MCTS) in AlphaGo style
The original verion is written by:
@author: Junxiao Song
@github: https://github.com/junxiaosong/AlphaZero_Gomoku/blob/master/MCTS_Pure.py
It is modified to Upper Confidence Bounds for Trees (UCT) http://mcts.ai/index.html version by:
@author: Yuan Liu
@github: https://github.com/cmusjtuliuyuan/AlphaGoZero/blob/master/MCTS_Pure.py
"""
import numpy as np
import copy
import random
import math

def fake_NN(board):
# Get v
player = board.get_player_just_moved()
board = copy.deepcopy(board)
while len(board.get_moves()) != 0: # while state is non-terminal
board.do_move(random.choice(board.get_moves()))
v = board.get_result(player)
# Get p
p = np.ones(board.width * board.height)
return p, v

def Net_to_movelist(NN_fn, board, node):
p, v = NN_fn(board)
move_list = get_untried_moves(board, node)
move_priorP_pairs = []
for move in move_list:
move_priorP_pairs.append((move, p[move]))
return move_priorP_pairs, v

def get_untried_moves(board, node):
return set(board.get_moves())-node.get_already_moved()


class TreeNode(object):
""" A node in the game tree. Note wins is always from the viewpoint of playerJustMoved.
We need player_just_moved because n_wins depends on it.
"""

def __init__(self, parent, move, prior_p, player_just_moved):
self._parent = parent
self._player_just_moved = player_just_moved
self._move = move # the move that got us to this node - "None" for the root node
self._children = {} # a map from action to TreeNode
self._n_visits = 0.0 # N(self._parent.board, self._move)
self._Q = 0.0 # Q(self._parent.board, self._move)
self._P = prior_p # NN_P(self._parent.board)[self._move]

def expand(self, move_priorP_pairs, next_player):
""" Remove m from untriedMoves and add a new child node for this move.
Return the added child node
"""
for move, prior_p in move_priorP_pairs:
new_node = TreeNode(self, move, prior_p, next_player)
self._children[move] = new_node

def UCT_select(self, c_puct):
""" Use the UCB1 formula to select a child node. Often a constant UCTK is applied so we have
lambda c: c.Q+c.U to vary the amount of exploration versus exploitation.
"""
return max(self._children.iteritems(), key=lambda act_node: act_node[1].get_UCT_value(c_puct))

def get_UCT_value(self, c_puct):
U = c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits)
return self._Q + U

def update(self, leaf_value):
""" Update this node - one additional visit and result additional wins.
result must be from the viewpoint of playerJustmoved.
"""
self._Q = (self._n_visits * self._Q + leaf_value) / (self._n_visits + 1.0)
self._n_visits += 1.0

def get_already_moved(self):
return set(self._children.keys())

def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded).
"""
return len(self._children) == 0

def __repr__(self):
return "[M:" + str(self._move) +\
" Q:" + str(self._Q) +\
" M:" + str(self.get_already_moved()) + "]"

def TreeToString(self, indent):
s = self.IndentString(indent) + str(self)
for c_move, c_node in self._children.iteritems():
s += c_node.TreeToString(indent+1)
return s

def IndentString(self,indent):
s = "\n"
for i in range (1,indent+1):
s += "| "
return s

def UCT(root_board, n_iteration, NN_fn, temp=1.0, c_puct=5):
""" Conduct a UCT search for n_iterations starting from rootstate.
Return the best move from the rootstate.
Assumes 2 alternating players (player 1 starts), with game results in the range [0.0, 1.0]."""

rootnode = TreeNode(parent=None, move=None, prior_p=1.0,
player_just_moved=root_board.get_player_just_moved())

for i in range(n_iteration):
node = rootnode
board = copy.deepcopy(root_board)

# Selection: Starting at root node R, recursively select optimal child nodes (explained below)
# until a leaf node L is reached.
while not node.is_leaf(): #node is fully expanded and non-terminal
move, node = node.UCT_select(c_puct)
board.do_move(move)

# Expansion: If L is a not a terminal node (i.e. it does not end the game) then create one
# or more child nodes and select one C.
move_priorP_pairs, leaf_value = Net_to_movelist(NN_fn,board, node)
end, winner = board.game_end()
if not end:
node.expand(move_priorP_pairs, next_player=board.get_current_player())
else:
# for end state,return the "true" leaf_value
if winner == -1: # tie
leaf_value = 0.0
else:
leaf_value = 1.0 if winner == node._player_just_moved else -1.0

# Backpropagate: Update the current move sequence with the simulation result.
while node != None: # backpropagate from the expanded node and work back to the root node
# state is terminal. Update node with result from POV of node._player_just_moved
node.update(leaf_value)
leaf_value = - leaf_value
node = node._parent
# Output some information about the tree - can be omitted
# print rootnode.TreeToString(0)
# return the move and prob pairs
def softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
move_visits = [(move, node._n_visits) for move, node in rootnode._children.iteritems()]
moves, visits = zip(*move_visits)
move_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
return moves, move_probs


class MCTSPlayer(object):
"""AI player based on MCTS"""
def __init__(self, n_iteration=1000, NN_fn=fake_NN):
self._n_iteration=n_iteration
self._NN_fn = NN_fn

def set_player_ind(self, p):
self.player = p

def get_action(self, board, temp=1e-5, dirichlet_weight=.0):
sensible_moves = board.availables
if len(sensible_moves) > 0:
moves, move_probs = UCT(board, self._n_iteration, self._NN_fn, temp)
move = np.random.choice(moves, p=(1-dirichlet_weight)*move_probs \
+ dirichlet_weight*np.random.dirichlet(0.3*np.ones(len(move_probs))))
print 'output position:', move
return move
else:
print("WARNING: the board is full")

def __str__(self):
return "MCTS {}".format(self.player)
48 changes: 17 additions & 31 deletions MCTS_Pure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,19 @@ class TreeNode(object):
We need player_just_moved because n_wins depends on it.
"""

def __init__(self, parent, move, board):
def __init__(self, parent, move, player_just_moved):
self._parent = parent
self._player_just_moved = board.current_player
self._player_just_moved = player_just_moved
self._move = move # the move that got us to this node - "None" for the root node
# We use copy here because it dictionary is passed by reference
self._untried_moves = copy.copy(board.get_moves())
self._children = {} # a map from action to TreeNode
self._n_visits = 0.0
self._n_wins = 0.0

def expand(self, move, new_board):
""" Remove m from untriedMoves and add a new child node for this move.
def expand(self, move, player_just_moved):
""" Add a new child node for this move.
Return the added child node
"""
self._untried_moves.remove(move)
new_node = TreeNode(self, move, new_board)
new_node = TreeNode(self, move, player_just_moved)
self._children[move] = new_node
return new_node

Expand All @@ -56,25 +53,18 @@ def update(self, result):
self._n_visits += 1.0
self._n_wins += result

def get_moves(self):
"""Get avaliable moves
"""
return self._untried_moves
def get_already_moved(self):
return set(self._children.keys())

def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded).
"""
return len(self._children) == 0

def is_all_tried(self):
"""Check whether we have already tired all moves
"""
return len(self._untried_moves) == 0

def __repr__(self):
return "[M:" + str(self._move) +\
" W/V:" + str(self._n_wins) + "/" + str(self._n_visits) +\
" U:" + str(self._untried_moves) + "]"
" M:" + str(self.get_already_moved()) + "]"

def TreeToString(self, indent):
s = self.IndentString(indent) + str(self)
Expand All @@ -88,35 +78,31 @@ def IndentString(self,indent):
s += "| "
return s

def ChildrenToString(self):
s = ""
for c_move, c_node in self._children.iteritems():
s += str(c_node) + "\n"
return s

def UCT(root_board, n_iteration):
""" Conduct a UCT search for n_iterations starting from rootstate.
Return the best move from the rootstate.
Assumes 2 alternating players (player 1 starts), with game results in the range [0.0, 1.0]."""
def get_untried_moves(board, node):
return set(board.get_moves())-node.get_already_moved()

rootnode = TreeNode(parent=None, move=None, board=root_board)
rootnode = TreeNode(parent=None, move=None, player_just_moved=root_board.get_player_just_moved())

for i in range(n_iteration):
node = rootnode
board = copy.deepcopy(root_board)

# Selection: Starting at root node R, recursively select optimal child nodes (explained below)
# until a leaf node L is reached.
while node.is_all_tried() and not node.is_leaf(): #node is fully expanded and non-terminal
while len(get_untried_moves(board, node))==0 and not node.is_leaf(): #node is fully expanded and non-terminal
move, node = node.UCT_select()
board.do_move(move)

# Expansion: If L is a not a terminal node (i.e. it does not end the game) then create one
# or more child nodes and select one C.
if not node.is_all_tried(): # if we can expand
move = random.choice(node.get_moves())
if len(get_untried_moves(board, node))!=0: # if we can expand
move = random.sample(get_untried_moves(board, node),1)[0]
board.do_move(move)
node = node.expand(move, board)
node = node.expand(move, board.get_player_just_moved())

# Simulation: Run a simulated playout from C until a result is achieved.
while len(board.get_moves()) != 0: # while state is non-terminal
Expand All @@ -128,14 +114,14 @@ def UCT(root_board, n_iteration):
node.update(board.get_result(node._player_just_moved))
node = node._parent
# Output some information about the tree - can be omitted
print rootnode.TreeToString(0)
# print rootnode.TreeToString(0)
# return the move that was most visited
move, _ = max(rootnode._children.iteritems(), key=lambda act_node: act_node[1]._n_visits)
return move

class MCTSPlayer(object):
"""AI player based on MCTS"""
def __init__(self, n_iteration=400):
def __init__(self, n_iteration=1000):
self._n_iteration=n_iteration

def set_player_ind(self, p):
Expand Down
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
# AlphaGoZero
# MCTS play Gomoku

For simplicity, this game is 3 in a row, and the board is 5 * 5.
Run the following command:
```
python play_test.py
```
Player1 will be the MCTS, Player2 will be humman(you).
7 changes: 5 additions & 2 deletions game.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
class Board(object):
"""
board for the game
current_player means the one who play next
player_just_moved means the one who has already moved
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -113,10 +115,9 @@ def has_a_winner(self):

def get_result(self, playerJustMoved):
win, winner = self.has_a_winner()
current_player = self.players[0] if playerJustMoved == self.players[1] else self.players[1]
if not win:
return 0.5 # tie
if winner == current_player:
if winner == playerJustMoved:
return 1.0 # win
return 0.0 # fail

Expand All @@ -132,6 +133,8 @@ def game_end(self):
def get_current_player(self):
return self.current_player

def get_player_just_moved(self):
return self.players[0] if self.current_player == self.players[1] else self.players[1]

class Game(object):
"""
Expand Down
17 changes: 9 additions & 8 deletions play_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
The original verison is written by:
@author: Junxiao Song
@github: https://github.com/junxiaosong/AlphaZero_Gomoku/blob/master/play_human.py
"""
"""

from __future__ import print_function
from game import Board, Game
from MCTS_Pure import MCTSPlayer
#from MCTS_Pure import MCTSPlayer
from MCTS_AlphaGo_Style import MCTSPlayer
import argparse

class Human(object):
Expand All @@ -18,7 +19,7 @@ class Human(object):

def __init__(self):
self.player = None

def set_player_ind(self, p):
self.player = p

Expand Down Expand Up @@ -49,15 +50,15 @@ def main():
width, height = 5, 5
try:
board = Board(width=width, height=height, n_in_row=n)
game = Game(board)
game = Game(board)

player1 = Human() if args.player1=='human' else MCTSPlayer()
player2 = Human() if args.player2=='human' else MCTSPlayer()
player2 = Human() if args.player2=='human' else MCTSPlayer()

# set start_player=0 for human first
game.start_play(player1, player2, start_player=0, is_shown=1)
except KeyboardInterrupt:
print('\n\rquit')

if __name__ == '__main__':
if __name__ == '__main__':
main()

0 comments on commit bf01312

Please sign in to comment.