diff --git a/src/mcts/allocator.cpp b/src/mcts/allocator.cpp index 3dc230918912a87494e93c83bfaf910be9bb0593..198c1d571baa4e7ed30d19333732f6f71fb59c2d 100644 --- a/src/mcts/allocator.cpp +++ b/src/mcts/allocator.cpp @@ -56,10 +56,11 @@ namespace mcts return n; } - void allocator::copy(node* n1, node* n2, unsigned int prunning) + void allocator::copy(node* n1, node* n2, int prunning) { - if (n1->get_statistics().count < prunning) return; n2->set_statistics(n1->get_statistics()); + n2->set_won(n1->get_won()); + if (n1->get_statistics().count < prunning) return; unsigned int nb_children = n1->get_number_of_children(); n2->set_number_of_children(nb_children); if (nb_children == 0) return; @@ -72,13 +73,15 @@ namespace mcts } } - node* allocator::move(node* root, unsigned int prunning) + node* allocator::move(node* root, int prunning) { node* r = allocate_unsafe(1); + std::cout << root->size() << std::endl; copy(root, r, prunning); free_pointer = node_arena; node* res = allocate_unsafe(1); copy(r, res); + std::cout << res->size() << std::endl; return res; } } diff --git a/src/mcts/allocator.hpp b/src/mcts/allocator.hpp index f887fdc79c0ac430980c76aee6ba995e94b13200..ad9c7647b01962f16f62a6c9adc7204f4e04328a 100644 --- a/src/mcts/allocator.hpp +++ b/src/mcts/allocator.hpp @@ -12,14 +12,14 @@ namespace mcts node* free_pointer; node* allocate_unsafe(unsigned int size); - void copy(node* n1, node* n2, unsigned int prunning = 0); + void copy(node* n1, node* n2, int prunning = 0); public: allocator(unsigned int size = 100000000U); ~allocator(); node* allocate(unsigned int size); void clear(); - node* move(node* root, unsigned int prunning = 0); + node* move(node* root, int prunning = 0); unsigned int size() const; unsigned int free_space() const; }; diff --git a/src/mcts/mcts_two_players.hpp b/src/mcts/mcts_two_players.hpp index 10555e0b570e68d2c42fe4100cd51a515c68e37b..f2424c3b40a910b3f2ed25f0d50e6ed159f1ef67 100644 --- a/src/mcts/mcts_two_players.hpp +++ b/src/mcts/mcts_two_players.hpp @@ -73,13 +73,12 @@ namespace mcts uint16_t best_move_so_far = k; node* const children = parent->get_children(); node* best_child_so_far = children + k; - unsigned int count; - float v; + double v; for (uint16_t i = 0; i < nb_children; ++i) { - node* child = children + k; - count = child->get_statistics().count; - v = -child->get_statistics().value + C_ * sqrt(log_of_N / count); + node* const child = children + k; + const unsigned int count = child->get_statistics().count; + v = -((double)child->get_statistics().value / count) + C_ * sqrt(log_of_N / count); if (v > best_value_so_far) { best_value_so_far = v; @@ -87,7 +86,8 @@ namespace mcts best_move_so_far = k; } ++k; - if (k == nb_children) k = 0; + k &= ~(-(k == nb_children)); + // if (k == nb_children) k = 0; } if (best_child_so_far->is_proven()) { @@ -122,8 +122,8 @@ namespace mcts for (unsigned int i = 0; i < nb_children; ++i) { node* child = children + i; - child->get_statistics_ref().count = 1; - child->get_statistics_ref().value = 0; + child->get_statistics_ref().count = 10; + child->get_statistics_ref().value = -5; } n->set_children(children); n->set_number_of_children(nb_children); @@ -134,21 +134,24 @@ namespace mcts void mcts_two_players<Game>::think(const std::shared_ptr<Game>& game) { using namespace std; + const int VIRTUAL_LOSS = 2; const chrono::steady_clock::time_point start = chrono::steady_clock::now(); chrono::steady_clock::time_point now; mt19937& generator = mcts<Game>::generators[util::omp_util::get_thread_num()]; auto state = game->get_state(); vector<node*> visited(200); - vector<uint16_t> moves(200); + // vector<uint16_t> moves(300); unsigned int nb_iter = 0; do { int size = 1; node* current = this->root; visited[0] = current; + current->add_virtual_loss(VIRTUAL_LOSS); while (!game->end_of_game() && !current->is_leaf() && !current->is_proven()) { current = select(game, generator, current); + current->add_virtual_loss(VIRTUAL_LOSS); visited[size++] = current; } int game_value = 0; @@ -157,25 +160,25 @@ namespace mcts if (current->is_won()) game_value = 1; else { - game_value = -1; + game_value = -1; } } else if (game->end_of_game()) { - int v = game->value_for_current_player(); - if (v > 0) + game_value = game->value_for_current_player(); + if (game_value > 0) { - game_value = 1; - if (new_version_) current->set_won(); + //game_value = 1; + /*if (new_version_)*/ current->set_won(); } - else if (v < 0) + else if (game_value < 0) { - game_value = -1; - if (new_version_) - { - current->set_lost(); - if (size > 1) visited[size - 2]->set_won(); - } + // game_value = -1; + /* if (new_version_) + {*/ + current->set_lost(); + if (size > 1) visited[size - 2]->set_won(); + // } } } else @@ -183,20 +186,27 @@ namespace mcts uint8_t player = game->current_player(); expand(game, current); game->playout(generator); - int v = game->value(player); - if (v > 0) game_value = 1; - else if (v < 0) game_value = -1; + //int v = game->value(player); + game_value = game->value(player); + // std::cout << game->player_to_string(player) << std::endl; + // std::cout << game_value << std::endl; + // std::cout << game->to_string() << std::endl; + // std::string wait; + // getline(std::cin, wait); + + // if (v > 0) game_value = 1; + // else if (v < 0) game_value = -1; } for (int i = size - 1; i >= 0; --i) { - visited[i]->update(game_value); + visited[i]->update(game_value, VIRTUAL_LOSS); game_value = -game_value; } game->set_state(state); ++nb_iter; - if ((nb_iter & 0x3F) == 0) now = chrono::steady_clock::now(); + if ((nb_iter & 0xFF) == 0) now = chrono::steady_clock::now(); } - while ((nb_iter & 0x3F) != 0 || now < start + this->milliseconds); + while ((nb_iter & 0xFF) != 0 || now < start + this->milliseconds); } template <typename Game> @@ -208,9 +218,10 @@ namespace mcts #pragma omp parallel think(game::copy(this->game)); } - // std::ofstream ofs ("graph.gv", ofstream::out); - // util::display_node::node_to_dot(ofs, this->root, 1000, 50); - util::display_node::node_to_ascii(cout, this->root, 1); + //std::ofstream ofs ("graph.gv", ofstream::out); + // util::display_node::node_to_dot(ofs, this->root, 2, 1000); + util::display_node::node_to_ascii(cout, this->root, 2); + std::cout << this->root->size() << std::endl; // std::cout << "finished " << new_version_ << std::endl; // string _; // getline(cin, _); @@ -236,11 +247,14 @@ namespace mcts best_move_so_far = k; } ++k; - if (k == nb_children) k = 0; + k &= ~(-(k == nb_children)); + //if (k == nb_children) k = 0; } return best_move_so_far; } + const int PRUNNING = 0; + template <typename Game> void mcts_two_players<Game>::last_moves(uint16_t computer, uint16_t other) { @@ -251,7 +265,7 @@ namespace mcts } else { - this->root = alloc_.move(&this->root->get_children()[computer].get_children()[other]); + this->root = alloc_.move(&this->root->get_children()[computer].get_children()[other], PRUNNING); } } @@ -265,7 +279,7 @@ namespace mcts } else { - this->root = alloc_.move(&this->root->get_children()[move], 20); + this->root = alloc_.move(&this->root->get_children()[move], PRUNNING); } } diff --git a/src/mcts/node.cpp b/src/mcts/node.cpp index bc40789a0f24a01cf8960a9d3247a8069b4fa81f..b2b6b438ea242fd7869b17c5a831eff04e92f6e7 100644 --- a/src/mcts/node.cpp +++ b/src/mcts/node.cpp @@ -6,6 +6,16 @@ using namespace std; namespace mcts { + uint32_t node::size() const + { + uint32_t res = 1; + for (uint16_t i = 0; i < number_of_children; ++i) + { + res += children[i].size(); + } + return res; + } + string node::to_string() const { stringbuf buffer; diff --git a/src/mcts/node.hpp b/src/mcts/node.hpp index 591e6dfd4dfc41ad77b2e44d9bd8df25bb5e7c1d..c98176e7b688e7a0ddcf3d0f5555e978155a52d5 100644 --- a/src/mcts/node.hpp +++ b/src/mcts/node.hpp @@ -6,15 +6,13 @@ #include <iostream> #include <limits> -#define NODE_WON_VALUE 1e15 -#define NODE_LOST_VALUE -1e15 - namespace mcts { class node { statistics stats; bool flag = false; + signed char won = 0; uint16_t number_of_children = 0; node* children = nullptr; @@ -22,9 +20,12 @@ namespace mcts inline uint16_t get_winning_index() const; inline bool is_leaf() const; inline uint16_t get_number_of_children() const; + uint32_t size() const; inline node* get_children() const; inline void set_number_of_children(uint16_t n); inline void set_children(node* n); + inline void set_won(signed char v); + inline signed char get_won() const; inline void set_won(); inline void set_lost(); inline bool is_proven() const; @@ -34,7 +35,8 @@ namespace mcts inline statistics& get_statistics_ref(); inline void set_statistics(const statistics& s); inline bool test_and_set(); - inline void update(int value); + inline void add_virtual_loss(int n); + inline void update(int value, int virtual_loss = 0); inline void update_count(); std::string to_string() const; friend std::ostream& operator<<(std::ostream& os, const node& n); @@ -45,24 +47,34 @@ namespace mcts return is_won() || is_lost(); } + void node::set_won(signed char v) + { + won = v; + } + + signed char node::get_won() const + { + return won; + } + void node::set_won() { - stats.value = NODE_WON_VALUE; + won = 1; } void node::set_lost() { - stats.value = NODE_LOST_VALUE; + won = -1; } bool node::is_won() const { - return stats.value > 1.1; + return won == 1; } bool node::is_lost() const { - return stats.value < -1.1; + return won == -1; } bool node::is_leaf() const @@ -122,17 +134,22 @@ namespace mcts { ++stats.count; } - - void node::update(int v) + + void node::add_virtual_loss(int n) { - unsigned int count = stats.count; - double value = stats.value; - ++count; - value += (v - value) / count; - stats.value = value; - stats.count = count; +#pragma omp atomic update + stats.count += n; +#pragma omp atomic update + stats.value -= n; } + void node::update(int v, int virtual_loss) + { +#pragma omp atomic update + stats.count += 1 - virtual_loss; +#pragma omp atomic update + stats.value += v + virtual_loss; + } } #endif diff --git a/src/mcts/statistics.cpp b/src/mcts/statistics.cpp index cafa2adac339ebe8362b206de8426ad4bde63a7f..d3fbaee7aaa4bc2ea14b43c94c1b85130d16f577 100644 --- a/src/mcts/statistics.cpp +++ b/src/mcts/statistics.cpp @@ -10,7 +10,7 @@ namespace mcts { stringbuf buffer; ostream os(&buffer); - os << "(count: " << count << ", value: " << setprecision(2) << value << ")"; + os << "(count: " << count << ", value: " << setprecision(2) << (double)value / count << ")"; return buffer.str(); } } diff --git a/src/mcts/statistics.hpp b/src/mcts/statistics.hpp index 78dfa693946e7f46f184ad358c6cd85a4acaa0c8..01c3ee67f9d281dbf49286f78c5ebaa535c21c01 100644 --- a/src/mcts/statistics.hpp +++ b/src/mcts/statistics.hpp @@ -7,8 +7,8 @@ namespace mcts { struct statistics { - unsigned int count = 0; - float value = 0; + int count = 0; + int value = 0; std::string to_string() const; }; } diff --git a/src/util/display_node.cpp b/src/util/display_node.cpp index 71cfd06187fd839417a79fa3dae9cacb3d4abfdf..6b1900bf078ef63fbdeba090830361998eb8975a 100644 --- a/src/util/display_node.cpp +++ b/src/util/display_node.cpp @@ -7,20 +7,20 @@ using namespace std; namespace util { - void display_node::node_to_ascii(std::ostream& os, const mcts::node* n, unsigned int depth, unsigned int prunning) + void display_node::node_to_ascii(std::ostream& os, const mcts::node* n, int depth, int prunning) { node_to_ascii(os, "", n, depth, prunning); os << endl; } - void display_node::node_to_dot(std::ostream& os, const mcts::node* n, unsigned int depth, unsigned int prunning) + void display_node::node_to_dot(std::ostream& os, const mcts::node* n, int depth, int prunning) { os << "digraph {" << endl; node_to_dot(os, 0, n, depth, prunning); os << "}" << endl; } - int display_node::node_to_dot(ostream& os, int id, const mcts::node* n, unsigned int depth, unsigned int prunning) + int display_node::node_to_dot(ostream& os, int id, const mcts::node* n, int depth, int prunning) { stringbuf buffer; ostream o (&buffer); @@ -42,7 +42,7 @@ namespace util return cpt + 1; } - void display_node::node_to_ascii(ostream& os, string prefix, const mcts::node* n, unsigned int depth, unsigned int prunning) + void display_node::node_to_ascii(ostream& os, string prefix, const mcts::node* n, int depth, int prunning) { string s; s = n->get_statistics().to_string(); @@ -63,7 +63,7 @@ namespace util children_to_ascii(os, new_prefix, n->get_number_of_children(), n->get_children(), depth, prunning); } - void display_node::children_to_ascii(ostream& os, string prefix, unsigned int nb_children, const mcts::node* children, unsigned int depth, unsigned int prunning) + void display_node::children_to_ascii(ostream& os, string prefix, unsigned int nb_children, const mcts::node* children, int depth, int prunning) { os << "+-"; node_to_ascii(os, prefix + "| ", children, depth - 1, prunning); diff --git a/src/util/display_node.hpp b/src/util/display_node.hpp index 86837fe8629e48580c13852d91b9b0c94bc19579..da3846ce29b4249bb5037c87e3f9699c4599412b 100644 --- a/src/util/display_node.hpp +++ b/src/util/display_node.hpp @@ -10,12 +10,12 @@ namespace util class display_node { public: - static void node_to_ascii(std::ostream& os, const mcts::node* n, unsigned int depth = std::numeric_limits<unsigned int>::max(), unsigned int prunning = 0); - static void node_to_dot(std::ostream& os, const mcts::node* n, unsigned int depth = std::numeric_limits<int>::max(), unsigned int prunning = 0); + static void node_to_ascii(std::ostream& os, const mcts::node* n, int depth = std::numeric_limits<int>::max(), int prunning = 0); + static void node_to_dot(std::ostream& os, const mcts::node* n, int depth = std::numeric_limits<int>::max(), int prunning = 0); private: - static void node_to_ascii(std::ostream& os, std::string prefix, const mcts::node* n, unsigned int depth, unsigned int prunning); - static void children_to_ascii(std::ostream& os, std::string prefix, unsigned int nb_children, const mcts::node* children, unsigned int depth, unsigned int prunning); - static int node_to_dot(std::ostream& os, int id, const mcts::node* n, unsigned int depth, unsigned int prunning); + static void node_to_ascii(std::ostream& os, std::string prefix, const mcts::node* n, int depth, int prunning); + static void children_to_ascii(std::ostream& os, std::string prefix, unsigned int nb_children, const mcts::node* children, int depth, int prunning); + static int node_to_dot(std::ostream& os, int id, const mcts::node* n, int depth, int prunning); }; }