From ed798562a1823fb7fdf98cf56742bb261f49054c Mon Sep 17 00:00:00 2001 From: dhb <1084714805@qq.com> Date: Mon, 14 Oct 2024 02:38:11 +0800 Subject: [PATCH] add draw utility penalty in puct selection Add @hzyhhzy's idea of less exploration for drawish children nodes. This seems to reduce the amount of useless search spent on highly drawish node, so the search will focus more on important branches. For MCTS search, Passed LTC on f15: TC: 60+0.6 Total/Win/Draw/Lose: 2212 / 743 / 1059 / 410 PTNML: 31 / 151 / 495 / 312 / 117 WinRate: 57.53% ELO: 52.41[41.58, 63.54] Passed VVLTC on f15: TC: 360+3.6 Total/Win/Draw/Lose: 5048 / 1162 / 3257 / 629 PTNML: 30 / 282 / 1446 / 657 / 109 WinRate: 55.28% ELO: 36.58[28.33, 44.97] --- Rapfi/config.cpp | 3 +++ Rapfi/config.h | 15 ++++++++------- Rapfi/search/mcts/parameter.h | 22 +++++++++++----------- Rapfi/search/mcts/search.cpp | 13 +++++++++++++ 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/Rapfi/config.cpp b/Rapfi/config.cpp index 2b9cef5..4b29b04 100644 --- a/Rapfi/config.cpp +++ b/Rapfi/config.cpp @@ -129,6 +129,8 @@ int MaxNonPVRootmovesToPrint = 10; int NumNodesAfterSingularRoot = 100; /// The power of two number of shards that the node table has. int NumNodeTableShardsPowerOfTwo = 10; +/// The ratio to decrase utility when child draw rate is high. +float DrawUtilityPenalty = 0.35f; // Time management options @@ -430,6 +432,7 @@ void Config::readSearch(const cpptoml::table &t) t.get_as("num_nodes_after_singular_root").value_or(NumNodesAfterSingularRoot); NumNodeTableShardsPowerOfTwo = t.get_as("num_node_table_shards_power_of_two").value_or(NumNodeTableShardsPowerOfTwo); + DrawUtilityPenalty = t.get_as("draw_utility_penalty").value_or(DrawUtilityPenalty); // Read time management options if (auto tm = t.get_table("timectl")) { diff --git a/Rapfi/config.h b/Rapfi/config.h index fafc809..8c9f163 100644 --- a/Rapfi/config.h +++ b/Rapfi/config.h @@ -120,13 +120,14 @@ extern int NumIterationAfterMate; extern int NumIterationAfterSingularRoot; extern int MaxSearchDepth; -extern bool ExpandWhenFirstEvaluate; -extern int MaxNumVisitsPerPlayout; -extern int NodesToPrintMCTSRootmoves; -extern int TimeToPrintMCTSRootmoves; -extern int MaxNonPVRootmovesToPrint; -extern int NumNodesAfterSingularRoot; -extern int NumNodeTableShardsPowerOfTwo; +extern bool ExpandWhenFirstEvaluate; +extern int MaxNumVisitsPerPlayout; +extern int NodesToPrintMCTSRootmoves; +extern int TimeToPrintMCTSRootmoves; +extern int MaxNonPVRootmovesToPrint; +extern int NumNodesAfterSingularRoot; +extern int NumNodeTableShardsPowerOfTwo; +extern float DrawUtilityPenalty; // ------------------------------------------------- // Time management options diff --git a/Rapfi/search/mcts/parameter.h b/Rapfi/search/mcts/parameter.h index d894f2d..5d79048 100644 --- a/Rapfi/search/mcts/parameter.h +++ b/Rapfi/search/mcts/parameter.h @@ -24,19 +24,19 @@ namespace Search::MCTS { constexpr float MaxNewVisitsProp = 0.36f; -constexpr float CpuctExploration = 0.39f; -constexpr float CpuctExplorationLog = 0.98f; -constexpr float CpuctExplorationBase = 340; +constexpr float CpuctExploration = 0.35f; +constexpr float CpuctExplorationLog = 1.02f; +constexpr float CpuctExplorationBase = 328; -constexpr float CpuctUtilityStdevScale = 0.043f; -constexpr float CpuctUtilityVarPrior = 0.16f; -constexpr float CpuctUtilityVarPriorWeight = 1.87f; +constexpr float CpuctUtilityStdevScale = 0.05f; +constexpr float CpuctUtilityVarPrior = 0.15f; +constexpr float CpuctUtilityVarPriorWeight = 1.80f; -constexpr float FpuReductionMax = 0.06f; +constexpr float FpuReductionMax = 0.055f; constexpr float FpuLossProp = 0.0008f; -constexpr float RootFpuReductionMax = 0.073f; -constexpr float RootFpuLossProp = 0.0036f; -constexpr float FpuUtilityBlendPow = 0.84f; +constexpr float RootFpuReductionMax = 0.07f; +constexpr float RootFpuLossProp = 0.003f; +constexpr float FpuUtilityBlendPow = 0.75f; constexpr uint32_t MinTranspositionSkipVisits = 11; @@ -44,7 +44,7 @@ constexpr bool UseLCBForBestmoveSelection = true; constexpr float LCBStdevs = 6.28f; constexpr float LCBMinVisitProp = 0.1f; -constexpr float PolicyTemperature = 0.91f; +constexpr float PolicyTemperature = 0.90f; constexpr float RootPolicyTemperature = 1.05f; } // namespace Search::MCTS diff --git a/Rapfi/search/mcts/search.cpp b/Rapfi/search/mcts/search.cpp index ee51eef..8ac8984 100644 --- a/Rapfi/search/mcts/search.cpp +++ b/Rapfi/search/mcts/search.cpp @@ -257,6 +257,8 @@ inline float fpuValue(float parentAvgUtility, float parentRawUtility, float expl /// Compute PUCT selection value with the given child statistics. inline float puctSelectionValue(float childUtility, + float childDraw, + float parentDraw, float childPolicy, uint32_t childVisits, uint32_t childVirtualVisits, @@ -265,6 +267,11 @@ inline float puctSelectionValue(float childUtility, float U = cpuctExploration * childPolicy / (1 + childVisits); float Q = childUtility; + // Reduce utility value for drawish child nodes for PUCT selection + // Encourage exploration for less drawish child nodes + if (Config::DrawUtilityPenalty != 0) + Q -= Config::DrawUtilityPenalty * childDraw * (1 - parentDraw); + // Account for virtual losses if (childVirtualVisits > 0) Q = (Q * childVisits - childVirtualVisits) / (childVisits + childVirtualVisits); @@ -302,6 +309,7 @@ std::pair selectChild(Node &node, const Board &board) SearchThread *thisThread = board.thisThread(); uint32_t parentVisits = node.getVisits(); + float parentDraw = node.getD(); float cpuctExploration = cpuctExplorationFactor(parentVisits); // Apply dynamic cpuct scaling based on parent utility variance if needed @@ -346,7 +354,10 @@ std::pair selectChild(Node &node, const Board &board) uint32_t childVisits = childEdge.getVisits(); uint32_t childVirtualVisits = childNode->getVirtualVisits(); float childUtility = -childNode->getQ(); + float childDraw = childNode->getD(); float selectionValue = puctSelectionValue(childUtility, + childDraw, + parentDraw, childPolicy, childVisits, childVirtualVisits, @@ -368,6 +379,8 @@ std::pair selectChild(Node &node, const Board &board) uint32_t childVisits = 0; // Unexplored edge must has zero edge visit uint32_t childVirtualVisits = 0; // Unexplored edge must has zero virtual visit float selectionValue = puctSelectionValue(fpuUtility, + parentDraw, + parentDraw, childPolicy, childVisits, childVirtualVisits,