diff --git a/examples/cross_val_fashion.py b/examples/cross_val_fashion.py new file mode 100644 index 0000000..934c81d --- /dev/null +++ b/examples/cross_val_fashion.py @@ -0,0 +1,81 @@ +"""Simple Convolution and fully connected blocks cross validation example.""" +from vulcanai import datasets +from vulcanai.models import ConvNet, DenseNet +from vulcanai.models.metrics import Metrics + +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +# prepare the data +normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], + std=[x/255.0 for x in [63.0, 62.1, 66.7]]) + +transform = transforms.Compose([transforms.ToTensor(), + normalize]) + + +data_path = "../data" +dataset = datasets.FashionData(root=data_path, + train=True, + transform=transform, + download=True) + +batch_size = 100 + +data_loader = DataLoader(dataset=dataset, + batch_size=batch_size, + shuffle=True) + + + +# define neural network - 3 2D conv layers followed by a dense layer +conv_2D_config = { + 'conv_units': [ + dict( + in_channels=1, + out_channels=16, + kernel_size=(5, 5), + stride=2, + dropout=0.1 + ), + dict( + in_channels=16, + out_channels=32, + kernel_size=(5, 5), + dropout=0.1 + ), + dict( + in_channels=32, + out_channels=64, + kernel_size=(5, 5), + pool_size=2, + dropout=0.1 + ) + ], +} + +dense_config = { + 'dense_units': [100, 50], + 'dropout': 0.5, # Single value or List +} + +conv_2D = ConvNet( + name='conv_2D', + in_dim=(1, 28, 28), + config=conv_2D_config +) + +dense_model = DenseNet( + name='dense_model', + input_networks=conv_2D, + config=dense_config, + num_classes=10, + early_stopping="best_validation_error", + early_stopping_patience=2 +) + + +# cross validate on 5 folds training each fold for 2 epochs +m = Metrics() + +m.cross_validate(dense_model, data_loader, 5, 2) diff --git a/vulcanai/models/basenetwork.py b/vulcanai/models/basenetwork.py index 3157b95..02381d3 100644 --- a/vulcanai/models/basenetwork.py +++ b/vulcanai/models/basenetwork.py @@ -28,6 +28,7 @@ sns.set(style='dark') logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler()) # Because pytorch causes a bunch of unresolved references @@ -651,9 +652,11 @@ def fit(self, train_loader, val_loader, epochs, else: save_path = save_path + '/' + self.name + '_' save_path = get_save_path(save_path, vis_type='train') - iterator = trange(epochs, desc='Epoch: ') + # iterator = trange(epochs, desc='Epoch: ') - for epoch in iterator: + for epoch in range(epochs): + + logger.info('\n -------- Epoch: {} --------\n'.format(epoch)) train_loss, train_acc = self._train_epoch(train_loader, retain_graph) @@ -687,7 +690,7 @@ def fit(self, train_loader, val_loader, epochs, self.__dict__.update(self.load_model( early_stopping.save_path).__dict__) # for tqdm - iterator.close() + # iterator.close() break # reset from None so that a distinction can be made @@ -697,15 +700,26 @@ def fit(self, train_loader, val_loader, epochs, if not valid_acc: valid_acc = np.nan - tqdm.write( - "\n Epoch {}:\n" - "Train Loss: {:.6f} | Val Loss: {:.6f} |" - "Train Acc: {:.4f} | Val Acc: {:.4f}".format( - self.epoch, - train_loss, - valid_loss, - train_acc, - valid_acc)) + + if epoch % valid_interv == 0: + tqdm.write( + "\nEpoch {} Summary:\n" + "Train Loss: {:.6f} | Val Loss: {:.6f} |" + "Train Acc: {:.4f} | Val Acc: {:.4f} \n".format( + self.epoch, + train_loss, + valid_loss, + train_acc, + valid_acc)) + + else: + tqdm.write( + "\nEpoch {} Summary:\n" + "Train Loss: {:.6f} | Train Acc: {:.4f} \n".format( + self.epoch, + train_loss, + train_acc)) + self.record['epoch'].append(self.epoch) self.record['train_error'].append(train_loss) @@ -723,8 +737,8 @@ def fit(self, train_loader, val_loader, epochs, except KeyboardInterrupt: logger.warning( - "\n\n**********KeyboardInterrupt: " - "Training stopped prematurely.**********\n\n") + "\n\n********** KeyboardInterrupt: " + "Training stopped prematurely. **********\n\n") def _train_epoch(self, train_loader, retain_graph): """ diff --git a/vulcanai/models/metrics.py b/vulcanai/models/metrics.py index 0542886..1a736a6 100644 --- a/vulcanai/models/metrics.py +++ b/vulcanai/models/metrics.py @@ -1069,8 +1069,9 @@ def cross_validate(network, data_loader, k, epochs, # all the results so far except KeyboardInterrupt: logger.info( - "\n\n***KeyboardInterrupt: Cross validate stopped \ - prematurely.***\n\n") + "\n\n********** KeyboardInterrupt: Cross validate stopped " + "prematurely. **********\n\n") + cross_val_network.save_model() if average_results: averaged_all_results = {}