forked from dxzmpk/image_classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
60 lines (59 loc) · 2.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from comet_ml import Experiment
import torch.optim as optim
from torchsummary import summary
from Project import project
from data import get_dataloaders
from data.transformation import train_transform, val_transform
from models import MyCNN, resnet18
from utils import device, show_dl
from poutyne.framework import Model
from poutyne.framework.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from callbacks import CometCallback
from logger import logging
# our hyperparameters
params = {
'lr': 0.001,
'batch_size': 32,
'model': 'resnet18-finetune'
}
logging.info(f'Using device={device} 🚀')
# everything starts with the data
train_dl, val_dl, test_dl = get_dataloaders(
project.data_dir / "train",
project.data_dir / "val",
val_transform=val_transform,
train_transform=train_transform,
batch_size=params['batch_size'],
)
# is always good practice to visualise some of the train and val images to be sure data-aug
# is applied properly
show_dl(train_dl)
show_dl(test_dl)
# define our comet experiment
experiment = Experiment(api_key="YOU_KEY",
project_name="dl-pytorch-template", workspace="francescosaveriozuppichini")
experiment.log_parameters(params)
# create our special resnet18
cnn = resnet18(2).to(device)
# print the model summary to show useful information
logging.info(summary(cnn, (3, 224, 244)))
# define custom optimizer and instantiace the trainer `Model`
optimizer = optim.Adam(cnn.parameters(), lr=params['lr'])
model = Model(cnn, optimizer, "cross_entropy", batch_metrics=["accuracy"]).to(device)
# usually you want to reduce the lr on plateau and store the best model
callbacks = [
ReduceLROnPlateau(monitor="val_acc", patience=5, verbose=True),
ModelCheckpoint(str(project.checkpoint_dir / f"{time.time()}-model.pt"), save_best_only="True", verbose=True),
EarlyStopping(monitor="val_acc", patience=10, mode='max'),
CometCallback(experiment)
]
model.fit_generator(
train_dl,
val_dl,
epochs=50,
callbacks=callbacks,
)
# get the results on the test set
loss, test_acc = model.evaluate_generator(test_dl)
logging.info(f'test_acc=({test_acc})')
experiment.log_metric('test_acc', test_acc)