Lightning is very slow - Performance divided by ~4 compared to Pytorch. 10s wait between epochs. #10382
-
I converted some Pytorch code to Lightning. The dataset is loaded lazily by the train & eval dataloaders. However, when moving the code to Lightning, I noticed a huge slowdown. After digging around, I noticed that there was a ~10 seconds delay between each epoch. For comparison, on my vanilla Pytorch, an epoch takes ~4s. I first thought it was a data loading problem, but during the 10s delay, no data is loaded (at least that's what my I think the issue is related to the number of workers, because setting Since this is company code, I cannot disclose the before/after, but I'll try to "anonymize" some code if necessary. Here is the lightning module: class RawModule(pl.LightningModule):
def __init__(self):
super(RawModule, self).__init__()
self.encoder1 = nn.Sequential(...)
self.encoder2 = nn.Sequential(...)
def forward(self, data1, data2):
result1 = self.encoder1(data1)
result2 = self.encoder2(data2)
result1 = result1 .view(result1 .size(0), -1)
result2 = result2 .view(result2 .size(0), -1)
result1 = F.normalize(result1 , p=2, dim=1)
result2 = F.normalize(result2 , p=2, dim=1)
return result1, result2
def calculate_loss(self, batch):
x, r, y = batch
a, v = self.forward(r, x)
d = nn.functional.cosine_similarity(a, v)
loss = logloss(d.unsqueeze(1), y)
return loss
class Module(RawModule):
def training_step(self, batch, batch_idx):
loss = self.calculate_loss(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.calculate_loss(batch)
self.log("validation_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
return optimizer
if __name__ == '__main__':
# stuff...
train_loader = data_utils.DataLoader(
train_dataset, batch_size=256, shuffle=True,
num_workers=5, persistent_workers=True,
pin_memory=True,
)
val_loader = data_utils.DataLoader(
test_dataset, batch_size=256,
num_workers=2, persistent_workers=True,
pin_memory=True,
)
# Model
load_from_pytorch = True
if checkpoint_path is None:
model = Module()
if load_from_pytorch:
if not checkpoint_path:
raise ValueError("Please provide a checkpoint path")
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
else:
model = Module.load_from_checkpoint(checkpoint_path)
trainer = pl.Trainer(
gpus=1,
max_epochs=5,
check_val_every_n_epoch=10,
log_every_n_steps=5,
)
trainer.fit(model, train_loader, val_loader) Here is the result of
Here is the result of Finally, here is a video demonstrating the problem. I'm printing each piece of data loading, to prove it's not the issue. Random informations:
Any idea on how to find the source of the problem? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Fixed in the 1.5.1 release. See the #10389 issue, or the release itself. |
Beta Was this translation helpful? Give feedback.
Fixed in the 1.5.1 release. See the #10389 issue, or the release itself.