Monte Carlo Tree Search
蒙特卡洛树搜索(Monte Carlo Tree Search, MCTS)是一种用于在某些类型的决策过程中(尤其是在具有巨大搜索空间的问题中,如棋类游戏)寻找最优决策的启发式搜索算法,对比之下DFS和BFS就是暴力搜索。MCTS被用在AlphaGo, MuZero中,近期的LLM-Reasoning模型也有使用。
概述
使用MCTS解决某一个问题(例如下棋)就是建立求解这个问题的一颗搜索树的过程,从一个给定的状态出发开始,环境会在达到某个终止状态时给出奖励,然后回溯路径上的节点,并更新节点的一些特征值。
以下棋为例,你处在某一个棋局状态之中,应该怎么走呢?如果之前见过一样的或类似的棋局状态,就知道下面走哪一步胜率最高,如果不知道,那就在头脑中模拟,头脑中模拟就是MCTS过程,没有真实的下棋子,没有改变任何的棋局状态,通过MCTS后,就知道最好的下一步,这时候才真正下棋子,到达新的棋局状态,然后在新的棋局状态开始新一轮MCTS,只要见过海量的棋局状态,经过海量的MCTS,就可以知道最好的下一步。它的逻辑可以用伪代码表示:
tree = MCTS()board = new_board()while True: # 输入你的下棋的棋子 row_col = input("enter row,col: ") row, col = map(int, row_col.split(",")) index = 3 * (row - 1) + (col - 1) if board.tup[index] is not None: raise RuntimeError("Invalid move") # 下棋(棋盘改变状态) board = board.make_move(index) print(board.to_pretty_string()) # 如果棋盘到终止状态,游戏结束 if board.terminal: break # 使用mcts的四步骤进行迭代 for _ in range(mcts_iterations): tree.run_mcts(board) # mcts过后,我们就知道了最优的下一步,选择执行下一步,更新棋盘 board = tree.choose(board)
MCTS算法流程
MCTS的迭代核心有四个步骤:选择(Selection) 扩展(Expansion) 模拟(Simulation/Rollout) 回溯(Backpropagation/Update)。
如下面的流程图所示:
叶子节点?"}; C -- 否 --> D["**选择(Selection)**
当前节点 = 当前节点中UCB(Si)最大的子节点"]; D --> C; C -- 是 --> E{"当前节点的
ni=0?"}; E -- 是 --> F["**模拟(Rollout)**"]; E -- 否 --> G["**扩展(Expansion)**
为当前节点一次性扩展所有可用动作作为子节点"]; G --> H["随机选择一个子节点"]; H --> I["**模拟(Rollout)**"]; J["**反向传播(Backpropagation)**
用模拟结果更新自当前叶节点至根S0路径上各节点的N、Q值"]; F --> J; I --> J; J -.-> B;
也可以用如下的代码来理解:
def run_mcts(self, node): "Make the tree one layer better. (Train for one iteration.)" path = self._select(node) leaf = path[-1] self._expand(leaf) reward = self._simulate(leaf) self._backpropagate(path, reward)
节点
每个节点需要记录三个基本信息:
- 当前节点的状态,例如棋盘的棋局
- 该节点被访问的次数
- 累积评分值,是平均奖励值,即获得的总奖励值除以
累积评分值是MCTS迭代之后,用于最终决定执行动作的依据,选择最大的动作执行。是MCTS外部的游戏棋局选择执行下一步的依据。作为对比,UCB值不需要存储,只发生在MCTS的选择阶段,用于推进树的向下搜索。
选择 Selection
选择(Selection)阶段的目标是找到下一个节点来进行扩展(Expansion)。我们默认扩展阶段会一次性添加所有的可行动作作为子节点,所以选择阶段结束一定是发生在叶子节点。在选择阶段搜索过程中,可能遇到下面的情况:
- 该节点是非叶子节点,依据UCB值(Upper Confidence Bounds)选择最大值的节点往下搜索,由于时,UCB的值为无穷大,此时随机选择一个即可
- 该节点是一个叶子节点,此时选择阶段结束。叶子节点就是没有子节点的节点,它可能有两种情况,这个节点被“探索”过,就是被模拟过,所以有过统计信息,例如被访问次数不为0,另一种情况是这个节点是扩展阶段添加的,但是还没有进行过模拟。
可以参考下面的代码理解:
self.Q = defaultdict(int) # total reward of each nodeself.N = defaultdict(int) # total visit count for each nodeself.children = dict() # 记录所有扩展过的节点,可以通过children.keys()来获得扩展过的子节点# self.children[node]是当前节点的所有子节点
def _select(self, node): "Find an unexplored descendent of `node`" path = [] while True: path.append(node) # node not in self.children 表示这个节点从未被扩展过 # not self.children[node] 为True是表示当前节点是叶子节点 if node not in self.children or not self.children[node]: # node is either unexplored or terminal return path # 从当前子节点中排出掉已经扩展过的节点就是未扩展的节点 unexplored = self.children[node] - self.children.keys() if unexplored: # pop操作是随机选择一个未扩展子节点 n = unexplored.pop() path.append(n) return path # 所有的子节点都被扩展过(都被self.children记录过),就UCT选择 node = self._uct_select(node) # descend a layer deeper
再来说一下UCT函数,它平衡了探索和利用:
其中,是当前节点的访问次数,由于是分母,所以一个新添加的节点的UCB的值是无穷大;是当前节点父节点的访问次数,是一个常数,是当前节点获得的奖励。
def _uct_select(self, node): "Select a child of node, balancing exploration & exploitation"
# All children of node should already be expanded: assert all(n in self.children for n in self.children[node])
log_N_vertex = math.log(self.N[node])
def uct(n): "Upper confidence bound for trees" return self.Q[n] / self.N[n] + self.exploration_weight * math.sqrt( log_N_vertex / self.N[n] )
return max(self.children[node], key=uct)
扩展 Expansion
扩展阶段的输入是选择阶段“选中”的未完全扩展的节点,扩展阶段的目标是从当前节点的状态下的可用的、未使用的动作列表中选择一个动作,并作为一个新节点添加到当前节点下作为子节点。
有些说法会说,在扩展阶段添加一个未被探索的动作作为子节点,但是在实际实现中,这里选择一次性添加所有的可行动作:
def _expand(self, node): "Update the `children` dict with the children of `node`" if node in self.children: return # already expanded # add **all unexplored children** of `node` to the tree self.children[node] = node.find_children()
模拟 Simulation Rollout
模拟阶段从当前给定的节点出发,使用默认策略从合法动作中选择一个动作,推进更新到下一个状态,直到快速的玩完游戏获得奖励,例如获胜、输掉或者平局等。需要注意的是模拟阶段并不添加任何新的节点,它只是想要从当前阶段快速的玩完游戏获得结果,所以只要random快速做出决策即可:
def _simulate(self, node): "Returns the reward for a random simulation (to completion) of `node`" invert_reward = True while True: if node.is_terminal(): reward = node.reward() return 1 - reward if invert_reward else reward node = node.find_random_child() invert_reward = not invert_reward
反向传播 Backpropagate
反向传播阶段就是游戏到了终点,得到了奖励,把奖励值加到经过的每一个节点上,即更新每个节点的总奖励值。因为模拟阶段并没有创建新的节点,所以奖励值将首先加到模拟开始的节点上,然后沿着选择的路径直到根节点:
def _backpropagate(self, path, reward): "Send the reward back up to the ancestors of the leaf" for node in reversed(path): self.N[node] += 1 self.Q[node] += reward reward = 1 - reward # 1 for me is 0 for my enemy, and vice versa
MCTS实现
代码的实现参考一份极简的开源实现1
from abc import ABC, abstractmethodfrom collections import defaultdictimport math
class MCTS: "Monte Carlo tree searcher. First rollout the tree then choose a move."
def __init__(self, exploration_weight=1): self.Q = defaultdict(int) # total reward of each node self.N = defaultdict(int) # total visit count for each node self.children = dict() # children of each node self.exploration_weight = exploration_weight
def choose(self, node): "Choose the best successor of node. (Choose a move in the game)" if node.is_terminal(): raise RuntimeError(f"choose called on terminal node {node}")
if node not in self.children: return node.find_random_child()
def score(n): if self.N[n] == 0: return float("-inf") # avoid unseen moves return self.Q[n] / self.N[n] # average reward
return max(self.children[node], key=score)
def do_rollout(self, node): "Make the tree one layer better. (Train for one iteration.)" path = self._select(node) leaf = path[-1] self._expand(leaf) reward = self._simulate(leaf) self._backpropagate(path, reward)
def _select(self, node): "Find an unexplored descendent of `node`" path = [] while True: path.append(node) if node not in self.children or not self.children[node]: # node is either unexplored or terminal return path unexplored = self.children[node] - self.children.keys() if unexplored: n = unexplored.pop() path.append(n) return path node = self._uct_select(node) # descend a layer deeper
def _expand(self, node): "Update the `children` dict with the children of `node`" if node in self.children: return # already expanded # add all unexplored children of `node` to the tree self.children[node] = node.find_children()
def _simulate(self, node): "Returns the reward for a random simulation (to completion) of `node`" invert_reward = True while True: if node.is_terminal(): reward = node.reward() return 1 - reward if invert_reward else reward node = node.find_random_child() invert_reward = not invert_reward
def _backpropagate(self, path, reward): "Send the reward back up to the ancestors of the leaf" for node in reversed(path): self.N[node] += 1 self.Q[node] += reward reward = 1 - reward # 1 for me is 0 for my enemy, and vice versa
def _uct_select(self, node): "Select a child of node, balancing exploration & exploitation"
# All children of node should already be expanded: assert all(n in self.children for n in self.children[node])
log_N_vertex = math.log(self.N[node])
def uct(n): "Upper confidence bound for trees" return self.Q[n] / self.N[n] + self.exploration_weight * math.sqrt( log_N_vertex / self.N[n] )
return max(self.children[node], key=uct)
class Node(ABC): """ A representation of a single board state. MCTS works by constructing a tree of these Nodes. Could be e.g. a chess or checkers board state. """
@abstractmethod def find_children(self): "All possible successors of this board state" return set()
@abstractmethod def find_random_child(self): "Random successor of this board state (for more efficient simulation)" return None
@abstractmethod def is_terminal(self): "Returns True if the node has no children" return True
@abstractmethod def reward(self): "Assumes `self` is terminal node. 1=win, 0=loss, .5=tie, etc" return 0
@abstractmethod def __hash__(self): "Nodes must be hashable" return 123456789
@abstractmethod def __eq__(node1, node2): "Nodes must be comparable" return True
MCTS可视化
针对一个路径和问题,vibe coding了一个MCTS树展开的可视化过程,可以先安装一个flask
环境,运行后端代码,运行单步迭代详情。
MCTS可视化后端代码
import mathimport randomfrom flask import Flask, jsonify, requestfrom flask_cors import CORSimport requestsimport jsonimport logging
# --- Game Definition ---class Game: TARGET_NUMBER = 10 POSSIBLE_ACTIONS = [1, 2, 3] MAX_MOVES = 7
@staticmethod def get_initial_state(): return (0, 0) # (current_sum, moves_taken)
@staticmethod def perform_action(state, action): current_sum, moves_taken = state new_sum = current_sum + action return (new_sum, moves_taken + 1)
@staticmethod def get_legal_actions(state): current_sum, moves_taken = state app.logger.debug(f"Game.get_legal_actions called with state: {state}. TARGET: {Game.TARGET_NUMBER}, MAX_MOVES: {Game.MAX_MOVES}, CLASS POSSIBLE_ACTIONS: {Game.POSSIBLE_ACTIONS}") if current_sum >= Game.TARGET_NUMBER or moves_taken >= Game.MAX_MOVES: app.logger.debug(f"Game.get_legal_actions for state {state} returning [] because terminal condition met: sum_check: {current_sum >= Game.TARGET_NUMBER}, moves_check: {moves_taken >= Game.MAX_MOVES}") return [] legal_actions_copy = list(Game.POSSIBLE_ACTIONS) app.logger.debug(f"Game.get_legal_actions for state {state} returning: {legal_actions_copy}") return legal_actions_copy
@staticmethod def is_terminal(state): current_sum, moves_taken = state return current_sum >= Game.TARGET_NUMBER or moves_taken >= Game.MAX_MOVES
@staticmethod def get_reward(state): current_sum, moves_taken = state if current_sum == Game.TARGET_NUMBER: return 1.0 elif current_sum > Game.TARGET_NUMBER: return -1.0 elif moves_taken >= Game.MAX_MOVES and current_sum != Game.TARGET_NUMBER: return -0.5 return 0.0
# --- MCTS Node ---class MCTSNode: _id_counter = 0
def __init__(self, state, parent=None, action_that_led_to_state=None): self.id = MCTSNode._id_counter MCTSNode._id_counter += 1 self.state = state self.parent = parent self.action_that_led_to_state = action_that_led_to_state self.children = [] self.visits = 0 self.value = 0.0 self.untried_actions = Game.get_legal_actions(self.state)
def is_fully_expanded(self): return len(self.untried_actions) == 0
def is_terminal(self): return Game.is_terminal(self.state)
def ucb_score(self, exploration_constant=1.414): if self.visits == 0: return float('inf') exploitation_term = self.value / self.visits if self.parent is None or self.parent.visits == 0: return exploitation_term exploration_term = exploration_constant * math.sqrt(math.log(self.parent.visits) / self.visits) return exploitation_term + exploration_term
def add_child(self, action, child_state): child_node = MCTSNode(child_state, parent=self, action_that_led_to_state=action) self.children.append(child_node) if action in self.untried_actions: self.untried_actions.remove(action) return child_node
def update(self, reward): self.visits += 1 self.value += reward
def to_dict_simple(self): return { "id": self.id, "state": self.state, "action_from_parent": self.action_that_led_to_state, "visits": self.visits, "value": round(self.value, 2) }
def __repr__(self): return (f"Node(ID:{self.id}, S:{self.state}, A:{self.action_that_led_to_state}, " f"V:{self.value:.2f}, N:{self.visits}, Children:{len(self.children)})")
# --- MCTS Algorithm ---class MCTS: _game_path_node_id_counter = 0
def __init__(self, initial_game_state_tuple): MCTSNode._id_counter = 0 self.root = MCTSNode(initial_game_state_tuple) self.iteration_count = 0
MCTS._game_path_node_id_counter = 0 self.game_path_history = [{ "state": initial_game_state_tuple, "action_from_parent": None, "id": f"game_path_{MCTS._game_path_node_id_counter}" }] MCTS._game_path_node_id_counter += 1
def _select_promising_node_with_path(self, node): current_node = node selection_path = [current_node.to_dict_simple()] while not current_node.is_terminal(): if not current_node.is_fully_expanded(): return current_node, selection_path else: if not current_node.children: return current_node, selection_path current_node = max(current_node.children, key=lambda n: n.ucb_score()) selection_path.append(current_node.to_dict_simple()) return current_node, selection_path
def _expand_node(self, node): if not node.untried_actions: return None, None action = random.choice(node.untried_actions) next_state_tuple = Game.perform_action(node.state, action) new_child_node = node.add_child(action, next_state_tuple) return new_child_node, action
def _simulate_random_rollout_with_path(self, node): current_rollout_state = node.state rollout_path = [] moves = 0 while not Game.is_terminal(current_rollout_state) and moves < Game.MAX_MOVES * 2 : legal_actions = Game.get_legal_actions(current_rollout_state) if not legal_actions: break action = random.choice(legal_actions) rollout_path.append({"from_state": current_rollout_state, "action_taken": action}) current_rollout_state = Game.perform_action(current_rollout_state, action) moves += 1 reward = Game.get_reward(current_rollout_state) rollout_path.append({"terminal_state": current_rollout_state, "reward_obtained": reward}) return reward, rollout_path
def _backpropagate_with_path(self, node, reward): temp_node = node backpropagation_path = [] while temp_node is not None: temp_node.update(reward) backpropagation_path.append({ "node_id": temp_node.id, "state": temp_node.state, "updated_visits": temp_node.visits, "updated_value": round(temp_node.value, 2) }) temp_node = temp_node.parent return backpropagation_path
def run_one_iteration_detailed(self): self.iteration_count += 1 iteration_details = {"iteration_number": self.iteration_count} promising_node, selection_path = self._select_promising_node_with_path(self.root) iteration_details["selection"] = { "path": selection_path, "selected_node_id": promising_node.id, "selected_node_state": promising_node.state, "is_terminal": promising_node.is_terminal(), "is_fully_expanded": promising_node.is_fully_expanded() } node_for_rollout = promising_node expanded_child_info = None if not promising_node.is_terminal() and promising_node.untried_actions: expanded_child, expanded_action = self._expand_node(promising_node) if expanded_child: node_for_rollout = expanded_child expanded_child_info = {"action": expanded_action, "child_id": expanded_child.id, "child_state": expanded_child.state} iteration_details["expansion"] = { "parent_node_id": promising_node.id, "action_taken": expanded_action, "new_child_node_id": expanded_child.id, "new_child_state": expanded_child.state } else: iteration_details["expansion"] = {"error": "Expansion failed"} else: iteration_details["expansion"] = {"message": "No expansion (node terminal or fully expanded)."}
app.logger.debug(f"Iter {self.iteration_count}: Selection ended at Node ID {promising_node.id}. Expansion: {expanded_child_info if expanded_child_info else iteration_details['expansion'].get('message', 'N/A')}")
simulation_reward, rollout_path = self._simulate_random_rollout_with_path(node_for_rollout) iteration_details["simulation"] = { "start_node_id": node_for_rollout.id, "start_node_state": node_for_rollout.state, "rollout_path": rollout_path, "reward": simulation_reward } app.logger.debug(f"Iter {self.iteration_count}: Simulation from Node ID {node_for_rollout.id} yielded reward {simulation_reward}")
backpropagation_path = self._backpropagate_with_path(node_for_rollout, simulation_reward) iteration_details["backpropagation"] = { "start_node_id": node_for_rollout.id, "reward_propagated": simulation_reward, "updated_path": backpropagation_path } app.logger.debug(f"Iter {self.iteration_count}: Backpropagation complete. Root visits: {self.root.visits}, Root value: {self.root.value:.2f}") return iteration_details
def get_best_action(self): if not self.root or not self.root.children: return None best_child_node = max(self.root.children, key=lambda node: node.visits) if best_child_node.visits == 0 and len(self.root.children) > 0: app.logger.warning(f"Best child (action: {best_child_node.action_that_led_to_state}) has 0 visits. MCTS might need more iterations from current root.") return best_child_node.action_that_led_to_state
def advance_tree(self, action): if not self.root: app.logger.error("Cannot advance tree: MCTS root is None.") return False, "MCTS树未初始化。" try: action_val = int(action) except ValueError: app.logger.error(f"Invalid action type for advance_tree: {action}. Expected int.") return False, f"提供的行动 '{action}' 类型无效。"
found_child = None for child in self.root.children: if child.action_that_led_to_state == action_val: found_child = child break
if found_child: self.root = found_child self.root.parent = None self.root.children = [] self.root.untried_actions = Game.get_legal_actions(self.root.state)
new_path_node_id = f"game_path_{MCTS._game_path_node_id_counter}" MCTS._game_path_node_id_counter += 1 self.game_path_history.append({ "state": self.root.state, "action_from_parent": action_val, "id": new_path_node_id }) app.logger.info(f"Tree advanced. New root is Node ID: {self.root.id}, State: {self.root.state}, Untried Actions: {len(self.root.untried_actions)}, Children: {len(self.root.children)}") return True, "树已成功推进到新状态,并已重置其子节点和未尝试行动列表。" else: app.logger.warning(f"Action {action_val} not found among children of current root (ID: {self.root.id}). Children actions: {[c.action_that_led_to_state for c in self.root.children]}") return False, f"行动 {action_val} 不是当前根节点的有效子行动。"
def get_game_path_tree_data(self): if not self.game_path_history: return {} history_root_data = self.game_path_history[0] display_tree_root = { "id": history_root_data["id"], "state": history_root_data["state"], "action_from_parent": history_root_data.get("action_from_parent"), "is_current_game_node": len(self.game_path_history) == 1, "value": "N/A", "visits": "N/A", "avg_value": "N/A", "ucb_score_from_parent": "N/A", "is_terminal": Game.is_terminal(history_root_data["state"]), "is_fully_expanded": "N/A", "displayed_children": [] } current_parent_in_display_tree = display_tree_root for i in range(1, len(self.game_path_history)): history_node_data = self.game_path_history[i] is_current = (i == len(self.game_path_history) - 1) child_display_node = { "id": history_node_data["id"], "state": history_node_data["state"], "action_from_parent": history_node_data.get("action_from_parent"), "is_current_game_node": is_current, "value": "N/A", "visits": "N/A", "avg_value": "N/A", "ucb_score_from_parent": "N/A", "is_terminal": Game.is_terminal(history_node_data["state"]), "is_fully_expanded": "N/A", "displayed_children": [] } current_parent_in_display_tree["displayed_children"].append(child_display_node) current_parent_in_display_tree = child_display_node return display_tree_root
def _get_node_data_recursive(self, node, current_depth, max_depth): if node is None or current_depth > max_depth: return None children_data_list = [] if current_depth < max_depth: for child_node in node.children: child_data = self._get_node_data_recursive(child_node, current_depth + 1, max_depth) if child_data: children_data_list.append(child_data) return { "id": node.id, "state": node.state, "action_from_parent": node.action_that_led_to_state, "visits": node.visits, "value": f"{node.value:.2f}", "avg_value": f"{(node.value / node.visits) if node.visits > 0 else 'N/A'}", "ucb_score_from_parent": f"{node.ucb_score():.2f}" if node.parent else "N/A (Root)", "is_terminal": node.is_terminal(), "is_fully_expanded": node.is_fully_expanded(), "untried_actions_count": len(node.untried_actions), "children_count": len(node.children), "displayed_children": children_data_list }
def get_tree_visualization_data(self, max_display_depth=2): if self.root is None: return {} return self._get_node_data_recursive(self.root, 0, max_display_depth)
# --- Flask Application ---app = Flask(__name__)CORS(app)logging.basicConfig(level=logging.DEBUG)app.logger.setLevel(logging.DEBUG)
mcts_algorithm = Nonecurrent_game_settings = {}
GEMINI_API_KEY = ""GEMINI_API_URL_BASE = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
def call_gemini_api(prompt_text): api_url = f"{GEMINI_API_URL_BASE}?key={GEMINI_API_KEY}" chat_history = [{"role": "user", "parts": [{"text": prompt_text}]}] payload = {"contents": chat_history} app.logger.info(f"Calling Gemini API with prompt (first 200 chars): {prompt_text[:200]}...") try: response = requests.post(api_url, json=payload, headers={'Content-Type': 'application/json'}, timeout=30) response.raise_for_status() result = response.json() app.logger.debug(f"Gemini API raw response: {result}") if (result.get("candidates") and result["candidates"][0].get("content") and result["candidates"][0]["content"].get("parts") and result["candidates"][0]["content"]["parts"][0].get("text")): return result["candidates"][0]["content"]["parts"][0]["text"] else: app.logger.error(f"Unexpected Gemini API response structure: {result}") return "AI未能按预期格式响应。" except requests.exceptions.Timeout: app.logger.error("Gemini API request timed out.") raise Exception("AI 服务请求超时") except requests.exceptions.RequestException as e: app.logger.error(f"Gemini API request failed: {e}") raise Exception(f"AI 服务请求失败: {str(e)}") except Exception as e: app.logger.error(f"Error processing Gemini response or other unexpected error: {e}") raise Exception(f"处理AI响应时出错: {str(e)}")
@app.route('/init_game', methods=['POST'])def init_game_endpoint(): global mcts_algorithm, current_game_settings try: data = request.json Game.TARGET_NUMBER = int(data.get('target', 10)) Game.MAX_MOVES = int(data.get('max_moves', 7)) Game.POSSIBLE_ACTIONS = [int(a) for a in data.get('actions', [1, 2, 3])] current_game_settings = { "target": Game.TARGET_NUMBER, "max_moves": Game.MAX_MOVES, "actions": Game.POSSIBLE_ACTIONS } app.logger.info(f"Game initialized with settings: {current_game_settings}") initial_state = Game.get_initial_state() mcts_algorithm = MCTS(initial_state) app.logger.info(f"Initial MCTS root: ID {mcts_algorithm.root.id}, State {mcts_algorithm.root.state}, Untried: {len(mcts_algorithm.root.untried_actions)}, Children: {len(mcts_algorithm.root.children)}") return jsonify({ "message": "游戏初始化成功。", "settings": current_game_settings, "initial_state": initial_state, "is_initial_terminal": Game.is_terminal(initial_state), "tree_root": mcts_algorithm.get_tree_visualization_data(max_display_depth=2), "game_path_tree": mcts_algorithm.get_game_path_tree_data() }), 200 except Exception as e: app.logger.error(f"Initialization failed: {str(e)}", exc_info=True) return jsonify({"error": f"初始化失败: {str(e)}"}), 400
@app.route('/run_mcts_iterations', methods=['POST'])def run_mcts_iterations_endpoint(): global mcts_algorithm if not mcts_algorithm: return jsonify({"error": "游戏未初始化。"}), 400 app.logger.info(f"Entering /run_mcts_iterations. Root ID: {mcts_algorithm.root.id}, State: {mcts_algorithm.root.state}, Untried: {len(mcts_algorithm.root.untried_actions)}, Children: {len(mcts_algorithm.root.children)}") try: data = request.json num_iterations = int(data.get('iterations', 1)) if num_iterations <= 0: return jsonify({"error": "迭代次数必须为正数。"}), 400 for i in range(num_iterations): app.logger.debug(f"--- Batch Iteration {i+1}/{num_iterations} for Root ID {mcts_algorithm.root.id} ---") promising_node, _ = mcts_algorithm._select_promising_node_with_path(mcts_algorithm.root) node_for_rollout = promising_node if not promising_node.is_terminal() and promising_node.untried_actions: expanded_child, expanded_action = mcts_algorithm._expand_node(promising_node) if expanded_child: node_for_rollout = expanded_child app.logger.debug(f"Batch Iter {i+1}: Expanded Node ID {promising_node.id} with action {expanded_action} to new child ID {expanded_child.id}") else: app.logger.warning(f"Batch Iter {i+1}: Expansion failed for Node ID {promising_node.id} despite having untried_actions.") else: app.logger.debug(f"Batch Iter {i+1}: No expansion for Node ID {promising_node.id} (Terminal: {promising_node.is_terminal()}, Untried: {len(promising_node.untried_actions)})")
simulation_reward, _ = mcts_algorithm._simulate_random_rollout_with_path(node_for_rollout) mcts_algorithm._backpropagate_with_path(node_for_rollout, simulation_reward) mcts_algorithm.iteration_count +=1 app.logger.info(f"{num_iterations} MCTS batch iterations completed. Total user iterations: {mcts_algorithm.iteration_count}") app.logger.info(f"After batch iterations. Root ID: {mcts_algorithm.root.id}, State: {mcts_algorithm.root.state}, Untried: {len(mcts_algorithm.root.untried_actions)}, Children: {len(mcts_algorithm.root.children)}") return jsonify({ "message": f"{num_iterations} 次批量 MCTS 迭代完成。总用户迭代次数: {mcts_algorithm.iteration_count}", "tree_root": mcts_algorithm.get_tree_visualization_data(max_display_depth=3), "root_stats": {"visits": mcts_algorithm.root.visits, "value": f"{mcts_algorithm.root.value:.2f}"} }), 200 except Exception as e: app.logger.error(f"Error during MCTS batch iterations: {str(e)}", exc_info=True) return jsonify({"error": f"MCTS 批量迭代过程中出错: {str(e)}"}), 500
@app.route('/run_single_iteration_detailed', methods=['POST'])def run_single_iteration_detailed_endpoint(): global mcts_algorithm if not mcts_algorithm: return jsonify({"error": "游戏未初始化。"}), 400 app.logger.info(f"Entering /run_single_iteration_detailed. Root ID: {mcts_algorithm.root.id}, State: {mcts_algorithm.root.state}, Untried: {len(mcts_algorithm.root.untried_actions)}, Children: {len(mcts_algorithm.root.children)}") try: detailed_steps = mcts_algorithm.run_one_iteration_detailed() app.logger.info(f"Detailed steps for user iteration {mcts_algorithm.iteration_count} completed.") app.logger.info(f"After single detailed iteration. Root ID: {mcts_algorithm.root.id}, State: {mcts_algorithm.root.state}, Untried: {len(mcts_algorithm.root.untried_actions)}, Children: {len(mcts_algorithm.root.children)}") return jsonify({ "message": f"单步 MCTS 迭代 {mcts_algorithm.iteration_count} 完成。", "detailed_steps": detailed_steps, "tree_root": mcts_algorithm.get_tree_visualization_data(max_display_depth=3) }), 200 except Exception as e: app.logger.error(f"Error during single detailed MCTS iteration: {str(e)}", exc_info=True) return jsonify({"error": f"MCTS 单步详细迭代过程中出错: {str(e)}"}), 500
@app.route('/get_best_move', methods=['GET'])def get_best_move_endpoint(): global mcts_algorithm if not mcts_algorithm or not mcts_algorithm.root: return jsonify({"error": "游戏未初始化或没有 MCTS 数据。"}), 400 app.logger.info(f"Entering /get_best_move. Root ID: {mcts_algorithm.root.id}, State: {mcts_algorithm.root.state}, Untried: {len(mcts_algorithm.root.untried_actions)}, Children: {len(mcts_algorithm.root.children)}")
if not mcts_algorithm.root.children: if Game.is_terminal(mcts_algorithm.root.state): message = "当前状态已是终止状态,无法获取下一步行动。" else: message = "根节点没有探索任何行动。请运行更多迭代或检查游戏状态。" app.logger.warning(f"/get_best_move: {message} for Root ID {mcts_algorithm.root.id}") return jsonify({ "message": message, "best_action": None, "current_sum": mcts_algorithm.root.state[0] if mcts_algorithm.root.state else 'N/A', "moves_taken": mcts_algorithm.root.state[1] if mcts_algorithm.root.state and len(mcts_algorithm.root.state) > 1 else 'N/A', "details": "根节点没有子节点或已是终止状态。" }), 200
best_action = mcts_algorithm.get_best_action() if best_action is None and mcts_algorithm.root.children: app.logger.error("CRITICAL: get_best_action returned None even though root has children!") return jsonify({ "message": "内部错误:无法在有子节点的情况下确定最佳行动。", "best_action": None, "current_sum": mcts_algorithm.root.state[0], "moves_taken": mcts_algorithm.root.state[1], "details": "内部逻辑错误。" }), 500
children_info = [] for child in mcts_algorithm.root.children: children_info.append({ "action": child.action_that_led_to_state, "visits": child.visits, "value": f"{child.value:.2f}", "avg_value": f"{(child.value / child.visits) if child.visits > 0 else 'N/A'}" }) app.logger.info(f"Best move suggested: {best_action} from Root ID {mcts_algorithm.root.id}") return jsonify({ "message": "MCTS 建议的最佳行动。", "best_action": best_action, "current_sum": mcts_algorithm.root.state[0] if mcts_algorithm.root.state else 'N/A', "moves_taken": mcts_algorithm.root.state[1] if mcts_algorithm.root.state and len(mcts_algorithm.root.state) > 1 else 'N/A', "root_children_stats": sorted(children_info, key=lambda x: x["visits"], reverse=True) }), 200
@app.route('/apply_move', methods=['POST'])def apply_move_endpoint(): global mcts_algorithm if not mcts_algorithm: return jsonify({"error": "游戏未初始化。"}), 400 try: data = request.json action_to_apply = data.get("action") if action_to_apply is None: return jsonify({"error": "请求中未提供行动。"}), 400 try: action_to_apply_int = int(action_to_apply) except ValueError: return jsonify({"error": f"行动 '{action_to_apply}' 必须是有效的数字。"}), 400
success, message = mcts_algorithm.advance_tree(action_to_apply_int)
if success: new_root_state = mcts_algorithm.root.state is_terminal_now = Game.is_terminal(new_root_state)
# Removed AUTO_ITERATIONS_AFTER_MOVE block app.logger.info(f"Auto-iterations after move have been REMOVED.") app.logger.info(f"State of new root (ID: {mcts_algorithm.root.id}) after advancing tree: Children: {len(mcts_algorithm.root.children)}, Untried: {len(mcts_algorithm.root.untried_actions)}")
current_sum_after_move = new_root_state[0] moves_taken_after_move = new_root_state[1] app.logger.info(f"Move {action_to_apply_int} applied. New state: {new_root_state}, Terminal: {is_terminal_now}") return jsonify({ "message": f"行动 {action_to_apply_int} 已执行。{message}", "new_root_state": new_root_state, "current_sum": current_sum_after_move, "moves_taken": moves_taken_after_move, "is_terminal": is_terminal_now, "tree_root": mcts_algorithm.get_tree_visualization_data(max_display_depth=2), "game_path_tree": mcts_algorithm.get_game_path_tree_data() }), 200 else: app.logger.warning(f"Failed to apply move {action_to_apply_int}: {message}") return jsonify({ "error": message, "tree_root": mcts_algorithm.get_tree_visualization_data(max_display_depth=2), "game_path_tree": mcts_algorithm.get_game_path_tree_data() }), 400 except Exception as e: app.logger.error(f"Error in /apply_move: {str(e)}", exc_info=True) return jsonify({"error": f"应用行动时出错: {str(e)}"}), 500if __name__ == '__main__': app.run(debug=True, port=5000)