mirror of
https://github.com/RGBCube/serenity
synced 2025-07-27 05:17:35 +00:00
ChessEngine: don't store board in non-leaf nodes in MCTS
Also make parameters static so they aren't in every node of the tree this saves a substantial amount of memory.
This commit is contained in:
parent
34433f5dc4
commit
49539abee0
3 changed files with 30 additions and 30 deletions
|
@ -42,9 +42,6 @@ void ChessEngine::handle_go(const GoCommand& command)
|
|||
|
||||
MCTSTree mcts(m_board);
|
||||
|
||||
// FIXME: optimize simulations enough for use.
|
||||
mcts.set_eval_method(MCTSTree::EvalMethod::Heuristic);
|
||||
|
||||
int rounds = 0;
|
||||
while (elapsed_time.elapsed() <= command.movetime.value()) {
|
||||
mcts.do_round();
|
||||
|
|
|
@ -8,13 +8,12 @@
|
|||
#include <AK/String.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
MCTSTree::MCTSTree(const Chess::Board& board, double exploration_parameter, MCTSTree* parent)
|
||||
MCTSTree::MCTSTree(const Chess::Board& board, MCTSTree* parent)
|
||||
: m_parent(parent)
|
||||
, m_exploration_parameter(exploration_parameter)
|
||||
, m_board(board)
|
||||
, m_board(make<Chess::Board>(board))
|
||||
, m_last_move(board.last_move())
|
||||
, m_turn(board.turn())
|
||||
{
|
||||
if (m_parent)
|
||||
m_eval_method = m_parent->eval_method();
|
||||
}
|
||||
|
||||
MCTSTree& MCTSTree::select_leaf()
|
||||
|
@ -25,7 +24,7 @@ MCTSTree& MCTSTree::select_leaf()
|
|||
MCTSTree* node = nullptr;
|
||||
double max_uct = -double(INFINITY);
|
||||
for (auto& child : m_children) {
|
||||
double uct = child.uct(m_board.turn());
|
||||
double uct = child.uct(m_turn);
|
||||
if (uct >= max_uct) {
|
||||
max_uct = uct;
|
||||
node = &child;
|
||||
|
@ -40,13 +39,15 @@ MCTSTree& MCTSTree::expand()
|
|||
VERIFY(!expanded() || m_children.size() == 0);
|
||||
|
||||
if (!m_moves_generated) {
|
||||
m_board.generate_moves([&](Chess::Move move) {
|
||||
Chess::Board clone = m_board;
|
||||
m_board->generate_moves([&](Chess::Move move) {
|
||||
Chess::Board clone = *m_board;
|
||||
clone.apply_move(move);
|
||||
m_children.append(make<MCTSTree>(clone, m_exploration_parameter, this));
|
||||
m_children.append(make<MCTSTree>(clone, this));
|
||||
return IterationDecision::Continue;
|
||||
});
|
||||
m_moves_generated = true;
|
||||
if (m_children.size() != 0)
|
||||
m_board = nullptr; // Release the board to save memory.
|
||||
}
|
||||
|
||||
if (m_children.size() == 0) {
|
||||
|
@ -63,8 +64,7 @@ MCTSTree& MCTSTree::expand()
|
|||
|
||||
int MCTSTree::simulate_game() const
|
||||
{
|
||||
VERIFY_NOT_REACHED();
|
||||
Chess::Board clone = m_board;
|
||||
Chess::Board clone = *m_board;
|
||||
while (!clone.game_finished()) {
|
||||
clone.apply_move(clone.random_move());
|
||||
}
|
||||
|
@ -73,10 +73,10 @@ int MCTSTree::simulate_game() const
|
|||
|
||||
int MCTSTree::heuristic() const
|
||||
{
|
||||
if (m_board.game_finished())
|
||||
return m_board.game_score();
|
||||
if (m_board->game_finished())
|
||||
return m_board->game_score();
|
||||
|
||||
double winchance = max(min(double(m_board.material_imbalance()) / 6, 1.0), -1.0);
|
||||
double winchance = max(min(double(m_board->material_imbalance()) / 6, 1.0), -1.0);
|
||||
|
||||
double random = double(rand()) / RAND_MAX;
|
||||
if (winchance >= random)
|
||||
|
@ -101,7 +101,7 @@ void MCTSTree::do_round()
|
|||
auto& node = select_leaf().expand();
|
||||
|
||||
int result;
|
||||
if (m_eval_method == EvalMethod::Simulation) {
|
||||
if constexpr (s_eval_method == EvalMethod::Simulation) {
|
||||
result = node.simulate_game();
|
||||
} else {
|
||||
result = node.heuristic();
|
||||
|
@ -111,7 +111,7 @@ void MCTSTree::do_round()
|
|||
|
||||
Chess::Move MCTSTree::best_move() const
|
||||
{
|
||||
int score_multiplier = (m_board.turn() == Chess::Color::White) ? 1 : -1;
|
||||
int score_multiplier = (m_turn == Chess::Color::White) ? 1 : -1;
|
||||
|
||||
Chess::Move best_move = { { 0, 0 }, { 0, 0 } };
|
||||
double best_score = -double(INFINITY);
|
||||
|
@ -119,8 +119,7 @@ Chess::Move MCTSTree::best_move() const
|
|||
for (auto& node : m_children) {
|
||||
double node_score = node.expected_value() * score_multiplier;
|
||||
if (node_score >= best_score) {
|
||||
// The best move is the last move made in the child.
|
||||
best_move = node.m_board.moves()[node.m_board.moves().size() - 1];
|
||||
best_move = node.m_last_move.value();
|
||||
best_score = node_score;
|
||||
}
|
||||
}
|
||||
|
@ -143,7 +142,7 @@ double MCTSTree::uct(Chess::Color color) const
|
|||
|
||||
// Fun fact: Szepesvári was my data structures professor.
|
||||
double expected = expected_value() * ((color == Chess::Color::White) ? 1 : -1);
|
||||
return expected + m_exploration_parameter * sqrt(log(m_parent->m_simulations) / m_simulations);
|
||||
return expected + s_exploration_parameter * sqrt(log(m_parent->m_simulations) / m_simulations);
|
||||
}
|
||||
|
||||
bool MCTSTree::expanded() const
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include <AK/Function.h>
|
||||
#include <AK/NonnullOwnPtrVector.h>
|
||||
#include <AK/OwnPtr.h>
|
||||
#include <LibChess/Chess.h>
|
||||
#include <math.h>
|
||||
|
||||
|
@ -18,7 +19,7 @@ public:
|
|||
Heuristic,
|
||||
};
|
||||
|
||||
MCTSTree(const Chess::Board& board, double exploration_parameter = sqrt(2), MCTSTree* parent = nullptr);
|
||||
MCTSTree(const Chess::Board& board, MCTSTree* parent = nullptr);
|
||||
|
||||
MCTSTree& select_leaf();
|
||||
MCTSTree& expand();
|
||||
|
@ -32,16 +33,19 @@ public:
|
|||
double uct(Chess::Color color) const;
|
||||
bool expanded() const;
|
||||
|
||||
EvalMethod eval_method() const { return m_eval_method; }
|
||||
void set_eval_method(EvalMethod method) { m_eval_method = method; }
|
||||
|
||||
private:
|
||||
// While static parameters are less configurable, they don't take up any
|
||||
// memory in the tree, which I believe to be a worthy tradeoff.
|
||||
static constexpr double s_exploration_parameter { sqrt(2) };
|
||||
// FIXME: Optimize simulations enough for use.
|
||||
static constexpr EvalMethod s_eval_method { EvalMethod::Heuristic };
|
||||
|
||||
NonnullOwnPtrVector<MCTSTree> m_children;
|
||||
MCTSTree* m_parent { nullptr };
|
||||
int m_white_points { 0 };
|
||||
int m_simulations { 0 };
|
||||
bool m_moves_generated { false };
|
||||
double m_exploration_parameter;
|
||||
EvalMethod m_eval_method { EvalMethod::Simulation };
|
||||
Chess::Board m_board;
|
||||
OwnPtr<Chess::Board> m_board;
|
||||
Optional<Chess::Move> m_last_move;
|
||||
Chess::Color m_turn : 2;
|
||||
bool m_moves_generated : 1 { false };
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue