From b833b13c9fb6745642ebd7a1a251047857f89b41 Mon Sep 17 00:00:00 2001 From: Francesco Bariatti <francesco.bariatti@insa-rennes.fr> Date: Mon, 18 Apr 2016 14:00:53 +0200 Subject: [PATCH] Revert "New MCTS version from Pascal" This reverts commit a6c5137c34380c7c92f0aa70c7f94e3c0fd74dbd. The program would exit with an error. --- src/mcts/allocator.cpp | 9 ++-- src/mcts/allocator.hpp | 4 +- src/mcts/mcts_two_players.hpp | 80 +++++++++++++++-------------------- src/mcts/node.cpp | 10 ----- src/mcts/node.hpp | 49 +++++++-------------- src/mcts/statistics.cpp | 2 +- src/mcts/statistics.hpp | 4 +- src/util/display_node.cpp | 10 ++--- src/util/display_node.hpp | 10 ++--- 9 files changed, 67 insertions(+), 111 deletions(-) diff --git a/src/mcts/allocator.cpp b/src/mcts/allocator.cpp index 198c1d5..3dc2309 100644 --- a/src/mcts/allocator.cpp +++ b/src/mcts/allocator.cpp @@ -56,11 +56,10 @@ namespace mcts return n; } - void allocator::copy(node* n1, node* n2, int prunning) + void allocator::copy(node* n1, node* n2, unsigned int prunning) { - n2->set_statistics(n1->get_statistics()); - n2->set_won(n1->get_won()); if (n1->get_statistics().count < prunning) return; + n2->set_statistics(n1->get_statistics()); unsigned int nb_children = n1->get_number_of_children(); n2->set_number_of_children(nb_children); if (nb_children == 0) return; @@ -73,15 +72,13 @@ namespace mcts } } - node* allocator::move(node* root, int prunning) + node* allocator::move(node* root, unsigned int prunning) { node* r = allocate_unsafe(1); - std::cout << root->size() << std::endl; copy(root, r, prunning); free_pointer = node_arena; node* res = allocate_unsafe(1); copy(r, res); - std::cout << res->size() << std::endl; return res; } } diff --git a/src/mcts/allocator.hpp b/src/mcts/allocator.hpp index ad9c764..f887fdc 100644 --- a/src/mcts/allocator.hpp +++ b/src/mcts/allocator.hpp @@ -12,14 +12,14 @@ namespace mcts node* free_pointer; node* allocate_unsafe(unsigned int size); - void copy(node* n1, node* n2, int prunning = 0); + void copy(node* n1, node* n2, unsigned int prunning = 0); public: allocator(unsigned int size = 100000000U); ~allocator(); node* allocate(unsigned int size); void clear(); - node* move(node* root, int prunning = 0); + node* move(node* root, unsigned int prunning = 0); unsigned int size() const; unsigned int free_space() const; }; diff --git a/src/mcts/mcts_two_players.hpp b/src/mcts/mcts_two_players.hpp index f2424c3..10555e0 100644 --- a/src/mcts/mcts_two_players.hpp +++ b/src/mcts/mcts_two_players.hpp @@ -73,12 +73,13 @@ namespace mcts uint16_t best_move_so_far = k; node* const children = parent->get_children(); node* best_child_so_far = children + k; - double v; + unsigned int count; + float v; for (uint16_t i = 0; i < nb_children; ++i) { - node* const child = children + k; - const unsigned int count = child->get_statistics().count; - v = -((double)child->get_statistics().value / count) + C_ * sqrt(log_of_N / count); + node* child = children + k; + count = child->get_statistics().count; + v = -child->get_statistics().value + C_ * sqrt(log_of_N / count); if (v > best_value_so_far) { best_value_so_far = v; @@ -86,8 +87,7 @@ namespace mcts best_move_so_far = k; } ++k; - k &= ~(-(k == nb_children)); - // if (k == nb_children) k = 0; + if (k == nb_children) k = 0; } if (best_child_so_far->is_proven()) { @@ -122,8 +122,8 @@ namespace mcts for (unsigned int i = 0; i < nb_children; ++i) { node* child = children + i; - child->get_statistics_ref().count = 10; - child->get_statistics_ref().value = -5; + child->get_statistics_ref().count = 1; + child->get_statistics_ref().value = 0; } n->set_children(children); n->set_number_of_children(nb_children); @@ -134,24 +134,21 @@ namespace mcts void mcts_two_players<Game>::think(const std::shared_ptr<Game>& game) { using namespace std; - const int VIRTUAL_LOSS = 2; const chrono::steady_clock::time_point start = chrono::steady_clock::now(); chrono::steady_clock::time_point now; mt19937& generator = mcts<Game>::generators[util::omp_util::get_thread_num()]; auto state = game->get_state(); vector<node*> visited(200); - // vector<uint16_t> moves(300); + vector<uint16_t> moves(200); unsigned int nb_iter = 0; do { int size = 1; node* current = this->root; visited[0] = current; - current->add_virtual_loss(VIRTUAL_LOSS); while (!game->end_of_game() && !current->is_leaf() && !current->is_proven()) { current = select(game, generator, current); - current->add_virtual_loss(VIRTUAL_LOSS); visited[size++] = current; } int game_value = 0; @@ -160,25 +157,25 @@ namespace mcts if (current->is_won()) game_value = 1; else { - game_value = -1; + game_value = -1; } } else if (game->end_of_game()) { - game_value = game->value_for_current_player(); - if (game_value > 0) + int v = game->value_for_current_player(); + if (v > 0) { - //game_value = 1; - /*if (new_version_)*/ current->set_won(); + game_value = 1; + if (new_version_) current->set_won(); } - else if (game_value < 0) + else if (v < 0) { - // game_value = -1; - /* if (new_version_) - {*/ - current->set_lost(); - if (size > 1) visited[size - 2]->set_won(); - // } + game_value = -1; + if (new_version_) + { + current->set_lost(); + if (size > 1) visited[size - 2]->set_won(); + } } } else @@ -186,27 +183,20 @@ namespace mcts uint8_t player = game->current_player(); expand(game, current); game->playout(generator); - //int v = game->value(player); - game_value = game->value(player); - // std::cout << game->player_to_string(player) << std::endl; - // std::cout << game_value << std::endl; - // std::cout << game->to_string() << std::endl; - // std::string wait; - // getline(std::cin, wait); - - // if (v > 0) game_value = 1; - // else if (v < 0) game_value = -1; + int v = game->value(player); + if (v > 0) game_value = 1; + else if (v < 0) game_value = -1; } for (int i = size - 1; i >= 0; --i) { - visited[i]->update(game_value, VIRTUAL_LOSS); + visited[i]->update(game_value); game_value = -game_value; } game->set_state(state); ++nb_iter; - if ((nb_iter & 0xFF) == 0) now = chrono::steady_clock::now(); + if ((nb_iter & 0x3F) == 0) now = chrono::steady_clock::now(); } - while ((nb_iter & 0xFF) != 0 || now < start + this->milliseconds); + while ((nb_iter & 0x3F) != 0 || now < start + this->milliseconds); } template <typename Game> @@ -218,10 +208,9 @@ namespace mcts #pragma omp parallel think(game::copy(this->game)); } - //std::ofstream ofs ("graph.gv", ofstream::out); - // util::display_node::node_to_dot(ofs, this->root, 2, 1000); - util::display_node::node_to_ascii(cout, this->root, 2); - std::cout << this->root->size() << std::endl; + // std::ofstream ofs ("graph.gv", ofstream::out); + // util::display_node::node_to_dot(ofs, this->root, 1000, 50); + util::display_node::node_to_ascii(cout, this->root, 1); // std::cout << "finished " << new_version_ << std::endl; // string _; // getline(cin, _); @@ -247,14 +236,11 @@ namespace mcts best_move_so_far = k; } ++k; - k &= ~(-(k == nb_children)); - //if (k == nb_children) k = 0; + if (k == nb_children) k = 0; } return best_move_so_far; } - const int PRUNNING = 0; - template <typename Game> void mcts_two_players<Game>::last_moves(uint16_t computer, uint16_t other) { @@ -265,7 +251,7 @@ namespace mcts } else { - this->root = alloc_.move(&this->root->get_children()[computer].get_children()[other], PRUNNING); + this->root = alloc_.move(&this->root->get_children()[computer].get_children()[other]); } } @@ -279,7 +265,7 @@ namespace mcts } else { - this->root = alloc_.move(&this->root->get_children()[move], PRUNNING); + this->root = alloc_.move(&this->root->get_children()[move], 20); } } diff --git a/src/mcts/node.cpp b/src/mcts/node.cpp index b2b6b43..bc40789 100644 --- a/src/mcts/node.cpp +++ b/src/mcts/node.cpp @@ -6,16 +6,6 @@ using namespace std; namespace mcts { - uint32_t node::size() const - { - uint32_t res = 1; - for (uint16_t i = 0; i < number_of_children; ++i) - { - res += children[i].size(); - } - return res; - } - string node::to_string() const { stringbuf buffer; diff --git a/src/mcts/node.hpp b/src/mcts/node.hpp index c98176e..591e6df 100644 --- a/src/mcts/node.hpp +++ b/src/mcts/node.hpp @@ -6,13 +6,15 @@ #include <iostream> #include <limits> +#define NODE_WON_VALUE 1e15 +#define NODE_LOST_VALUE -1e15 + namespace mcts { class node { statistics stats; bool flag = false; - signed char won = 0; uint16_t number_of_children = 0; node* children = nullptr; @@ -20,12 +22,9 @@ namespace mcts inline uint16_t get_winning_index() const; inline bool is_leaf() const; inline uint16_t get_number_of_children() const; - uint32_t size() const; inline node* get_children() const; inline void set_number_of_children(uint16_t n); inline void set_children(node* n); - inline void set_won(signed char v); - inline signed char get_won() const; inline void set_won(); inline void set_lost(); inline bool is_proven() const; @@ -35,8 +34,7 @@ namespace mcts inline statistics& get_statistics_ref(); inline void set_statistics(const statistics& s); inline bool test_and_set(); - inline void add_virtual_loss(int n); - inline void update(int value, int virtual_loss = 0); + inline void update(int value); inline void update_count(); std::string to_string() const; friend std::ostream& operator<<(std::ostream& os, const node& n); @@ -47,34 +45,24 @@ namespace mcts return is_won() || is_lost(); } - void node::set_won(signed char v) - { - won = v; - } - - signed char node::get_won() const - { - return won; - } - void node::set_won() { - won = 1; + stats.value = NODE_WON_VALUE; } void node::set_lost() { - won = -1; + stats.value = NODE_LOST_VALUE; } bool node::is_won() const { - return won == 1; + return stats.value > 1.1; } bool node::is_lost() const { - return won == -1; + return stats.value < -1.1; } bool node::is_leaf() const @@ -134,22 +122,17 @@ namespace mcts { ++stats.count; } - - void node::add_virtual_loss(int n) + + void node::update(int v) { -#pragma omp atomic update - stats.count += n; -#pragma omp atomic update - stats.value -= n; + unsigned int count = stats.count; + double value = stats.value; + ++count; + value += (v - value) / count; + stats.value = value; + stats.count = count; } - void node::update(int v, int virtual_loss) - { -#pragma omp atomic update - stats.count += 1 - virtual_loss; -#pragma omp atomic update - stats.value += v + virtual_loss; - } } #endif diff --git a/src/mcts/statistics.cpp b/src/mcts/statistics.cpp index d3fbaee..cafa2ad 100644 --- a/src/mcts/statistics.cpp +++ b/src/mcts/statistics.cpp @@ -10,7 +10,7 @@ namespace mcts { stringbuf buffer; ostream os(&buffer); - os << "(count: " << count << ", value: " << setprecision(2) << (double)value / count << ")"; + os << "(count: " << count << ", value: " << setprecision(2) << value << ")"; return buffer.str(); } } diff --git a/src/mcts/statistics.hpp b/src/mcts/statistics.hpp index 01c3ee6..78dfa69 100644 --- a/src/mcts/statistics.hpp +++ b/src/mcts/statistics.hpp @@ -7,8 +7,8 @@ namespace mcts { struct statistics { - int count = 0; - int value = 0; + unsigned int count = 0; + float value = 0; std::string to_string() const; }; } diff --git a/src/util/display_node.cpp b/src/util/display_node.cpp index 6b1900b..71cfd06 100644 --- a/src/util/display_node.cpp +++ b/src/util/display_node.cpp @@ -7,20 +7,20 @@ using namespace std; namespace util { - void display_node::node_to_ascii(std::ostream& os, const mcts::node* n, int depth, int prunning) + void display_node::node_to_ascii(std::ostream& os, const mcts::node* n, unsigned int depth, unsigned int prunning) { node_to_ascii(os, "", n, depth, prunning); os << endl; } - void display_node::node_to_dot(std::ostream& os, const mcts::node* n, int depth, int prunning) + void display_node::node_to_dot(std::ostream& os, const mcts::node* n, unsigned int depth, unsigned int prunning) { os << "digraph {" << endl; node_to_dot(os, 0, n, depth, prunning); os << "}" << endl; } - int display_node::node_to_dot(ostream& os, int id, const mcts::node* n, int depth, int prunning) + int display_node::node_to_dot(ostream& os, int id, const mcts::node* n, unsigned int depth, unsigned int prunning) { stringbuf buffer; ostream o (&buffer); @@ -42,7 +42,7 @@ namespace util return cpt + 1; } - void display_node::node_to_ascii(ostream& os, string prefix, const mcts::node* n, int depth, int prunning) + void display_node::node_to_ascii(ostream& os, string prefix, const mcts::node* n, unsigned int depth, unsigned int prunning) { string s; s = n->get_statistics().to_string(); @@ -63,7 +63,7 @@ namespace util children_to_ascii(os, new_prefix, n->get_number_of_children(), n->get_children(), depth, prunning); } - void display_node::children_to_ascii(ostream& os, string prefix, unsigned int nb_children, const mcts::node* children, int depth, int prunning) + void display_node::children_to_ascii(ostream& os, string prefix, unsigned int nb_children, const mcts::node* children, unsigned int depth, unsigned int prunning) { os << "+-"; node_to_ascii(os, prefix + "| ", children, depth - 1, prunning); diff --git a/src/util/display_node.hpp b/src/util/display_node.hpp index da3846c..86837fe 100644 --- a/src/util/display_node.hpp +++ b/src/util/display_node.hpp @@ -10,12 +10,12 @@ namespace util class display_node { public: - static void node_to_ascii(std::ostream& os, const mcts::node* n, int depth = std::numeric_limits<int>::max(), int prunning = 0); - static void node_to_dot(std::ostream& os, const mcts::node* n, int depth = std::numeric_limits<int>::max(), int prunning = 0); + static void node_to_ascii(std::ostream& os, const mcts::node* n, unsigned int depth = std::numeric_limits<unsigned int>::max(), unsigned int prunning = 0); + static void node_to_dot(std::ostream& os, const mcts::node* n, unsigned int depth = std::numeric_limits<int>::max(), unsigned int prunning = 0); private: - static void node_to_ascii(std::ostream& os, std::string prefix, const mcts::node* n, int depth, int prunning); - static void children_to_ascii(std::ostream& os, std::string prefix, unsigned int nb_children, const mcts::node* children, int depth, int prunning); - static int node_to_dot(std::ostream& os, int id, const mcts::node* n, int depth, int prunning); + static void node_to_ascii(std::ostream& os, std::string prefix, const mcts::node* n, unsigned int depth, unsigned int prunning); + static void children_to_ascii(std::ostream& os, std::string prefix, unsigned int nb_children, const mcts::node* children, unsigned int depth, unsigned int prunning); + static int node_to_dot(std::ostream& os, int id, const mcts::node* n, unsigned int depth, unsigned int prunning); }; } -- GitLab