Skip to content
Snippets Groups Projects
Commit a6c5137c authored by Bariatti Francesco's avatar Bariatti Francesco
Browse files

New MCTS version from Pascal

Added a virtual loss when a thread is exploring so that all threads don't explore the same branch
parent 37aef8b3
No related branches found
No related tags found
No related merge requests found
......@@ -56,10 +56,11 @@ namespace mcts
return n;
}
void allocator::copy(node* n1, node* n2, unsigned int prunning)
void allocator::copy(node* n1, node* n2, int prunning)
{
if (n1->get_statistics().count < prunning) return;
n2->set_statistics(n1->get_statistics());
n2->set_won(n1->get_won());
if (n1->get_statistics().count < prunning) return;
unsigned int nb_children = n1->get_number_of_children();
n2->set_number_of_children(nb_children);
if (nb_children == 0) return;
......@@ -72,13 +73,15 @@ namespace mcts
}
}
node* allocator::move(node* root, unsigned int prunning)
node* allocator::move(node* root, int prunning)
{
node* r = allocate_unsafe(1);
std::cout << root->size() << std::endl;
copy(root, r, prunning);
free_pointer = node_arena;
node* res = allocate_unsafe(1);
copy(r, res);
std::cout << res->size() << std::endl;
return res;
}
}
......@@ -12,14 +12,14 @@ namespace mcts
node* free_pointer;
node* allocate_unsafe(unsigned int size);
void copy(node* n1, node* n2, unsigned int prunning = 0);
void copy(node* n1, node* n2, int prunning = 0);
public:
allocator(unsigned int size = 100000000U);
~allocator();
node* allocate(unsigned int size);
void clear();
node* move(node* root, unsigned int prunning = 0);
node* move(node* root, int prunning = 0);
unsigned int size() const;
unsigned int free_space() const;
};
......
......@@ -73,13 +73,12 @@ namespace mcts
uint16_t best_move_so_far = k;
node* const children = parent->get_children();
node* best_child_so_far = children + k;
unsigned int count;
float v;
double v;
for (uint16_t i = 0; i < nb_children; ++i)
{
node* child = children + k;
count = child->get_statistics().count;
v = -child->get_statistics().value + C_ * sqrt(log_of_N / count);
node* const child = children + k;
const unsigned int count = child->get_statistics().count;
v = -((double)child->get_statistics().value / count) + C_ * sqrt(log_of_N / count);
if (v > best_value_so_far)
{
best_value_so_far = v;
......@@ -87,7 +86,8 @@ namespace mcts
best_move_so_far = k;
}
++k;
if (k == nb_children) k = 0;
k &= ~(-(k == nb_children));
// if (k == nb_children) k = 0;
}
if (best_child_so_far->is_proven())
{
......@@ -122,8 +122,8 @@ namespace mcts
for (unsigned int i = 0; i < nb_children; ++i)
{
node* child = children + i;
child->get_statistics_ref().count = 1;
child->get_statistics_ref().value = 0;
child->get_statistics_ref().count = 10;
child->get_statistics_ref().value = -5;
}
n->set_children(children);
n->set_number_of_children(nb_children);
......@@ -134,21 +134,24 @@ namespace mcts
void mcts_two_players<Game>::think(const std::shared_ptr<Game>& game)
{
using namespace std;
const int VIRTUAL_LOSS = 2;
const chrono::steady_clock::time_point start = chrono::steady_clock::now();
chrono::steady_clock::time_point now;
mt19937& generator = mcts<Game>::generators[util::omp_util::get_thread_num()];
auto state = game->get_state();
vector<node*> visited(200);
vector<uint16_t> moves(200);
// vector<uint16_t> moves(300);
unsigned int nb_iter = 0;
do
{
int size = 1;
node* current = this->root;
visited[0] = current;
current->add_virtual_loss(VIRTUAL_LOSS);
while (!game->end_of_game() && !current->is_leaf() && !current->is_proven())
{
current = select(game, generator, current);
current->add_virtual_loss(VIRTUAL_LOSS);
visited[size++] = current;
}
int game_value = 0;
......@@ -157,25 +160,25 @@ namespace mcts
if (current->is_won()) game_value = 1;
else
{
game_value = -1;
game_value = -1;
}
}
else if (game->end_of_game())
{
int v = game->value_for_current_player();
if (v > 0)
game_value = game->value_for_current_player();
if (game_value > 0)
{
game_value = 1;
if (new_version_) current->set_won();
//game_value = 1;
/*if (new_version_)*/ current->set_won();
}
else if (v < 0)
else if (game_value < 0)
{
game_value = -1;
if (new_version_)
{
current->set_lost();
if (size > 1) visited[size - 2]->set_won();
}
// game_value = -1;
/* if (new_version_)
{*/
current->set_lost();
if (size > 1) visited[size - 2]->set_won();
// }
}
}
else
......@@ -183,20 +186,27 @@ namespace mcts
uint8_t player = game->current_player();
expand(game, current);
game->playout(generator);
int v = game->value(player);
if (v > 0) game_value = 1;
else if (v < 0) game_value = -1;
//int v = game->value(player);
game_value = game->value(player);
// std::cout << game->player_to_string(player) << std::endl;
// std::cout << game_value << std::endl;
// std::cout << game->to_string() << std::endl;
// std::string wait;
// getline(std::cin, wait);
// if (v > 0) game_value = 1;
// else if (v < 0) game_value = -1;
}
for (int i = size - 1; i >= 0; --i)
{
visited[i]->update(game_value);
visited[i]->update(game_value, VIRTUAL_LOSS);
game_value = -game_value;
}
game->set_state(state);
++nb_iter;
if ((nb_iter & 0x3F) == 0) now = chrono::steady_clock::now();
if ((nb_iter & 0xFF) == 0) now = chrono::steady_clock::now();
}
while ((nb_iter & 0x3F) != 0 || now < start + this->milliseconds);
while ((nb_iter & 0xFF) != 0 || now < start + this->milliseconds);
}
template <typename Game>
......@@ -208,9 +218,10 @@ namespace mcts
#pragma omp parallel
think(game::copy(this->game));
}
// std::ofstream ofs ("graph.gv", ofstream::out);
// util::display_node::node_to_dot(ofs, this->root, 1000, 50);
util::display_node::node_to_ascii(cout, this->root, 1);
//std::ofstream ofs ("graph.gv", ofstream::out);
// util::display_node::node_to_dot(ofs, this->root, 2, 1000);
util::display_node::node_to_ascii(cout, this->root, 2);
std::cout << this->root->size() << std::endl;
// std::cout << "finished " << new_version_ << std::endl;
// string _;
// getline(cin, _);
......@@ -236,11 +247,14 @@ namespace mcts
best_move_so_far = k;
}
++k;
if (k == nb_children) k = 0;
k &= ~(-(k == nb_children));
//if (k == nb_children) k = 0;
}
return best_move_so_far;
}
const int PRUNNING = 0;
template <typename Game>
void mcts_two_players<Game>::last_moves(uint16_t computer, uint16_t other)
{
......@@ -251,7 +265,7 @@ namespace mcts
}
else
{
this->root = alloc_.move(&this->root->get_children()[computer].get_children()[other]);
this->root = alloc_.move(&this->root->get_children()[computer].get_children()[other], PRUNNING);
}
}
......@@ -265,7 +279,7 @@ namespace mcts
}
else
{
this->root = alloc_.move(&this->root->get_children()[move], 20);
this->root = alloc_.move(&this->root->get_children()[move], PRUNNING);
}
}
......
......@@ -6,6 +6,16 @@ using namespace std;
namespace mcts
{
uint32_t node::size() const
{
uint32_t res = 1;
for (uint16_t i = 0; i < number_of_children; ++i)
{
res += children[i].size();
}
return res;
}
string node::to_string() const
{
stringbuf buffer;
......
......@@ -6,15 +6,13 @@
#include <iostream>
#include <limits>
#define NODE_WON_VALUE 1e15
#define NODE_LOST_VALUE -1e15
namespace mcts
{
class node
{
statistics stats;
bool flag = false;
signed char won = 0;
uint16_t number_of_children = 0;
node* children = nullptr;
......@@ -22,9 +20,12 @@ namespace mcts
inline uint16_t get_winning_index() const;
inline bool is_leaf() const;
inline uint16_t get_number_of_children() const;
uint32_t size() const;
inline node* get_children() const;
inline void set_number_of_children(uint16_t n);
inline void set_children(node* n);
inline void set_won(signed char v);
inline signed char get_won() const;
inline void set_won();
inline void set_lost();
inline bool is_proven() const;
......@@ -34,7 +35,8 @@ namespace mcts
inline statistics& get_statistics_ref();
inline void set_statistics(const statistics& s);
inline bool test_and_set();
inline void update(int value);
inline void add_virtual_loss(int n);
inline void update(int value, int virtual_loss = 0);
inline void update_count();
std::string to_string() const;
friend std::ostream& operator<<(std::ostream& os, const node& n);
......@@ -45,24 +47,34 @@ namespace mcts
return is_won() || is_lost();
}
void node::set_won(signed char v)
{
won = v;
}
signed char node::get_won() const
{
return won;
}
void node::set_won()
{
stats.value = NODE_WON_VALUE;
won = 1;
}
void node::set_lost()
{
stats.value = NODE_LOST_VALUE;
won = -1;
}
bool node::is_won() const
{
return stats.value > 1.1;
return won == 1;
}
bool node::is_lost() const
{
return stats.value < -1.1;
return won == -1;
}
bool node::is_leaf() const
......@@ -122,17 +134,22 @@ namespace mcts
{
++stats.count;
}
void node::update(int v)
void node::add_virtual_loss(int n)
{
unsigned int count = stats.count;
double value = stats.value;
++count;
value += (v - value) / count;
stats.value = value;
stats.count = count;
#pragma omp atomic update
stats.count += n;
#pragma omp atomic update
stats.value -= n;
}
void node::update(int v, int virtual_loss)
{
#pragma omp atomic update
stats.count += 1 - virtual_loss;
#pragma omp atomic update
stats.value += v + virtual_loss;
}
}
#endif
......@@ -10,7 +10,7 @@ namespace mcts
{
stringbuf buffer;
ostream os(&buffer);
os << "(count: " << count << ", value: " << setprecision(2) << value << ")";
os << "(count: " << count << ", value: " << setprecision(2) << (double)value / count << ")";
return buffer.str();
}
}
......@@ -7,8 +7,8 @@ namespace mcts
{
struct statistics
{
unsigned int count = 0;
float value = 0;
int count = 0;
int value = 0;
std::string to_string() const;
};
}
......
......@@ -7,20 +7,20 @@ using namespace std;
namespace util
{
void display_node::node_to_ascii(std::ostream& os, const mcts::node* n, unsigned int depth, unsigned int prunning)
void display_node::node_to_ascii(std::ostream& os, const mcts::node* n, int depth, int prunning)
{
node_to_ascii(os, "", n, depth, prunning);
os << endl;
}
void display_node::node_to_dot(std::ostream& os, const mcts::node* n, unsigned int depth, unsigned int prunning)
void display_node::node_to_dot(std::ostream& os, const mcts::node* n, int depth, int prunning)
{
os << "digraph {" << endl;
node_to_dot(os, 0, n, depth, prunning);
os << "}" << endl;
}
int display_node::node_to_dot(ostream& os, int id, const mcts::node* n, unsigned int depth, unsigned int prunning)
int display_node::node_to_dot(ostream& os, int id, const mcts::node* n, int depth, int prunning)
{
stringbuf buffer;
ostream o (&buffer);
......@@ -42,7 +42,7 @@ namespace util
return cpt + 1;
}
void display_node::node_to_ascii(ostream& os, string prefix, const mcts::node* n, unsigned int depth, unsigned int prunning)
void display_node::node_to_ascii(ostream& os, string prefix, const mcts::node* n, int depth, int prunning)
{
string s;
s = n->get_statistics().to_string();
......@@ -63,7 +63,7 @@ namespace util
children_to_ascii(os, new_prefix, n->get_number_of_children(), n->get_children(), depth, prunning);
}
void display_node::children_to_ascii(ostream& os, string prefix, unsigned int nb_children, const mcts::node* children, unsigned int depth, unsigned int prunning)
void display_node::children_to_ascii(ostream& os, string prefix, unsigned int nb_children, const mcts::node* children, int depth, int prunning)
{
os << "+-";
node_to_ascii(os, prefix + "| ", children, depth - 1, prunning);
......
......@@ -10,12 +10,12 @@ namespace util
class display_node
{
public:
static void node_to_ascii(std::ostream& os, const mcts::node* n, unsigned int depth = std::numeric_limits<unsigned int>::max(), unsigned int prunning = 0);
static void node_to_dot(std::ostream& os, const mcts::node* n, unsigned int depth = std::numeric_limits<int>::max(), unsigned int prunning = 0);
static void node_to_ascii(std::ostream& os, const mcts::node* n, int depth = std::numeric_limits<int>::max(), int prunning = 0);
static void node_to_dot(std::ostream& os, const mcts::node* n, int depth = std::numeric_limits<int>::max(), int prunning = 0);
private:
static void node_to_ascii(std::ostream& os, std::string prefix, const mcts::node* n, unsigned int depth, unsigned int prunning);
static void children_to_ascii(std::ostream& os, std::string prefix, unsigned int nb_children, const mcts::node* children, unsigned int depth, unsigned int prunning);
static int node_to_dot(std::ostream& os, int id, const mcts::node* n, unsigned int depth, unsigned int prunning);
static void node_to_ascii(std::ostream& os, std::string prefix, const mcts::node* n, int depth, int prunning);
static void children_to_ascii(std::ostream& os, std::string prefix, unsigned int nb_children, const mcts::node* children, int depth, int prunning);
static int node_to_dot(std::ostream& os, int id, const mcts::node* n, int depth, int prunning);
};
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment