Skip to content

Commit

Permalink
Merge pull request LeelaChessZero#1 from dubslow/crem-master
Browse files Browse the repository at this point in the history
Cosmetic changes to mcts (comments, variable renames)
  • Loading branch information
mooskagh authored Jun 5, 2018
2 parents a951c13 + 09964d0 commit aaad254
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 62 deletions.
6 changes: 3 additions & 3 deletions src/analyzer/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ void Analyzer::RunOnePosition(const std::vector<Move>& moves) {

// Fetch MCTS-agnostic per-move stats P and V.
std::vector<const Node*> nodes;
for (Node* iter : tree.GetCurrentHead()->Children()) {
nodes.emplace_back(iter);
for (Node* child : tree.GetCurrentHead()->Children()) {
nodes.emplace_back(child);
}
std::sort(nodes.begin(), nodes.end(), [](const Node* a, const Node* b) {
return a->GetNStarted() > b->GetNStarted();
Expand Down Expand Up @@ -261,4 +261,4 @@ void Analyzer::OnInfo(const ThinkingInfo& info) const {
WriteToLog(res);
}

} // namespace lczero
} // namespace lczero
14 changes: 7 additions & 7 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class Node::Pool {
Node::Pool::FreeNode* Node::Pool::UnrollNodeTree(FreeNode* node) {
if (!node->node.child_) return node;
FreeNode* prev = node;
for (Node* iter = node->node.child_; iter; iter = iter->sibling_) {
FreeNode* next = reinterpret_cast<FreeNode*>(iter);
for (Node* child = node->node.child_; child; child = child->sibling_) {
FreeNode* next = reinterpret_cast<FreeNode*>(child);
prev->next = next;
prev = UnrollNodeTree(next);
}
Expand Down Expand Up @@ -278,8 +278,8 @@ void Node::UpdateMaxDepth(int depth) {

bool Node::UpdateFullDepth(uint16_t* depth) {
if (full_depth_ > *depth) return false;
for (Node* iter : Children()) {
if (*depth > iter->full_depth_) *depth = iter->full_depth_;
for (Node* child : Children()) {
if (*depth > child->full_depth_) *depth = child->full_depth_;
}
if (*depth >= full_depth_) {
full_depth_ = ++*depth;
Expand Down Expand Up @@ -309,8 +309,8 @@ V3TrainingData Node::GetV3TrainingData(GameResult game_result,
float total_n =
static_cast<float>(n_ - 1); // First visit was expansion of it inself.
std::memset(result.probabilities, 0, sizeof(result.probabilities));
for (Node* iter : Children()) {
result.probabilities[iter->move_.as_nn_index()] = iter->n_ / total_n;
for (Node* child : Children()) {
result.probabilities[child->move_.as_nn_index()] = child->n_ / total_n;
}

// Populate planes.
Expand Down Expand Up @@ -407,4 +407,4 @@ void NodeTree::DeallocateTree() {
current_head_ = nullptr;
}

} // namespace lczero
} // namespace lczero
9 changes: 5 additions & 4 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Node {
// Returns move, with optional flip (false == player BEFORE the position).
Move GetMove(bool flip) const;

// Returns sum of probabilities for visited children.
// Returns sum of policy priors which have had at least one playout.
float GetVisitedPolicy() const;
uint32_t GetN() const { return n_; }
uint32_t GetNInFlight() const { return n_in_flight_; }
Expand All @@ -80,11 +80,12 @@ class Node {
return q_;
}
}
// Returns U / (Puct * N[parent])
// Returns p / N, which is equal to U / (cpuct * sqrt(N[parent])) by the MCTS
// equation. So it's really more of a "reduced U" than raw U.
float GetU() const { return p_ / (1 + n_ + n_in_flight_); }
// Returns value of Value Head returned from the neural net.
float GetV() const { return v_; }
// Returns value of Move probabilityreturned from the neural net.
// Returns value of Move probability returned from the neural net
// (but can be changed by adding Dirichlet noise).
float GetP() const { return p_; }
// Returns whether the node is known to be draw/lose/win.
Expand Down Expand Up @@ -219,4 +220,4 @@ class NodeTree {
// Performs garbage collection on node pool. Thread safe.
void GarbageCollectNodePool();

} // namespace lczero
} // namespace lczero
102 changes: 54 additions & 48 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ void ApplyDirichletNoise(Node* node, float eps, double alpha) {
std::vector<float> noise;

// TODO(mooskagh) remove this loop when we store number of children.
for (Node* iter : node->Children()) {
(void)iter; // Silence the unused variable warning.
for (Node* child : node->Children()) {
(void)child; // Silence the unused variable warning.
float eta = Random::Get().GetGamma(alpha, 1.0);
noise.emplace_back(eta);
total += eta;
Expand All @@ -129,8 +129,8 @@ void ApplyDirichletNoise(Node* node, float eps, double alpha) {
if (total < std::numeric_limits<float>::min()) return;

int noise_idx = 0;
for (Node* iter : node->Children()) {
iter->SetP(iter->GetP() * (1 - eps) + eps * noise[noise_idx++] / total);
for (Node* child : node->Children()) {
child->SetP(child->GetP() * (1 - eps) + eps * noise[noise_idx++] / total);
}
}

Expand Down Expand Up @@ -226,8 +226,8 @@ void Search::SendMovesStats() const {
const float parent_q =
-root_node_->GetQ(0, 0) -
kFpuReduction * std::sqrt(root_node_->GetVisitedPolicy());
for (Node* iter : root_node_->Children()) {
nodes.emplace_back(iter);
for (Node* child : root_node_->Children()) {
nodes.emplace_back(child);
}
std::sort(nodes.begin(), nodes.end(),
[](const Node* a, const Node* b) { return a->GetN() < b->GetN(); });
Expand Down Expand Up @@ -487,18 +487,20 @@ void SearchWorker::GatherMinibatch() {
// Node was never visited, extending.
ExtendNode(node);

// If node turned out to be a terminal one, no need to send to NN for
// evaluation.
// Only send non-terminal nodes to neural network
if (!node->IsTerminal()) {
nodes_to_process_.back().nn_queried = true;
AddNodeToCompute(node);
}
}
}

// Returns node and whether it should be processed.
// (false if it is a collision).
// Returns node and whether there's been a search collision on the node.
SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend() {
// Starting from search_->root_node_, generate a playout, choosing a
// node at each level according to the MCTS formula. n_in_flight_ is
// incremented for each node in the playout (via TryStartScoreUpdate()).

Node* node = search_->root_node_;
// Initialize position sequence with pre-move position.
history_.Trim(search_->played_history_.GetLength());
Expand All @@ -513,17 +515,20 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend() {
// True on first iteration, false as we dive deeper.
bool is_root_node = true;
while (true) {
// First, terminate if we find collisions or leaf nodes.
{
SharedMutex::Lock lock(search_->nodes_mutex_);
// Check whether we are in the leave.
// n_in_flight_ is incremented. If the method returns false, then there is
// a search collision, and this node is alredy being expanded.
if (!node->TryStartScoreUpdate()) return {node, true};
// Found leave, and we are the the first to visit it.
// Unexamined leaf node. We've hit the end of this playout.
if (!node->HasChildren()) return {node, false};
// If we fall through, then n_in_flight_ has been incremented but this
// playout remains incomplete; we must go deeper.
}

// Now we are not in leave, we need to go deeper.
SharedMutex::SharedLock lock(search_->nodes_mutex_);
float factor =
float puct_mult =
search_->kCpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
float best = -100.0f;
int possible_moves = 0;
Expand All @@ -532,29 +537,29 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend() {
? -node->GetQ(0, search_->kExtraVirtualLoss)
: -node->GetQ(0, search_->kExtraVirtualLoss) -
search_->kFpuReduction * std::sqrt(node->GetVisitedPolicy());
for (Node* iter : node->Children()) {
for (Node* child : node->Children()) {
if (is_root_node) {
// If there's no chance to catch up the currently best node with
// remaining playouts, not consider it.
// best_move_node_ can change since best_node_n computation.
// If there's no chance to catch up to the current best node with
// remaining playouts, don't consider it.
// best_move_node_ could have changed since best_node_n was retrieved.
// To ensure we have at least one node to expand, always include
// current best node.
if (iter != search_->best_move_node_ &&
if (child != search_->best_move_node_ &&
search_->remaining_playouts_ <
best_node_n - static_cast<int>(iter->GetN())) {
best_node_n - static_cast<int>(child->GetN())) {
continue;
}
++possible_moves;
}
float Q = iter->GetQ(parent_q, search_->kExtraVirtualLoss);
if (search_->kVirtualLossBug && iter->GetN() == 0) {
Q = (Q * iter->GetParent()->GetN() - search_->kVirtualLossBug) /
(iter->GetParent()->GetN() + std::fabs(search_->kVirtualLossBug));
float Q = child->GetQ(parent_q, search_->kExtraVirtualLoss);
if (search_->kVirtualLossBug && child->GetN() == 0) {
Q = (Q * child->GetParent()->GetN() - search_->kVirtualLossBug) /
(child->GetParent()->GetN() + std::fabs(search_->kVirtualLossBug));
}
const float score = factor * iter->GetU() + Q;
const float score = puct_mult * child->GetU() + Q;
if (score > best) {
best = score;
node = iter;
node = child;
}
}
history_.Append(node->GetMove());
Expand All @@ -574,20 +579,20 @@ void SearchWorker::ExtendNode(Node* node) {
const auto& board = history_.Last().GetBoard();
auto legal_moves = board.GenerateLegalMoves();

// Check whether it's a draw/lose by rules.
// Check whether it's a draw/lose by position. Importantly, we must check
// these before doing the by-rule checks below.
if (legal_moves.empty()) {
// Checkmate or stalemate.
// Could be a checkmate or a stalemate
if (board.IsUnderCheck()) {
// Checkmate.
node->MakeTerminal(GameResult::WHITE_WON);
} else {
// Stalemate.
node->MakeTerminal(GameResult::DRAW);
}
return;
}

// If it's root node and we're asked to think, pretend there's no draw.
// We can shortcircuit these draws-by-rule only if they aren't root;
// if they are root, then thinking about them is the point.
if (node != search_->root_node_) {
if (!board.HasMatingMaterial()) {
node->MakeTerminal(GameResult::DRAW);
Expand Down Expand Up @@ -624,8 +629,8 @@ bool SearchWorker::AddNodeToCompute(Node* node, bool add_if_cached) {

if (node->HasChildren()) {
// Legal moves are known, using them.
for (Node* iter : node->Children()) {
moves.emplace_back(iter->GetMove().as_nn_index());
for (Node* child : node->Children()) {
moves.emplace_back(child->GetMove().as_nn_index());
}
} else {
// Cache pseudolegal moves. A bit of a waste, but faster.
Expand Down Expand Up @@ -676,28 +681,28 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) {
return 1;
}

// If it's a node in progress of expansion or is terminal, not prefetching.
// If it's a node in process of expansion or is terminal, don't prefetch it.
if (!node->HasChildren()) return 0;

// Populate all subnodes and their scores.
typedef std::pair<float, Node*> ScoredNode;
std::vector<ScoredNode> scores;
float factor =
float puct_mult =
search_->kCpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
// FPU reduction is not taken into account.
const float parent_q = -node->GetQ(0, search_->kExtraVirtualLoss);
for (Node* iter : node->Children()) {
if (iter->GetP() == 0.0f) continue;
for (Node* child : node->Children()) {
if (child->GetP() == 0.0f) continue;
// Flipping sign of a score to be able to easily sort.
scores.emplace_back(-factor * iter->GetU() -
iter->GetQ(parent_q, search_->kExtraVirtualLoss),
iter);
scores.emplace_back(-puct_mult * child->GetU() -
child->GetQ(parent_q, search_->kExtraVirtualLoss),
child);
}

size_t first_unsorted_index = 0;
int total_budget_spent = 0;
int budget_to_spend = budget; // Initializing for the case there's only
// on child.
int budget_to_spend = budget; // Initialize for the case where there's only
// one child.
for (size_t i = 0; i < scores.size(); ++i) {
if (budget <= 0) break;

Expand All @@ -715,13 +720,14 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) {
Node* n = scores[i].second;
// Last node gets the same budget as prev-to-last node.
if (i != scores.size() - 1) {
// Sign of the score was flipped for sorting, flipping back.
// Sign of the score was flipped for sorting, so flip it back.
const float next_score = -scores[i + 1].first;
const float q = n->GetQ(-parent_q, search_->kExtraVirtualLoss);
if (next_score > q) {
budget_to_spend = std::min(
budget,
int(n->GetP() * factor / (next_score - q) - n->GetNStarted()) + 1);
int(n->GetP() * puct_mult / (next_score - q) - n->GetNStarted())
+ 1);
} else {
budget_to_spend = budget;
}
Expand All @@ -738,7 +744,8 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) {
// 4. Run NN computation.
// ~~~~~~~~~~~~~~~~~~~~~~
void SearchWorker::RunNNComputation() {
// Evaluate nodes through NN.
// This function is so small as to be silly, but its parent function is
// conceptually cleaner for it.
if (computation_->GetBatchSize() != 0) computation_->ComputeBlocking();
}

Expand Down Expand Up @@ -798,7 +805,7 @@ void SearchWorker::DoBackupUpdate() {
float v = node->GetV();
// Maximum depth the node is explored.
uint16_t depth = 0;
// If the node is terminal, mark it as fully explored to an infinite
// If the node is terminal, mark it as fully explored to an "infinite"
// depth.
uint16_t cur_full_depth = node->IsTerminal() ? 999 : 0;
bool full_depth_updated = true;
Expand Down Expand Up @@ -830,8 +837,7 @@ void SearchWorker::DoBackupUpdate() {
// 7. UpdateCounters()
//~~~~~~~~~~~~~~~~~~~~
void SearchWorker::UpdateCounters() {
search_
->UpdateRemainingMoves(); // Update remaining moves using smart pruning.
search_->UpdateRemainingMoves(); // Updates smart pruning counters.
search_->MaybeOutputInfo();
search_->MaybeTriggerStop();

Expand Down

0 comments on commit aaad254

Please sign in to comment.