Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/cross_val_fashion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Simple Convolution and fully connected blocks cross validation example."""
from vulcanai import datasets
from vulcanai.models import ConvNet, DenseNet
from vulcanai.models.metrics import Metrics

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# prepare the data
normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
std=[x/255.0 for x in [63.0, 62.1, 66.7]])

transform = transforms.Compose([transforms.ToTensor(),
normalize])


data_path = "../data"
dataset = datasets.FashionData(root=data_path,
train=True,
transform=transform,
download=True)

batch_size = 100

data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)



# define neural network - 3 2D conv layers followed by a dense layer
conv_2D_config = {
'conv_units': [
dict(
in_channels=1,
out_channels=16,
kernel_size=(5, 5),
stride=2,
dropout=0.1
),
dict(
in_channels=16,
out_channels=32,
kernel_size=(5, 5),
dropout=0.1
),
dict(
in_channels=32,
out_channels=64,
kernel_size=(5, 5),
pool_size=2,
dropout=0.1
)
],
}

dense_config = {
'dense_units': [100, 50],
'dropout': 0.5, # Single value or List
}

conv_2D = ConvNet(
name='conv_2D',
in_dim=(1, 28, 28),
config=conv_2D_config
)

dense_model = DenseNet(
name='dense_model',
input_networks=conv_2D,
config=dense_config,
num_classes=10,
early_stopping="best_validation_error",
early_stopping_patience=2
)


# cross validate on 5 folds training each fold for 2 epochs
m = Metrics()

m.cross_validate(dense_model, data_loader, 5, 2)
42 changes: 28 additions & 14 deletions vulcanai/models/basenetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

sns.set(style='dark')
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())


# Because pytorch causes a bunch of unresolved references
Expand Down Expand Up @@ -651,9 +652,11 @@ def fit(self, train_loader, val_loader, epochs,
else:
save_path = save_path + '/' + self.name + '_'
save_path = get_save_path(save_path, vis_type='train')
iterator = trange(epochs, desc='Epoch: ')
# iterator = trange(epochs, desc='Epoch: ')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this remove the progress bar at the epoch level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was the overlapping progress bars, that's why i removed the outer one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but you still need to print the epoch progress bar. Nested progress bars are what we are going for so they just need to be cleaned up in implementation.


for epoch in iterator:
for epoch in range(epochs):

logger.info('\n -------- Epoch: {} --------\n'.format(epoch))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why have this additional epoch log if the tqdm writer also writes the current epoch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I don't use tqdm on the epoch level anymore

Copy link
Contributor Author

@sneha-desai sneha-desai Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was the overlapping progress bars, that's why i removed the epoch level one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see my other comment about this. I don't think the solution for this is to remove the progress bar altogether. unless I'm missing something, we just need to clean up the nested progress bar implementation.


train_loss, train_acc = self._train_epoch(train_loader,
retain_graph)
Expand Down Expand Up @@ -687,7 +690,7 @@ def fit(self, train_loader, val_loader, epochs,
self.__dict__.update(self.load_model(
early_stopping.save_path).__dict__)
# for tqdm
iterator.close()
# iterator.close()
break

# reset from None so that a distinction can be made
Expand All @@ -697,15 +700,26 @@ def fit(self, train_loader, val_loader, epochs,
if not valid_acc:
valid_acc = np.nan

tqdm.write(
"\n Epoch {}:\n"
"Train Loss: {:.6f} | Val Loss: {:.6f} |"
"Train Acc: {:.4f} | Val Acc: {:.4f}".format(
self.epoch,
train_loss,
valid_loss,
train_acc,
valid_acc))

if epoch % valid_interv == 0:
tqdm.write(
"\nEpoch {} Summary:\n"
"Train Loss: {:.6f} | Val Loss: {:.6f} |"
"Train Acc: {:.4f} | Val Acc: {:.4f} \n".format(
self.epoch,
train_loss,
valid_loss,
train_acc,
valid_acc))

else:
tqdm.write(
"\nEpoch {} Summary:\n"
"Train Loss: {:.6f} | Train Acc: {:.4f} \n".format(
self.epoch,
train_loss,
train_acc))


self.record['epoch'].append(self.epoch)
self.record['train_error'].append(train_loss)
Expand All @@ -723,8 +737,8 @@ def fit(self, train_loader, val_loader, epochs,

except KeyboardInterrupt:
logger.warning(
"\n\n**********KeyboardInterrupt: "
"Training stopped prematurely.**********\n\n")
"\n\n********** KeyboardInterrupt: "
"Training stopped prematurely. **********\n\n")

def _train_epoch(self, train_loader, retain_graph):
"""
Expand Down
5 changes: 3 additions & 2 deletions vulcanai/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,8 +1069,9 @@ def cross_validate(network, data_loader, k, epochs,
# all the results so far
except KeyboardInterrupt:
logger.info(
"\n\n***KeyboardInterrupt: Cross validate stopped \
prematurely.***\n\n")
"\n\n********** KeyboardInterrupt: Cross validate stopped "
"prematurely. **********\n\n")
cross_val_network.save_model()

if average_results:
averaged_all_results = {}
Expand Down