-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMCTS.h
64 lines (51 loc) · 1.72 KB
/
MCTS.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#pragma once
#include <unordered_map>
#include <string>
#include <vector>
#include <thread>
#include <atomic>
#include "GameField.h"
#include "thread_pool.h"
#include "libtorch.h"
class TreeNode {
public:
// friend class can access private variables
friend class MCTS;
TreeNode();
TreeNode(const TreeNode& node);
TreeNode(TreeNode* parent, double p_sa, unsigned action_size);
TreeNode& operator=(const TreeNode& p);
unsigned int select(double c_puct, double c_virtual_loss);
void expand(const std::vector<double>& action_priors);
void backup(double leaf_value);
double get_value(double c_puct, double c_virtual_loss,
unsigned int sum_n_visited) const;
inline bool get_is_leaf() const { return this->is_leaf; }
// store tree
TreeNode* parent;
std::vector<TreeNode*> children;
bool is_leaf;
std::mutex lock;
std::atomic<unsigned int> n_visited;
double p_sa;
double q_sa;
std::atomic<int> virtual_loss;
};
class MCTS {
public:
MCTS(NeuralNetwork* neural_network, unsigned int thread_num, double c_puct,
unsigned int num_mcts_sims, double c_virtual_loss,
unsigned int action_size);
std::vector<double> get_action_probs(GameField* g, double temp = 1e-3);
void update_with_move(int last_move);
void simulate(std::shared_ptr<GameField> game, bool explore);
static void tree_deleter(TreeNode* t);
// variables
std::unique_ptr<TreeNode, decltype(MCTS::tree_deleter)*> root;
std::unique_ptr<ThreadPool> thread_pool;
NeuralNetwork* neural_network;
unsigned int action_size;
unsigned int num_mcts_sims;
double c_puct;
double c_virtual_loss;
};