IAtari
Genetic algorithm generating AI capable to play Atari2600 games.
master.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <boost/asio/io_service.hpp>
3 #include <boost/asio/ip/tcp.hpp>
4 #include <boost/asio/spawn.hpp>
5 #include <boost/chrono.hpp>
6 #include <iostream>
7 #include <memory>
8 #include <list>
9 #include <queue>
10 #include <vector>
11 #include "message.hpp"
12 #include "connection.hpp"
13 
14 template <typename Res, typename Params>
15 class master
16 {
17  int port;
18  int nb_eval_by_parameter;
19  int nb_eval_by_slave;
20 
21  boost::asio::io_service io_service;
22  std::list<connection_ptr> slaves;
23  std::queue<message::request<Params>> work_queue;
24  std::list<message::result<Res>> intermediate_results;
25 
26  struct session : public std::enable_shared_from_this<session>
27  {
30  };
31 
32  unsigned int init_work_queue(const std::vector<Params>& parameters_to_evaluate);
33 public:
34  master(int port = 54321, int nb_eval_by_parameter = 1, int nb_eval_by_slave = 1);
35  template <typename Client>
36  void run(Client client);
37 };
38 
39 template <typename Res, typename Params>
41 {
42  using namespace boost;
43  auto self(this->shared_from_this());
44  asio::spawn(master.io_service,
45  [self, connection, &master](asio::yield_context yield)
46  {
47  message::request<Params> request;
48  try
49  {
50  asio::deadline_timer timer(master.io_service);
51  system::error_code ec;
52  std::string remote_address = connection->get_socket().remote_endpoint().address().to_string();
53  while (connection->get_socket().is_open())
54  {
55  while (master.work_queue.empty())
56  {
57  timer.expires_from_now(posix_time::seconds(1));
58  timer.async_wait(yield[ec]);
59  }
60  request = master.work_queue.front();
61  master.work_queue.pop();
62  asio::deadline_timer timer_slave(master.io_service, posix_time::seconds(1800));
63  timer_slave.async_wait([self, connection, request, remote_address, &master](const system::error_code& ec)
64  {
65  if (ec != boost::asio::error::operation_aborted)
66  {
67  master.work_queue.emplace(request);
68  std::cout << "abort slave from "
69  << remote_address
70  << std::endl;
71  self->close(master, connection);
72  }
73  });
74  do
75  {
76  connection->async_write(request, yield[ec]);
77  if (ec)
78  {
79  timer.expires_from_now(posix_time::seconds(2));
80  timer.async_wait(yield[ec]);
81  }
82  } while (ec);
83  message::result<Res> res;
84  connection->async_read(res, yield[ec]);
85  timer_slave.cancel();
86  if (ec)
87  {
88  std::cout << "master read rejected" << std::endl;
89  master.work_queue.emplace(request);
90  continue;
91  }
92  // std::cout << "received results from: "
93  // << remote_address
94  // << std::endl;
95  master.intermediate_results.emplace_back(res);
96  // if (master.intermediate_results.size() % 50 == 0)
97  // std::cout << master.intermediate_results.size() << std::endl;
98  }
99  }
100  catch (const std::exception& e)
101  {
102  master.work_queue.emplace(request);
103  std::cout << e.what() << std::endl;
104  self->close(master, connection);
105  }
106  });
107 }
108 
109 template <typename Res, typename Params>
111 {
112  master.slaves.remove(connection);
113  connection->close();
114 }
115 
116 template <typename Res, typename Params>
117 unsigned int master<Res, Params>::init_work_queue(const std::vector<Params>& parameters_to_evaluate)
118 {
119  unsigned int nb_jobs = 0;
120  int id = 0;
121  uint64_t seed = 1;
122  for (const auto& p : parameters_to_evaluate)
123  {
125  int n = std::min(nb_eval_by_parameter, nb_eval_by_slave);
126  int rest = nb_eval_by_parameter;
127  for (int i = 0; i < nb_eval_by_parameter; i += n, rest -= n)
128  {
129  r.id = id;
130  r.nb_eval = n;
131  r.seed = seed++;
132  r.params = p;
133  work_queue.emplace(r);
134  ++nb_jobs;
135  }
136  if (rest != 0)
137  {
138  r.id = id;
139  r.nb_eval = rest;
140  r.seed = seed++;
141  r.params = p;
142  work_queue.emplace(r);
143  ++nb_jobs;
144  }
145  ++id;
146  seed = 1;
147  }
148  return nb_jobs;
149 }
150 
151 template <typename Res, typename Params>
152 master<Res, Params>::master(int port, int nb_eval_by_parameter, int nb_eval_by_slave)
153  : port(port), nb_eval_by_parameter(nb_eval_by_parameter), nb_eval_by_slave(nb_eval_by_slave)
154 {
155 }
156 
157 template <typename Res, typename Params>
158 template <typename Client>
159 void master<Res, Params>::run(Client client)
160 {
161  using namespace boost;
162  asio::spawn(io_service,
163  [&](asio::yield_context yield)
164  {
165  asio::ip::tcp::acceptor acceptor(io_service,
166  asio::ip::tcp::endpoint(asio::ip::tcp::v4(), port));
167  while (true)
168  {
169  system::error_code ec;
170  asio::ip::tcp::socket socket(io_service);
171  acceptor.async_accept(socket, yield[ec]);
172  if (!ec)
173  {
174  slaves.emplace_back(new connection(std::move(socket)));
175  auto session_ptr = std::make_shared<session>();
176  std::cout << "new connection from: "
177  << slaves.back()->get_socket().remote_endpoint().address().to_string()
178  << std::endl;
179  session_ptr->run(*this, slaves.back());
180  }
181  }
182  });
183 
184  asio::spawn(io_service,
185  [&](asio::yield_context yield)
186  {
187  system::error_code ec;
188  asio::deadline_timer timer(io_service);
189  while (!client->finished())
190  {
191  unsigned int nb_params, nb_jobs;
192  const auto params = client->get_parameters(yield);
193  nb_params = params.size();
194  nb_jobs = init_work_queue(params);
195  while (intermediate_results.size() < nb_jobs)
196  {
197  timer.expires_from_now(posix_time::seconds(2));
198  timer.async_wait(yield[ec]);
199  }
200  std::cout << "master "
201  << intermediate_results.size()
202  << " jobs done"
203  << std::endl;
204  std::vector<std::list<Res>> res(nb_params);
205  for (auto& r : intermediate_results)
206  {
207  res[r.id].splice(res[r.id].end(), r.results);
208  }
209  client->set_results(std::move(res), yield);
210  intermediate_results.clear();
211  }
212  io_service.stop();
213  });
214  io_service.run();
215 }
void run(Client client)
Definition: master.hpp:159
int id
Definition: message.hpp:20
void close()
Definition: connection.hpp:39
void run(int argc, char *argv[])
Definition: master_slave.cpp:11
master(int port=54321, int nb_eval_by_parameter=1, int nb_eval_by_slave=1)
Definition: master.hpp:152
uint64_t seed
Definition: message.hpp:22
Definition: master.hpp:15
int nb_eval
Definition: message.hpp:21
Definition: connection.hpp:13
Represent a serialized agent extracted from an archive.
Definition: message.hpp:18
T params
Definition: message.hpp:23
std::shared_ptr< connection > connection_ptr
Definition: connection.hpp:85