Skip to content

Commit

Permalink
Update MCTS
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuandong Tian committed Oct 11, 2017
1 parent 3b015f8 commit dcecea8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
2 changes: 2 additions & 0 deletions elf/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self):
("mcts_use_prior", dict(action="store_true")),
("mcts_baseline", 3.0),
("mcts_baseline_sigma", 0.3),
("mcts_pseudo_games", 0),
],
on_get_args = self._on_get_args
)
Expand Down Expand Up @@ -61,5 +62,6 @@ def initialize(self, co):
mcts.use_prior = args.mcts_use_prior
mcts.baseline = args.mcts_baseline
mcts.baseline_sigma = args.mcts_baseline_sigma
mcts.pseudo_games = args.pseudo_games


12 changes: 9 additions & 3 deletions elf/tree_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class TSOneThreadT {
using Node = NodeT<S, A>;
using NodeAlloc = NodeAllocT<S, A>;

TSOneThreadT(int thread_id, const TSOptions& options) : thread_id_(thread_id), options_(options) {
TSOneThreadT(int thread_id, const TSOptions& options)
: thread_id_(thread_id), options_(options), rng_(thread_id) {
if (options_.verbose) {
output_.reset(new ofstream("tree_search_" + std::to_string(thread_id) + ".txt"));
}
Expand Down Expand Up @@ -137,6 +138,8 @@ class TSOneThreadT {
Semaphore<RunInfo> state_ready_;
std::unique_ptr<ostream> output_;

std::mt19937 rng_;

static float sigmoid(float x) {
return 1.0 / (1 + exp(-x));
}
Expand Down Expand Up @@ -173,9 +176,12 @@ class TSOneThreadT {
auto func = [&](const Node *n) -> NodeResponseT<A> & {
return actor.evaluate(*n->s_ptr());
};
return node->ExpandIfNecessary(func, alloc);
auto init = [&](EdgeInfo &info) {
info.acc_reward = rng_() % (options_.pseudo_games + 1);
info.n = options_.pseudo_games;
};
return node->ExpandIfNecessary(func, init, alloc);
}

};

// Mcts algorithm
Expand Down
5 changes: 3 additions & 2 deletions elf/tree_search_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class NodeT : public NodeBaseT<S> {
int count() const { return count_; }
float value() const { return V_; }

template <typename ExpandFunc>
VisitType ExpandIfNecessary(ExpandFunc func, NodeAlloc &alloc) {
template <typename ExpandFunc, typename InitFunc>
VisitType ExpandIfNecessary(ExpandFunc func, InitFunc init, NodeAlloc &alloc) {
if (visited_) return NODE_ALREADY_VISITED;

// Otherwise visit.
Expand All @@ -100,6 +100,7 @@ class NodeT : public NodeBaseT<S> {
for (const pair<A, float> & action_pair : resp.pi) {
auto res = sa_.insert(make_pair(action_pair.first, EdgeInfo(action_pair.second)));
res.first->second.next = alloc.Alloc();
init(res.first->second);
}

// value
Expand Down
6 changes: 5 additions & 1 deletion elf/tree_search_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ struct TSOptions {
float baseline = 3.0;
float baseline_sigma = 0.3;

// Pre-added pseudo playout.
int pseudo_games = 0;

string info() const {
stringstream ss;
ss << "Maximal #moves (0 = no constraint): " << max_num_moves << endl;
Expand All @@ -32,12 +35,13 @@ struct TSOptions {
ss << "Verbose: " << (verbose ? "True" : "False") << endl;
ss << "Use prior: " << (use_prior ? "True" : "False") << endl;
ss << "Persistent tree: " << (persistent_tree ? "True" : "False") << endl;
ss << "#Pseudo game: " << pseudo_games << endl;
ss << "Pick method: " << pick_method << endl;
ss << "Baseline: " << baseline << ", baseline_sigma: " << baseline_sigma << endl;
return ss.str();
}

REGISTER_PYBIND_FIELDS(max_num_moves, num_threads, num_rollout_per_thread, verbose, persistent_tree, pick_method, use_prior, baseline, baseline_sigma);
REGISTER_PYBIND_FIELDS(max_num_moves, num_threads, num_rollout_per_thread, verbose, persistent_tree, pick_method, use_prior, baseline, baseline_sigma, pseudo_games);
};

} // namespace mcts

0 comments on commit dcecea8

Please sign in to comment.