Releases: Lightning-AI/pytorch-lightning
Lightning v2.4
Lightning AI ⚡ is excited to announce the release of Lightning 2.4. This is mainly a compatibility upgrade for PyTorch 2.4 and Python 3.12, with a sprinkle of a few features and bug fixes.
Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.
Changes
PyTorch Lightning
Added
- Made saving non-distributed checkpoints fully atomic (#20011)
- Added
dump_stats
flag toAdvancedProfiler
(#19703) - Added a flag
verbose
to theseed_everything()
function (#20108) - Added support for PyTorch 2.4 (#20010)
- Added support for Python 3.12 (20078)
- The
TQDMProgressBar
now provides an option to retain prior training epoch bars (#19578) - Added the count of modules in train and eval mode to the printed
ModelSummary
table (#20159)
Changed
- Triggering KeyboardInterrupt (Ctrl+C) during
.fit()
,.evaluate()
,.test()
or.predict()
now terminates all processes launched by the Trainer and exits the program (#19976) - Changed the implementation of how seeds are chosen for dataloader workers when using
seed_everything(..., workers=True)
(#20055) - NumPy is no longer a required dependency (#20090)
Fixed
- Avoid LightningCLI saving hyperparameters with
class_path
andinit_args
since this would be a breaking change (#20068) - Fixed an issue that would cause too many printouts of the seed info when using
seed_everything()
(#20108) - Fixed
_LoggerConnector
's_ResultMetric
to move all registered keys to the device of the logged value if needed (#19814) - Fixed
_optimizer_to_device
logic for special 'step' key in optimizer state causing performance regression (#20019) - Fixed parameter counts in
ModelSummary
when model has distributed parameters (DTensor) (#20163)
Lightning Fabric
Added
Changed
Fixed
Full commit list: 2.3.0 -> 2.4.0
Contributors
We thank all our contributors who submitted pull requests for features, bug fixes and documentation updates.
New Contributors
- @SamuelLarkin made their first contribution in #19969
- @liambsmith made their first contribution in #19986
- @EtayLivne made their first contribution in #19915
- @elmuz made their first contribution in #19998
- @swyo made their first contribution in #19982
- @corwinjoy made their first contribution in #20011
- @omahs made their first contribution in #19979
- @linbo0518 made their first contribution in #20040
- @01AbhiSingh made their first contribution in #20055
- @K-H-Ismail made their first contribution in #20099
- @adosar made their first contribution in #20146
- @jojje made their first contribution in #19578
Did you know?
Chuck Norris can solve NP-hard problems in polynomial time. In fact, any problem is easy when Chuck Norris solves it.
Patch release v2.3.3
This release removes the code from the main lightning
package that was reported in CVE-2024-5980.
Patch release v2.3.2
Includes a minor bugfix that avoids a conflict with the entrypoint command with another package #20041.
Patch release v2.3.1
Includes minor bugfixes and stability improvements.
Full Changelog: 2.3.0...2.3.1
Lightning v2.3: Tensor Parallelism and 2D Parallelism
Lightning AI is excited to announce the release of Lightning 2.3 ⚡
Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.
This release introduces experimental support for Tensor Parallelism and 2D Parallelism, PyTorch 2.3 support, and several bugfixes and stability improvements.
Highlights
Tensor Parallelism (beta)
Tensor parallelism (TP) is a technique that splits up the computation of selected layers across GPUs to save memory and speed up distributed models. To enable TP as well as other forms of parallelism, we introduce a ModelParallelStrategy
for both Lightning Trainer and Fabric. Under the hood, TP is enabled through new experimental PyTorch APIs like DTensor and torch.distributed.tensor.parallel
.
PyTorch Lightning
Enabling TP in a model with PyTorch Lightning requires you to implement the LightningModule.configure_model()
method where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.
import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
# 1. Implement the `configure_model()` method in LightningModule
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = FeedForward(8192, 8192)
def configure_model(self):
# Lightning will set up a `self.device_mesh` for you
tp_mesh = self.device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(self.model, tp_mesh, plan)
def training_step(self, batch):
...
# 2. Create the strategy
strategy = ModelParallelStrategy()
# 3. Configure devices and set the strategy in Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(...)
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.pytorch.strategies import ModelParallelStrategy
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = FeedForward(8192, 8192)
def configure_model(self):
if self.device_mesh is None:
return
# Lightning will set up a `self.device_mesh` for you
tp_mesh = self.device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(self.model, tp_mesh, plan)
def training_step(self, batch):
output = self.model(batch)
loss = output.sum()
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=3e-3)
def train_dataloader(self):
# Trainer configures the sampler automatically for you such that
# all batches in a tensor-parallel group are identical
dataset = RandomDataset(8192, 64)
return torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2)
strategy = ModelParallelStrategy()
trainer = L.Trainer(
accelerator="cuda",
devices=2,
strategy=strategy,
max_epochs=1,
)
model = LitModel()
trainer.fit(model)
trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Lightning Fabric
Applying TP in a model with Fabric requires you to implement a special function where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.
import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
# 1. Implement the parallelization function for your model
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
# 2. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
# 3. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)
# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)
# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_si...
Patch release v2.2.5
PyTorch Lightning + Fabric
Fixed
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)
Full Changelog: 2.2.4...2.2.5
Patch release v2.2.4
App
Fixed
- Fixed HTTPClient retry for flow/work queue (#19837)
PyTorch
No Changes.
Fabric
No Changes.
Full Changelog: 2.2.3...2.2.4
Patch release v2.2.3
PyTorch
Fixed
- Fixed
WandbLogger.log_hyperparameters()
raising an error if hyperparameters are not JSON serializable (#19769)
Fabric
No Changes.
Full Changelog: 2.2.2...2.2.3
Patch release v2.2.2
PyTorch
Fixed
- Fixed an issue causing a TypeError when using
torch.compile
as a decorator (#19627) - Fixed a KeyError when saving a FSDP sharded checkpoint and setting
save_weights_only=True
(#19524)
Fabric
Fixed
- Fixed an issue causing a TypeError when using
torch.compile
as a decorator (#19627) - Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped (#19705)
- Fixed an issue causing weights to be reset in
Fabric.setup()
when using FSDP (#19755)
Full Changelog: 2.2.1...2.2.2
Contributors
@ankitgola005 @awaelchli @Borda @carmocca @dmitsf @dvoytan-spark @fnhirwa
Patch release v2.2.1
PyTorch
Fixed
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually (#19446)
- Fixed the divisibility check for
Trainer.accumulate_grad_batches
andTrainer.log_every_n_steps
in ThroughputMonitor (#19470) - Fixed support for Remote Stop and Remote Abort with NeptuneLogger (#19130)
- Fixed infinite recursion error in precision plugin graveyard (#19542)
Fabric
Fixed
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually (#19446)
Full Changelog: 2.2.0post...2.2.1
Contributors
@Raalsky @awaelchli @carmocca @Borda
If we forgot someone due to not matching commit email with GitHub account, let us know :]