Skip to content

Commit

Permalink
add draw utility penalty in puct selection
Browse files Browse the repository at this point in the history
test f15
  • Loading branch information
dhbloo committed Oct 12, 2024
1 parent e73658a commit 34dc002
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
3 changes: 3 additions & 0 deletions Rapfi/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ int MaxNonPVRootmovesToPrint = 10;
int NumNodesAfterSingularRoot = 100;
/// The power of two number of shards that the node table has.
int NumNodeTableShardsPowerOfTwo = 10;
/// The ratio to decrase utility when child draw rate is high.
float DrawUtilityPenalty = 0.0f;

// Time management options

Expand Down Expand Up @@ -430,6 +432,7 @@ void Config::readSearch(const cpptoml::table &t)
t.get_as<int>("num_nodes_after_singular_root").value_or(NumNodesAfterSingularRoot);
NumNodeTableShardsPowerOfTwo =
t.get_as<int>("num_node_table_shards_power_of_two").value_or(NumNodeTableShardsPowerOfTwo);
DrawUtilityPenalty = t.get_as<double>("draw_utility_penalty").value_or(DrawUtilityPenalty);

// Read time management options
if (auto tm = t.get_table("timectl")) {
Expand Down
15 changes: 8 additions & 7 deletions Rapfi/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,14 @@ extern int NumIterationAfterMate;
extern int NumIterationAfterSingularRoot;
extern int MaxSearchDepth;

extern bool ExpandWhenFirstEvaluate;
extern int MaxNumVisitsPerPlayout;
extern int NodesToPrintMCTSRootmoves;
extern int TimeToPrintMCTSRootmoves;
extern int MaxNonPVRootmovesToPrint;
extern int NumNodesAfterSingularRoot;
extern int NumNodeTableShardsPowerOfTwo;
extern bool ExpandWhenFirstEvaluate;
extern int MaxNumVisitsPerPlayout;
extern int NodesToPrintMCTSRootmoves;
extern int TimeToPrintMCTSRootmoves;
extern int MaxNonPVRootmovesToPrint;
extern int NumNodesAfterSingularRoot;
extern int NumNodeTableShardsPowerOfTwo;
extern float DrawUtilityPenalty;

// -------------------------------------------------
// Time management options
Expand Down
11 changes: 11 additions & 0 deletions Rapfi/search/mcts/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ inline float fpuValue(float parentAvgUtility, float parentRawUtility, float expl

/// Compute PUCT selection value with the given child statistics.
inline float puctSelectionValue(float childUtility,
float childDraw,
float parentDraw,
float childPolicy,
uint32_t childVisits,
uint32_t childVirtualVisits,
Expand All @@ -265,6 +267,9 @@ inline float puctSelectionValue(float childUtility,
float U = cpuctExploration * childPolicy / (1 + childVisits);
float Q = childUtility;

if (Config::DrawUtilityPenalty != 0)
Q -= Config::DrawUtilityPenalty * childDraw * (1 - parentDraw);

// Account for virtual losses
if (childVirtualVisits > 0)
Q = (Q * childVisits - childVirtualVisits) / (childVisits + childVirtualVisits);
Expand Down Expand Up @@ -302,6 +307,7 @@ std::pair<Edge *, Node *> selectChild(Node &node, const Board &board)
SearchThread *thisThread = board.thisThread();

uint32_t parentVisits = node.getVisits();
float parentDraw = node.getD();
float cpuctExploration = cpuctExplorationFactor(parentVisits);

// Apply dynamic cpuct scaling based on parent utility variance if needed
Expand Down Expand Up @@ -346,7 +352,10 @@ std::pair<Edge *, Node *> selectChild(Node &node, const Board &board)
uint32_t childVisits = childEdge.getVisits();
uint32_t childVirtualVisits = childNode->getVirtualVisits();
float childUtility = -childNode->getQ();
float childDraw = childNode->getD();
float selectionValue = puctSelectionValue(childUtility,
childDraw,
parentDraw,
childPolicy,
childVisits,
childVirtualVisits,
Expand All @@ -368,6 +377,8 @@ std::pair<Edge *, Node *> selectChild(Node &node, const Board &board)
uint32_t childVisits = 0; // Unexplored edge must has zero edge visit
uint32_t childVirtualVisits = 0; // Unexplored edge must has zero virtual visit
float selectionValue = puctSelectionValue(fpuUtility,
parentDraw,
parentDraw,
childPolicy,
childVisits,
childVirtualVisits,
Expand Down

0 comments on commit 34dc002

Please sign in to comment.