Skip to content

Commit

Permalink
PPO larger checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
hesic73 committed Dec 16, 2023
1 parent 16fdec5 commit 63835dd
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 8 deletions.
4 changes: 2 additions & 2 deletions cfg/algo/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ optimizer:
lr: 5e-4


num_channels: 32
num_residual_blocks: 3
num_channels: 64
num_residual_blocks: 4


4 changes: 2 additions & 2 deletions cfg/baseline/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ optimizer:
lr: 3e-4


num_channels: 32
num_residual_blocks: 3
num_channels: 64
num_residual_blocks: 4


8 changes: 4 additions & 4 deletions cfg/train_InRL.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ run_dir:

augment: false

epochs: 1500
epochs: 1000
rounds: 64
save_interval: -1
save_interval: 300


black_checkpoint: pretrained_models/${board_size}_${board_size}/${algo.name}/0.pt
white_checkpoint: pretrained_models/${board_size}_${board_size}/${algo.name}/1.pt
black_checkpoint: black_final.pt # pretrained_models/${board_size}_${board_size}/${algo.name}/0.pt
white_checkpoint: white_final.pt # pretrained_models/${board_size}_${board_size}/${algo.name}/1.pt

wandb:
group: ${board_size}_${board_size}_${algo.name}_InRL
Expand Down
24 changes: 24 additions & 0 deletions gomoku_rl/utils/elo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,27 @@ def compute_elo_ratings(payoff: np.ndarray, average_rating: float = 1200) -> np.
elo_ratings = payoff.mean(axis=-1)
elo_ratings = elo_ratings * (400 / np.log(10)) + average_rating
return elo_ratings


def compute_expected_score(rating_0: float, rating_1: float) -> float:
return 1 / (1 + 10.0 ** ((rating_1 - rating_0) / 400))


class Elo:
def __init__(self) -> None:
self.players: dict[str, float] = {}

def addPlayer(self, name: str, rating: float = 1200):
assert name not in self.players
self.players[name] = rating

def expected_score(self, player_0: str, player_1: str) -> float:
rating_0 = self.players[player_0]
rating_1 = self.players[player_1]
return compute_expected_score(rating_0, rating_1)

def update(self, player_0: str, player_1: str, score: float, K: float = 64):
e = self.expected_score(player_0, player_1)
tmp = K * (score - e)
self.players[player_0] = self.players[player_0] + tmp
self.players[player_1] = self.players[player_1] - tmp
Binary file modified pretrained_models/15_15/ppo/0.pt
Binary file not shown.
Binary file modified pretrained_models/15_15/ppo/1.pt
Binary file not shown.

0 comments on commit 63835dd

Please sign in to comment.