#ifndef __OPENINGS_HPP__
#define __OPENINGS_HPP__

#include "allocator.hpp"
#include <vector>
#include "display_node.hpp"

namespace mcts
{
  class openings
  {
    allocator alloc_;
    node* root_;
    const unsigned int nb_visits_before_expansion_;

    void copy(node* src, node* dst, allocator& alloc) const;
    template <typename Game>
    void expand(Game& game, node* n, uint16_t move, int value);
  public:
    template <typename Game>
    openings(const Game& game, unsigned int nb_visits_before_expansion = 2);
    void copy_to(node* root, allocator& alloc) const;
    template <typename Game>
    void update(Game& game, const std::vector<uint16_t>& moves, int value);
    friend std::ostream& operator<<(std::ostream& os, const openings& op);
  };

  template <typename Game>
  openings::openings(const Game& game, unsigned int nb_visits_before_expansion) : nb_visits_before_expansion_(nb_visits_before_expansion)
  {
    root_ = alloc_.allocate(1);
    unsigned int nb_children = game.number_of_moves();
    node* children = alloc_.allocate(nb_children);
    for (unsigned int i = 0; i < nb_children; ++i)
      {
        node* child = children + i;
        child->get_statistics_ref().count = 1;
        child->get_statistics_ref().value = 0;
      }
  }

  template <typename Game>
  void openings::expand(Game& game, node* n, uint16_t move, int value)
  {
    unsigned int count = n->get_statistics().count;
    if (count >= nb_visits_before_expansion_)
      {
        unsigned int nb_children = game.number_of_moves();
        node* children = alloc_.allocate(nb_children);
        for (unsigned int i = 0; i < nb_children; ++i)
	  {
            node* child = children + i;
            child->get_statistics_ref().count = 1;
            child->get_statistics_ref().value = 0;
	  }
        n->set_children(children);
        n->set_number_of_children(nb_children);
        children[move].update(value);
      }
  }

  template <typename Game>
  void openings::update(Game& game, const std::vector<uint16_t>& moves, int value)
  {
    node* pred = nullptr;
    node* current = root_;
    int k = 0;
    while (true)
      {
        current->update(value);
        value = -value;
        if (current->is_leaf()) break;
        uint16_t m = moves[k++];
        game.play(m);
        pred = current;
        current = current->get_children() + m;
        if (current->is_proven())
	  {
            if (current->is_lost()) pred->set_won();
            else
	      {
                const uint16_t nb_children = pred->get_number_of_children();
                node* const children = pred->get_children();
                bool all_won = true;
                for (uint16_t i = 0; i < nb_children; ++i)
		  {
                    node* child = children + i;
                    if (!child->is_won())
		      {
                        all_won = false;
                        break;
		      }
		  }
                if (all_won) pred->set_lost();
	      }
            return;
	  }
      }
    if (!game.end_of_game())
      {
        expand(game, current, moves[k], value);
      }
    else
      {
        if (value == 1) current->set_won();
        else if (value == -1)
	  {
            current->set_lost();
            if (pred != nullptr) pred->set_won();
	  }
      }
  }
}

#endif