-
Bariatti Francesco authoredBariatti Francesco authored
openings.hpp 3.22 KiB
#ifndef __OPENINGS_HPP__
#define __OPENINGS_HPP__
#include "allocator.hpp"
#include <vector>
#include "display_node.hpp"
namespace mcts
{
class openings
{
allocator alloc_;
node* root_;
const unsigned int nb_visits_before_expansion_;
void copy(node* src, node* dst, allocator& alloc) const;
template <typename Game>
void expand(Game& game, node* n, uint16_t move, int value);
public:
template <typename Game>
openings(const Game& game, unsigned int nb_visits_before_expansion = 2);
void copy_to(node* root, allocator& alloc) const;
template <typename Game>
void update(Game& game, const std::vector<uint16_t>& moves, int value);
friend std::ostream& operator<<(std::ostream& os, const openings& op);
};
template <typename Game>
openings::openings(const Game& game, unsigned int nb_visits_before_expansion) : nb_visits_before_expansion_(nb_visits_before_expansion)
{
root_ = alloc_.allocate(1);
unsigned int nb_children = game.number_of_moves();
node* children = alloc_.allocate(nb_children);
for (unsigned int i = 0; i < nb_children; ++i)
{
node* child = children + i;
child->get_statistics_ref().count = 1;
child->get_statistics_ref().value = 0;
}
}
template <typename Game>
void openings::expand(Game& game, node* n, uint16_t move, int value)
{
unsigned int count = n->get_statistics().count;
if (count >= nb_visits_before_expansion_)
{
unsigned int nb_children = game.number_of_moves();
node* children = alloc_.allocate(nb_children);
for (unsigned int i = 0; i < nb_children; ++i)
{
node* child = children + i;
child->get_statistics_ref().count = 1;
child->get_statistics_ref().value = 0;
}
n->set_children(children);
n->set_number_of_children(nb_children);
children[move].update(value);
}
}
template <typename Game>
void openings::update(Game& game, const std::vector<uint16_t>& moves, int value)
{
node* pred = nullptr;
node* current = root_;
int k = 0;
while (true)
{
current->update(value);
value = -value;
if (current->is_leaf()) break;
uint16_t m = moves[k++];
game.play(m);
pred = current;
current = current->get_children() + m;
if (current->is_proven())
{
if (current->is_lost()) pred->set_won();
else
{
const uint16_t nb_children = pred->get_number_of_children();
node* const children = pred->get_children();
bool all_won = true;
for (uint16_t i = 0; i < nb_children; ++i)
{
node* child = children + i;
if (!child->is_won())
{
all_won = false;
break;
}
}
if (all_won) pred->set_lost();
}
return;
}
}
if (!game.end_of_game())
{
expand(game, current, moves[k], value);
}
else
{
if (value == 1) current->set_won();
else if (value == -1)
{
current->set_lost();
if (pred != nullptr) pred->set_won();
}
}
}
}
#endif