-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrainer.py
112 lines (80 loc) · 3.68 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
def reward_fn(coords, tour):
"""Reward function. Compute the total distance for a tour, given the
coordinates of each city and the tour indexes.
Args:
coords (torch.Tensor): Tensor of size [batch_size, seq_len, dim],
representing each city's coordinates.
tour (torch.Tensor): Tensor of size [batch_size, seq_len + 1],
representing the tour's indexes (comes back to the first city).
Returns:
float: Reward for this tour.
"""
dim = coords.size(-1)
ordered_coords = torch.gather(coords, 1, tour.long().unsqueeze(-1).repeat(1, 1, dim))
ordered_coords = ordered_coords.transpose(0, 2) # [dim, seq_len, batch_size]
# For each dimension (x, y), compute the squared difference between each city
delta2 = [torch.square(d[1:] - d[:-1]).transpose(0, 1) for d in ordered_coords]
# Euclidian distance between each city
inter_city_distances = torch.sqrt(sum(delta2))
distance = inter_city_distances.sum(dim=-1)
return distance.float()
class Trainer():
def __init__(self, conf, agent, dataset):
"""Trainer class, taking care of training the agent.
Args:
conf (OmegaConf.DictConf): Configuration.
agent (torch.nn.Module): Agent network to train.
dataset (data.DataGenerator): Data generator.
"""
super().__init__()
self.conf = conf
self.agent = agent
self.dataset = dataset
self.device = torch.device(self.conf.device)
self.agent = self.agent.to(self.device)
self.optim = torch.optim.Adam(params=self.agent.parameters(), lr=self.conf.lr)
gamma = 1 - self.conf.lr_decay_rate / self.conf.lr_decay_steps # To have same behavior as Tensorflow implementation
self.scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optim, gamma=gamma)
def train_step(self, data):
self.optim.zero_grad()
# Forward pass
tour, critique, log_probs, _ = self.agent(data)
# Compute reward
reward = reward_fn(data, tour)
# Compute losses for both actor (reinforce) and critic
loss1 = ((reward - critique).detach() * log_probs).mean()
loss2 = F.mse_loss(reward, critique)
# Backward pass
(loss1 + loss2).backward()
# Clip gradients
nn.utils.clip_grad_norm_(self.agent.parameters(), self.conf.grad_clip)
# Optimize
self.optim.step()
# Update LR
self.scheduler.step()
return reward.mean(), [loss1, loss2]
def run(self):
self.agent.train()
running_reward, running_losses = 0, [0, 0]
for step in range(self.conf.steps):
input_batch = self.dataset.train_batch(self.conf.batch_size, self.conf.max_len, self.conf.dimension)
input_batch = torch.Tensor(input_batch).to(self.device)
reward, losses = self.train_step(input_batch)
running_reward += reward
running_losses[0] += losses[0]
running_losses[1] += losses[1]
if step % self.conf.log_interval == 0 and step != 0:
# Log stuff
wandb.log({
'reward': running_reward / self.conf.log_interval,
'actor_loss': running_losses[0] / self.conf.log_interval,
'critic_loss': running_losses[1] / self.conf.log_interval,
'learning_rate': self.scheduler.get_last_lr()[0],
'step': step
})
# Reset running reward/loss
running_reward, running_losses = 0, [0, 0]