| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +#  | 
 | 3 | +# This source code is licensed under the MIT license found in the  | 
 | 4 | +# LICENSE file in the root directory of this source tree.  | 
 | 5 | + | 
 | 6 | +import torch  | 
 | 7 | +import torchrl  | 
 | 8 | +from tensordict import TensorDict  | 
 | 9 | + | 
 | 10 | +pgn_or_fen = "fen"  | 
 | 11 | + | 
 | 12 | +env = torchrl.envs.ChessEnv(  | 
 | 13 | +    include_pgn=False,  | 
 | 14 | +    include_fen=True,  | 
 | 15 | +    include_hash=True,  | 
 | 16 | +    include_hash_inv=True,  | 
 | 17 | +    include_san=True,  | 
 | 18 | +    stateful=True,  | 
 | 19 | +    mask_actions=True,  | 
 | 20 | +)  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +def transform_reward(td):  | 
 | 24 | +    if "reward" not in td:  | 
 | 25 | +        return td  | 
 | 26 | +    reward = td["reward"]  | 
 | 27 | +    if reward == 0.5:  | 
 | 28 | +        td["reward"] = 0  | 
 | 29 | +    elif reward == 1 and td["turn"]:  | 
 | 30 | +        td["reward"] = -td["reward"]  | 
 | 31 | +    return td  | 
 | 32 | + | 
 | 33 | + | 
 | 34 | +# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.  | 
 | 35 | +# Need to transform the reward to be:  | 
 | 36 | +#   white win = 1  | 
 | 37 | +#   draw = 0  | 
 | 38 | +#   black win = -1  | 
 | 39 | +env.append_transform(transform_reward)  | 
 | 40 | + | 
 | 41 | +forest = torchrl.data.MCTSForest()  | 
 | 42 | +forest.reward_keys = env.reward_keys + ["_visits", "_reward_sum"]  | 
 | 43 | +forest.done_keys = env.done_keys  | 
 | 44 | +forest.action_keys = env.action_keys  | 
 | 45 | +forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]  | 
 | 46 | + | 
 | 47 | +C = 2.0**0.5  | 
 | 48 | + | 
 | 49 | + | 
 | 50 | +def traversal_priority_UCB1(tree):  | 
 | 51 | +    if tree.rollout[-1]["next", "_visits"] == 0:  | 
 | 52 | +        res = float("inf")  | 
 | 53 | +    else:  | 
 | 54 | +        if tree.parent.rollout is None:  | 
 | 55 | +            parent_visits = 0  | 
 | 56 | +            for child in tree.parent.subtree:  | 
 | 57 | +                parent_visits += child.rollout[-1]["next", "_visits"]  | 
 | 58 | +        else:  | 
 | 59 | +            parent_visits = tree.parent.rollout[-1]["next", "_visits"]  | 
 | 60 | +            assert parent_visits > 0  | 
 | 61 | + | 
 | 62 | +        value_avg = (  | 
 | 63 | +            tree.rollout[-1]["next", "_reward_sum"]  | 
 | 64 | +            / tree.rollout[-1]["next", "_visits"]  | 
 | 65 | +        )  | 
 | 66 | + | 
 | 67 | +        # If it's black's turn, flip the reward, since black wants to optimize  | 
 | 68 | +        # for the lowest reward.  | 
 | 69 | +        if not tree.rollout[0]["turn"]:  | 
 | 70 | +            value_avg = -value_avg  | 
 | 71 | + | 
 | 72 | +        res = (  | 
 | 73 | +            value_avg  | 
 | 74 | +            + C  | 
 | 75 | +            * torch.sqrt(torch.log(parent_visits) / tree.rollout[-1]["next", "_visits"])  | 
 | 76 | +        ).item()  | 
 | 77 | + | 
 | 78 | +    return res  | 
 | 79 | + | 
 | 80 | + | 
 | 81 | +def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):  | 
 | 82 | +    done = False  | 
 | 83 | +    trees_visited = []  | 
 | 84 | + | 
 | 85 | +    while not done:  | 
 | 86 | +        if tree.subtree is None:  | 
 | 87 | +            td_tree = tree.rollout[-1]["next"]  | 
 | 88 | + | 
 | 89 | +            if (td_tree["_visits"] > 0 or tree.parent is None) and not td_tree["done"]:  | 
 | 90 | +                actions = env.all_actions(td_tree)  | 
 | 91 | +                subtrees = []  | 
 | 92 | + | 
 | 93 | +                for action in actions:  | 
 | 94 | +                    td = env.step(env.reset(td_tree.clone()).update(action)).update(  | 
 | 95 | +                        TensorDict(  | 
 | 96 | +                            {  | 
 | 97 | +                                ("next", "_visits"): 0,  | 
 | 98 | +                                ("next", "_reward_sum"): env.reward_spec.zeros(),  | 
 | 99 | +                            }  | 
 | 100 | +                        )  | 
 | 101 | +                    )  | 
 | 102 | + | 
 | 103 | +                    new_node = torchrl.data.Tree(  | 
 | 104 | +                        rollout=td.unsqueeze(0),  | 
 | 105 | +                        node_data=td["next"].select(*forest.node_map.in_keys),  | 
 | 106 | +                    )  | 
 | 107 | +                    subtrees.append(new_node)  | 
 | 108 | + | 
 | 109 | +                tree.subtree = TensorDict.lazy_stack(subtrees)  | 
 | 110 | +                chosen_idx = torch.randint(0, len(subtrees), ()).item()  | 
 | 111 | +                rollout_state = subtrees[chosen_idx].rollout[-1]["next"]  | 
 | 112 | + | 
 | 113 | +            else:  | 
 | 114 | +                rollout_state = td_tree  | 
 | 115 | + | 
 | 116 | +            if rollout_state["done"]:  | 
 | 117 | +                rollout_reward = rollout_state["reward"]  | 
 | 118 | +            else:  | 
 | 119 | +                rollout = env.rollout(  | 
 | 120 | +                    max_steps=max_rollout_steps,  | 
 | 121 | +                    tensordict=rollout_state,  | 
 | 122 | +                )  | 
 | 123 | +                rollout_reward = rollout[-1]["next", "reward"]  | 
 | 124 | +            done = True  | 
 | 125 | + | 
 | 126 | +        else:  | 
 | 127 | +            priorities = torch.tensor(  | 
 | 128 | +                [traversal_priority_UCB1(subtree) for subtree in tree.subtree]  | 
 | 129 | +            )  | 
 | 130 | +            chosen_idx = torch.argmax(priorities).item()  | 
 | 131 | +            tree = tree.subtree[chosen_idx]  | 
 | 132 | +            trees_visited.append(tree)  | 
 | 133 | + | 
 | 134 | +    for tree in trees_visited:  | 
 | 135 | +        td = tree.rollout[-1]["next"]  | 
 | 136 | +        td["_visits"] += 1  | 
 | 137 | +        td["_reward_sum"] += rollout_reward  | 
 | 138 | + | 
 | 139 | + | 
 | 140 | +def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps):  | 
 | 141 | +    """Performs Monte-Carlo tree search in an environment.  | 
 | 142 | +
  | 
 | 143 | +    Args:  | 
 | 144 | +        forest (MCTSForest): Forest of the tree to update. If the tree does not  | 
 | 145 | +            exist yet, it is added.  | 
 | 146 | +        root (TensorDict): The root step of the tree to update.  | 
 | 147 | +        env (EnvBase): Environment to performs actions in.  | 
 | 148 | +        num_steps (int): Number of iterations to traverse.  | 
 | 149 | +        max_rollout_steps (int): Maximum number of steps for each rollout.  | 
 | 150 | +    """  | 
 | 151 | +    if root not in forest:  | 
 | 152 | +        for action in env.all_actions(root.clone()):  | 
 | 153 | +            td = env.step(env.reset(root.clone()).update(action)).update(  | 
 | 154 | +                TensorDict(  | 
 | 155 | +                    {  | 
 | 156 | +                        ("next", "_visits"): 0,  | 
 | 157 | +                        ("next", "_reward_sum"): env.reward_spec.zeros(),  | 
 | 158 | +                    }  | 
 | 159 | +                )  | 
 | 160 | +            )  | 
 | 161 | +            forest.extend(td.unsqueeze(0))  | 
 | 162 | + | 
 | 163 | +    tree = forest.get_tree(root)  | 
 | 164 | + | 
 | 165 | +    for _ in range(num_steps):  | 
 | 166 | +        _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)  | 
 | 167 | + | 
 | 168 | +    return tree  | 
 | 169 | + | 
 | 170 | + | 
 | 171 | +def tree_format_fn(tree):  | 
 | 172 | +    td = tree.rollout[-1]["next"]  | 
 | 173 | +    return [  | 
 | 174 | +        td["san"],  | 
 | 175 | +        td[pgn_or_fen].split("\n")[-1],  | 
 | 176 | +        td["_reward_sum"].item(),  | 
 | 177 | +        td["_visits"].item(),  | 
 | 178 | +    ]  | 
 | 179 | + | 
 | 180 | + | 
 | 181 | +def get_best_move(fen, mcts_steps, rollout_steps):  | 
 | 182 | +    root = env.reset(TensorDict({"fen": fen}))  | 
 | 183 | +    tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps)  | 
 | 184 | + | 
 | 185 | +    # print('------------------------------')  | 
 | 186 | +    # print(tree.to_string(tree_format_fn))  | 
 | 187 | +    # print('------------------------------')  | 
 | 188 | + | 
 | 189 | +    moves = []  | 
 | 190 | + | 
 | 191 | +    for subtree in tree.subtree:  | 
 | 192 | +        san = subtree.rollout[0]["next", "san"]  | 
 | 193 | +        reward_sum = subtree.rollout[-1]["next", "_reward_sum"]  | 
 | 194 | +        visits = subtree.rollout[-1]["next", "_visits"]  | 
 | 195 | +        value_avg = (reward_sum / visits).item()  | 
 | 196 | +        if not subtree.rollout[0]["turn"]:  | 
 | 197 | +            value_avg = -value_avg  | 
 | 198 | +        moves.append((value_avg, san))  | 
 | 199 | + | 
 | 200 | +    moves = sorted(moves, key=lambda x: -x[0])  | 
 | 201 | + | 
 | 202 | +    print("------------------")  | 
 | 203 | +    for value_avg, san in moves:  | 
 | 204 | +        print(f" {value_avg:0.02f} {san}")  | 
 | 205 | +    print("------------------")  | 
 | 206 | + | 
 | 207 | +    return moves[0][1]  | 
 | 208 | + | 
 | 209 | + | 
 | 210 | +# White has M1, best move Rd8#. Any other moves lose to M2 or M1.  | 
 | 211 | +fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"  | 
 | 212 | +assert get_best_move(fen0, 100, 10) == "Rd8#"  | 
 | 213 | + | 
 | 214 | +# Black has M1, best move Qg6#. Other moves give rough equality or worse.  | 
 | 215 | +fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"  | 
 | 216 | +assert get_best_move(fen1, 100, 10) == "Qg6#"  | 
 | 217 | + | 
 | 218 | +# White has M2, best move Rxg8+. Any other move loses.  | 
 | 219 | +fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"  | 
 | 220 | +assert get_best_move(fen2, 1000, 10) == "Rxg8+"  | 
0 commit comments