Lightning v2.3: Tensor Parallelism and 2D Parallelism
Lightning AI is excited to announce the release of Lightning 2.3 ⚡
Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.
This release introduces experimental support for Tensor Parallelism and 2D Parallelism, PyTorch 2.3 support, and several bugfixes and stability improvements.
Highlights
Tensor Parallelism (beta)
Tensor parallelism (TP) is a technique that splits up the computation of selected layers across GPUs to save memory and speed up distributed models. To enable TP as well as other forms of parallelism, we introduce a ModelParallelStrategy
for both Lightning Trainer and Fabric. Under the hood, TP is enabled through new experimental PyTorch APIs like DTensor and torch.distributed.tensor.parallel
.
PyTorch Lightning
Enabling TP in a model with PyTorch Lightning requires you to implement the LightningModule.configure_model()
method where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.
import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
# 1. Implement the `configure_model()` method in LightningModule
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = FeedForward(8192, 8192)
def configure_model(self):
# Lightning will set up a `self.device_mesh` for you
tp_mesh = self.device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(self.model, tp_mesh, plan)
def training_step(self, batch):
...
# 2. Create the strategy
strategy = ModelParallelStrategy()
# 3. Configure devices and set the strategy in Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(...)
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.pytorch.strategies import ModelParallelStrategy
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = FeedForward(8192, 8192)
def configure_model(self):
if self.device_mesh is None:
return
# Lightning will set up a `self.device_mesh` for you
tp_mesh = self.device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(self.model, tp_mesh, plan)
def training_step(self, batch):
output = self.model(batch)
loss = output.sum()
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=3e-3)
def train_dataloader(self):
# Trainer configures the sampler automatically for you such that
# all batches in a tensor-parallel group are identical
dataset = RandomDataset(8192, 64)
return torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2)
strategy = ModelParallelStrategy()
trainer = L.Trainer(
accelerator="cuda",
devices=2,
strategy=strategy,
max_epochs=1,
)
model = LitModel()
trainer.fit(model)
trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Lightning Fabric
Applying TP in a model with Fabric requires you to implement a special function where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.
import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
# 1. Implement the parallelization function for your model
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
# 2. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
# 3. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)
# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)
# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)
# Simplified training loop
for i, batch in enumerate(dataloader):
output = model(batch)
loss = output.sum()
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
fabric.print(f"Iteration {i} complete")
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
2D Parallelism (beta)
Tensor Parallelism by itself can be very effective for efficient inference of very large models. For training, TP is typically combined with other forms of parallelism, such as FSDP, to increase throughput and scalability on large clusters with 100s of GPUs. The new ModelParallelStrategy
in this release supports the combination of TP + FSDP, which is referred to as 2D parallelism.
For an introduction to this feature, please also refer to the tutorial Studios (PyTorch Lightning, Lightning Fabric). At the moment, the PyTorch team is reimplementing FSDP under the name FSDP2 with the aim to make it compose well with other parallelisms such as TP. Therefore, for the experimental 2D parallelism support, you'll need to switch to using FSDP2 with the new ModelParallelStrategy
. Please refer to our docs (PyTorch Lightning, Lightning Fabric) and stay tuned for future releases as these APIs mature.
Training Mode in Model Summary
The model summary table that gets displayed when you run Trainer.fit()
now contains a new column "Mode" that shows the training mode each layer is in (#19468).
| Name | Type | Params | Mode
-----------------------------------------------------------------
0 | model | Sam | 93.7 M | train
1 | model.image_encoder | ImageEncoderViT | 89.7 M | eval
2 | model.prompt_encoder | PromptEncoder | 6.2 K | train
3 | model.mask_decoder | MaskDecoder | 4.1 M | train
-----------------------------------------------------------------
93.7 M Trainable params
0 Non-trainable params
93.7 M Total params
374.942 Total estimated model params size (MB)
A module in PyTorch is always either in train
(default) or eval
mode.
This improvement should give users more visibility into the state of their model and help debug issues, for example when you need to make sure certain layers of the model are frozen.
Special Forward Methods in Fabric
Until now, Lightning Fabric warned the user in case the forward pass of the model or a subset of its modules was conducted through methods other than the dedicated forward
method of the PyTorch module. The reason for this is that PyTorch needs to run special hooks in case of DDP/FSDP and other strategies to function properly, and not running through the real forward
method would skip these hooks and lead to correctness issues.
In Lightning Fabric 2.3, we added a feature to explicitly mark alternative forward methods so that Fabric can add the necessary rerouting behind the scenes:
import lightning as L
fabric = L.Fabric(devices=2, strategy="ddp")
fabric.launch()
model = MyModel()
model = fabric.setup(model)
# OK: Calling the model directly
output = model(input)
# ERROR: Calling another method that calls forward indirectly
prediction = model.generate(input)
# New: Mark special forward methods explicitly before using them
model.mark_forward_method(model.generate)
# OK: Now can use `model.generate()` in DDP/FSDP without issues
prediction = model.generate(input)
Find the full example and more details in our docs.
Notable Changes
The 2.0 series of Lightning releases guarantees core API stability: No name changes, argument renaming, hook removals etc. on core interfaces (Trainer, LightningModule, etc.) unless a feature is specifically marked experimental. Here we list a few behavioral changes made in places where the change was justified if it significantly improves the user experience, improves performance, or fixes the correctness of a feature. These changes will likely not impact most users.
Skipping the training step in DDP
It is no longer allowed to skip training_step()
by returning None
in distributed training (#19918). The following usage was previously possible but would result in unpredictable hangs and timeouts in distributed training:
def training_step(self, batch):
loss = ...
if loss.isnan():
# No longer allowed in multi-GPU!
# Raises error in Lightning >= 2.3
return None
return loss
We decided to raise an error if the user attempts to return None
when running in a multi-GPU setting.
Miscellaneous Changes
- Dropped support for PyTorch 1.13 (#19300). With every new Lightning release, we add official support for the latest PyTorch stable version and drop the oldest version in our support window.
- The
prepare_data()
hook inLightningModule
andLightningDataModule
is now subject to a barrier without timeout to avoid long-running tasks to be interrupted (#19448). Similarly, also in Fabric theFabric.rank_zero_first
context manager now uses an infinite barrier (#19448).
CHANGELOG
PyTorch Lightning
Added
- The
ModelSummary
andRichModelSummary
callbacks now display the training mode of each layer in the column "Mode" (#19468) - Added
load_from_checkpoint
support forLightningCLI
when using dependency injection (#18105) - Added robust timer duration parsing with an informative error message when parsing fails (#19513)
- Added
on_exception
hook toLightningDataModule
(#19601) - Added support for PyTorch 2.3 (#19708)
- Added
ModelParallelStrategy
to support 2D parallelism (#19878, #19888) - Added a call to
torch.distributed.destroy_process_group
in atexit handler if process group needs destruction (#19931) - Added support for configuring hybrid-sharding by passing a tuple for the
FSDPStrategy(device_mesh=...)
argument (#19504)
Changed
- The
prepare_data()
hook inLightningModule
andLightningDataModule
is now subject to a barrier without timeout to avoid long-running tasks to be interrupted (#19448) - Relaxed the requirement for custom batch samplers to expose
drop_last
for prediction (#19678) - It is no longer allowed to skip
training_step()
by returningNone
in distributed training (#19918)
Removed
Fixed
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)
- Fixed
WandbLogger.log_hyperparameters()
raising an error if hyperparameters are not JSON serializable (#19769) - Fixed an issue with the LightningCLI not being able to set the
ModelCheckpoint(save_last=...)
argument (#19808) - Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML (#19804)
- Fixed resetting
epoch_loop.restarting
to avoid full validation run afterLearningRateFinder
(#19818)
Lightning Fabric
Added
- Added sanitization for classes before logging them as hyperparameters (#19771)
- Enabled consolidating distributed checkpoints through
fabric consolidate
in the new CLI (#19560) - Added the ability to explicitly mark forward methods in Fabric via
_FabricModule.mark_forward_method()
(#19690) - Added support for PyTorch 2.3 (#19708)
- Added
ModelParallelStrategy
to support 2D parallelism (#19846, #19852, #19870, #19872) - Added a call to
torch.distributed.destroy_process_group
in atexit handler if process group needs destruction (#19931) - Added support for configuring hybrid-sharding by passing a tuple for the
FSDPStrategy(device_mesh=...)
argument (#19504)
Changed
- Renamed
lightning run model
tofabric run
(#19442, #19527) - The
Fabric.rank_zero_first
context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted (#19448) - Fabric now raises an error if you forget to call
fabric.backward()
when it is needed by the strategy or precision selection (#19447, #19493) _BackwardSyncControl
can now control what to do when gradient accumulation is disabled (#19577)
Removed
- Removed support for PyTorch 1.13 (#19706)
Fixed
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)
Full commit list: 2.2.0 -> 2.3.0
Contributors
We thank all our contributors who submitted pull requests for features, bug fixes and documentation updates.
New Contributors
- @cauyxy made their first contribution in #19437
- @mwip made their first contribution in #19518
- @kylebgorman made their first contribution in #19513
- @kashif made their first contribution in #19520
- @ash0ts made their first contribution in #19451
- @dimitri-voytan made their first contribution in #19524
- @ankitgola005 made their first contribution in #19615
- @invisprints made their first contribution in #19629
- @kvenkman made their first contribution in #19465
- @fnhirwa made their first contribution in #19640
- @inyong37 made their first contribution in #19677
- @clumsy made their first contribution in #19601
- @judidoko made their first contribution in #19692
- @Lunamos made their first contribution in #19701
- @dominicgkerr made their first contribution in #19727
- @daavoo made their first contribution in #19774
- @Peiffap made their first contribution in #19805
- @IvanYashchuk made their first contribution in #19926
- @ringohoffman made their first contribution in #19904
- @afspies made their first contribution in #19847
- @fedebotu made their first contribution in #19822
- @mariovas3 made their first contribution in #19808
- @Bhavay-2001 made their first contribution in #19947
- @V0XNIHILI made their first contribution in #19771
Did you know?
Chuck Norris is a big fan and daily user of Lightning Studio.