kind of implementation of show attend and tell paper in pytorch
the architecture is similar to the paper with efficientnet_b5 encoder, bahdanau attention i.e. soft attention described as per paper and LSTM encoder
Train Acc - 65%
Test Acc - 75%
BLEU Score - 1.814
class ImageCaptionDataset(Dataset):
def __init__(self, root_dir, captions_file, tokenizer, transform=None):
self.root_dir = root_dir
self.captions_file = pd.read_csv(captions_file)
self.tokenizer = tokenizer
self.transform = transform
def __len__(self):
return len(self.captions_file)
def __getitem__(self, idx):
img_name = self.captions_file.iloc[idx, 0]
caption = self.captions_file.iloc[idx, 1]
img_path = f"{self.root_dir}/{img_name}"
image ="RGB")
if self.transform is not None:
image = self.transform(image)
# Tokenize the caption
caption_tokens = self.tokenizer(caption, padding='max_length', max_length=30, truncation=True, return_tensors="pt")
caption_tensor = caption_tokens['input_ids'].squeeze() # Remove extra dimension
return image, caption_tensor
the attention part basically mentioned in the paper had 2 things - hard and soft.
went with soft attention as hard attention had some interesting stuff like monte carlo sample approximation for gradient descent and that lead to upgradation of parameters
class BahdanauAttention(nn.Module):
def __init__(self, enc_hid_dim, dec_hid_dim, attn_dim):
super(BahdanauAttention, self).__init__()
#encoder hidden states hj
self.Wh = nn.Linear(enc_hid_dim, attn_dim)
#decoder previous hidden state si-1
self.Ws = nn.Linear(dec_hid_dim, attn_dim)
self.v = nn.Linear(attn_dim, 1, bias=False)
def forward(self, enc_out, dec_hid):
enc_features = self.Wh(enc_out)
dec_features = self.Ws(dec_hid).unsqueeze(1)
scores = torch.tanh(enc_features + dec_features)
alignment_scores = self.v(scores).squeeze(-1)
attention_weights = F.softmax(alignment_scores, dim=1)
context = torch.bmm(attention_weights.unsqueeze(1), enc_out).squeeze(1)
return context, attention_weights
The inference part is taken from this
def predict_caption(image_path, model, tokenizer, max_len=50):
image ="RGB")
image_tensor = transform(image).unsqueeze(0).to(device)
features = model.encoder(image_tensor)
h, c = model.decoder.init_hidden_state(features)
#Starting the caption with the [CLS] token
word = torch.tensor([tokenizer.cls_token_id]).to(device)
embeds = model.decoder.embedding(word)
captions = []
alphas = []
for _ in range(max_len):
alpha, context = model.decoder.attention(features, h)
lstm_input =, context), dim=1)
h, c = model.decoder.lstm_cell(lstm_input, (h, c))
output = model.decoder.fcn(model.decoder.drop(h))
predicted_word_idx = output.argmax(dim=1)
#Break if [SEP] token is generated
if predicted_word_idx.item() == tokenizer.sep_token_id:
embeds = model.decoder.embedding(predicted_word_idx.unsqueeze(0))
#Converting word indices to words & skipping special tokens
caption = tokenizer.decode(captions, skip_special_tokens=True)
return image, caption
the bleu score measures how similar the generated caption is with the actual caption.
from nltk.translate.bleu_score import corpus_bleu
def calculate_bleu(predicted_captions, ground_truth_captions):
predicted_captions = [caption.split() for caption in predicted_captions]
ground_truth_captions = [[caption.split()] for caption in ground_truth_captions]
bleu_score = corpus_bleu(ground_truth_captions, predicted_captions)
return bleu_score
def evaluate_model(model, test_loader, tokenizer):
predicted_captions = []
ground_truth_captions = []
with torch.no_grad():
for image, captions in test_loader:
image =
captions =
outputs, _ = model(image, captions)
_, predicted = outputs.max(2)
predicted = predicted.cpu().numpy()
for idx in range(predicted.shape[0]):
predicted_caption = tokenizer.decode(predicted[idx], skip_special_tokens=True)
ground_truth_caption = tokenizer.decode(captions[idx, 1:], skip_special_tokens=True)
#Calculate BLEU score
bleu_score = calculate_bleu(predicted_captions, ground_truth_captions)
print(f"BLEU Score: {bleu_score:.4f}")
return bleu_score
- implementing the hard attention part
- shifting to Distributed Data Parallel in Pytorch
- training on Flickr30k
- wrapping the model weights in FastAPI and deploying on aws or azure i.e. end2end making
- also playing around different hyperparameters for getting the best results