-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
156 lines (120 loc) · 4.89 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# STUDENT's UCO: 482857
# Description:
# This file should be used for performing training of a network
# Usage: python training.py <path_2_dataset>
import sys
import matplotlib.pyplot as plt
import torch
from torchview import draw_graph
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import utils as vutils
from network import ModelExample
from dataset import SampleDataset, SampleDataSpliter
# sample function for model architecture visualization
# draw_graph function saves an additional file: Graphviz DOT graph file, it's not necessary to delete it
def draw_network_architecture(network, input_sample):
# saves visualization of model architecture to the model_architecture.png
model_graph = draw_graph(network, input_sample, graph_dir='LR', save_graph=True, filename="model_architecture")
# sample function for losses visualization
def plot_learning_curves(train_losses, validation_losses):
plt.figure(figsize=(10, 5))
plt.title("Train and Evaluation Losses During Training")
plt.plot(train_losses, label="train_loss")
plt.plot(validation_losses, label="validation_loss")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("learning_curves.png")
# sample function for training
def fit(net, batch_size, epochs, trainloader, validloader, loss_fn, optimizer, device):
train_losses = []
validation_losses = []
best_val_loss = float('inf')
best_model_state = None
net.to(device)
for epoch in range(epochs):
# Training phase
net.train()
running_loss = 0.0
for data, labels, _ in trainloader:
data, labels = data.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(data)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_train_loss = running_loss / len(trainloader)
train_losses.append(avg_train_loss)
# Validation phase
net.eval()
running_loss = 0.0
with torch.no_grad():
for data, labels, img_file in validloader:
data, labels = data.to(device), labels.to(device)
outputs = net(data)
loss = loss_fn(outputs, labels)
running_loss += loss.item()
avg_val_loss = running_loss / len(validloader)
validation_losses.append(avg_val_loss)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_model_state = net.state_dict()
print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.5f}, Val Loss: {avg_val_loss:.5f}')
print('Training finished!')
return train_losses, validation_losses, best_model_state
# declaration for this function should not be changed
def training(dataset_path):
"""
training(dataset_path) performs training on the given dataset;
saves:
- model.pt (trained model)
- learning_curves.png (learning curves generated during training)
- model_architecture.png (a scheme of model's architecture)
Parameters:
- dataset_path (string): path to a dataset
Returns:
- None
"""
# Check for available GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Computing with {}!'.format(device))
batch_size = 64
epochs = 12
cityscape_dataset = SampleDataset(data_dir=dataset_path)
sample_data_splitter = SampleDataSpliter(cityscape_dataset)
traindataset = sample_data_splitter.get_train_dataset()
valdataset = sample_data_splitter.get_val_dataset()
trainloader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(valdataset, batch_size=batch_size, shuffle=False)
number_of_classes = 6
dropout = 0.1
net = ModelExample(number_of_classes, dropout)
input_sample = torch.zeros((1, 3, 256, 256)).to(device)
draw_network_architecture(net, input_sample)
optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
tr_losses, val_losses, best_model_state = fit(net, batch_size, epochs, trainloader, valloader, loss_fn, optimizer,
device)
best_model = ModelExample(number_of_classes)
best_model.load_state_dict(best_model_state)
torch.save(best_model, './model.pt')
plot_learning_curves(tr_losses, val_losses)
return
# #### code below should not be changed ############################################################################
def get_arguments():
if len(sys.argv) != 2:
print("Usage: python training.py <path_2_dataset> ")
sys.exit(1)
try:
path = sys.argv[1]
except Exception as e:
print(e)
sys.exit(1)
return path
if __name__ == "__main__":
path_2_dataset = get_arguments()
training(path_2_dataset)