Skip to content

Lightning v2.3: Tensor Parallelism and 2D Parallelism

Compare
Choose a tag to compare
@awaelchli awaelchli released this 13 Jun 21:30
· 121 commits to master since this release
a42484c

Lightning AI is excited to announce the release of Lightning 2.3 ⚡

Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.

This release introduces experimental support for Tensor Parallelism and 2D Parallelism, PyTorch 2.3 support, and several bugfixes and stability improvements.

Highlights

Tensor Parallelism (beta)

Tensor parallelism (TP) is a technique that splits up the computation of selected layers across GPUs to save memory and speed up distributed models. To enable TP as well as other forms of parallelism, we introduce a ModelParallelStrategy for both Lightning Trainer and Fabric. Under the hood, TP is enabled through new experimental PyTorch APIs like DTensor and torch.distributed.tensor.parallel.

PyTorch Lightning

Enabling TP in a model with PyTorch Lightning requires you to implement the LightningModule.configure_model() method where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the `configure_model()` method in LightningModule
class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        ...


# 2. Create the strategy
strategy = ModelParallelStrategy()

# 3. Configure devices and set the strategy in Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(...)
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.pytorch.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FeedForward(8192, 8192)

    def configure_model(self):
        if self.device_mesh is None:
            return

        # Lightning will set up a `self.device_mesh` for you
        tp_mesh = self.device_mesh["tensor_parallel"]
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(self.model, tp_mesh, plan)

    def training_step(self, batch):
        output = self.model(batch)
        loss = output.sum()
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=3e-3)

    def train_dataloader(self):
        # Trainer configures the sampler automatically for you such that
        # all batches in a tensor-parallel group are identical
        dataset = RandomDataset(8192, 64)
        return torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2)


strategy = ModelParallelStrategy()
trainer = L.Trainer(
    accelerator="cuda",
    devices=2,
    strategy=strategy,
    max_epochs=1,
)

model = LitModel()
trainer.fit(model)

trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

Lightning Fabric

Applying TP in a model with Fabric requires you to implement a special function where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.

Open In Studio

 

import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module


# 1. Implement the parallelization function for your model
def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


# 2. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)

# 3. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module

import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    tp_mesh = device_mesh["tensor_parallel"]
    # Use PyTorch's distributed tensor APIs to parallelize the model
    plan = {
        "w1": ColwiseParallel(),
        "w2": RowwiseParallel(),
        "w3": ColwiseParallel(),
    }
    parallelize_module(model, tp_mesh, plan)
    return model


strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()

# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)

# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)

# Simplified training loop
for i, batch in enumerate(dataloader):
    output = model(batch)
    loss = output.sum()
    fabric.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
    fabric.print(f"Iteration {i} complete")

fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

2D Parallelism (beta)

Tensor Parallelism by itself can be very effective for efficient inference of very large models. For training, TP is typically combined with other forms of parallelism, such as FSDP, to increase throughput and scalability on large clusters with 100s of GPUs. The new ModelParallelStrategy in this release supports the combination of TP + FSDP, which is referred to as 2D parallelism.

For an introduction to this feature, please also refer to the tutorial Studios (PyTorch Lightning, Lightning Fabric). At the moment, the PyTorch team is reimplementing FSDP under the name FSDP2 with the aim to make it compose well with other parallelisms such as TP. Therefore, for the experimental 2D parallelism support, you'll need to switch to using FSDP2 with the new ModelParallelStrategy. Please refer to our docs (PyTorch Lightning, Lightning Fabric) and stay tuned for future releases as these APIs mature.

Training Mode in Model Summary

The model summary table that gets displayed when you run Trainer.fit() now contains a new column "Mode" that shows the training mode each layer is in (#19468).

  | Name                 | Type            | Params | Mode 
-----------------------------------------------------------------
0 | model                | Sam             | 93.7 M | train
1 | model.image_encoder  | ImageEncoderViT | 89.7 M | eval 
2 | model.prompt_encoder | PromptEncoder   | 6.2 K  | train
3 | model.mask_decoder   | MaskDecoder     | 4.1 M  | train
-----------------------------------------------------------------
93.7 M    Trainable params
0         Non-trainable params
93.7 M    Total params
374.942   Total estimated model params size (MB)

A module in PyTorch is always either in train (default) or eval mode.
This improvement should give users more visibility into the state of their model and help debug issues, for example when you need to make sure certain layers of the model are frozen.

Special Forward Methods in Fabric

Until now, Lightning Fabric warned the user in case the forward pass of the model or a subset of its modules was conducted through methods other than the dedicated forward method of the PyTorch module. The reason for this is that PyTorch needs to run special hooks in case of DDP/FSDP and other strategies to function properly, and not running through the real forward method would skip these hooks and lead to correctness issues.

In Lightning Fabric 2.3, we added a feature to explicitly mark alternative forward methods so that Fabric can add the necessary rerouting behind the scenes:

import lightning as L

fabric = L.Fabric(devices=2, strategy="ddp")
fabric.launch()

model = MyModel()
model = fabric.setup(model)

# OK: Calling the model directly
output = model(input)

# ERROR: Calling another method that calls forward indirectly
prediction = model.generate(input)

# New: Mark special forward methods explicitly before using them
model.mark_forward_method(model.generate)

# OK: Now can use `model.generate()` in DDP/FSDP without issues
prediction = model.generate(input)

Find the full example and more details in our docs.

Notable Changes

The 2.0 series of Lightning releases guarantees core API stability: No name changes, argument renaming, hook removals etc. on core interfaces (Trainer, LightningModule, etc.) unless a feature is specifically marked experimental. Here we list a few behavioral changes made in places where the change was justified if it significantly improves the user experience, improves performance, or fixes the correctness of a feature. These changes will likely not impact most users.

Skipping the training step in DDP

It is no longer allowed to skip training_step() by returning None in distributed training (#19918). The following usage was previously possible but would result in unpredictable hangs and timeouts in distributed training:

def training_step(self, batch):
    loss = ...
    if loss.isnan():
        # No longer allowed in multi-GPU!
        # Raises error in Lightning >= 2.3
        return None
    return loss

We decided to raise an error if the user attempts to return None when running in a multi-GPU setting.

Miscellaneous Changes

  • Dropped support for PyTorch 1.13 (#19300). With every new Lightning release, we add official support for the latest PyTorch stable version and drop the oldest version in our support window.
  • The prepare_data() hook in LightningModule and LightningDataModule is now subject to a barrier without timeout to avoid long-running tasks to be interrupted (#19448). Similarly, also in Fabric the Fabric.rank_zero_first context manager now uses an infinite barrier (#19448).

CHANGELOG

PyTorch Lightning

Added
  • The ModelSummary and RichModelSummary callbacks now display the training mode of each layer in the column "Mode" (#19468)
  • Added load_from_checkpoint support for LightningCLI when using dependency injection (#18105)
  • Added robust timer duration parsing with an informative error message when parsing fails (#19513)
  • Added on_exception hook to LightningDataModule (#19601)
  • Added support for PyTorch 2.3 (#19708)
  • Added ModelParallelStrategy to support 2D parallelism (#19878, #19888)
  • Added a call to torch.distributed.destroy_process_group in atexit handler if process group needs destruction (#19931)
  • Added support for configuring hybrid-sharding by passing a tuple for the FSDPStrategy(device_mesh=...) argument (#19504)
Changed
  • The prepare_data() hook in LightningModule and LightningDataModule is now subject to a barrier without timeout to avoid long-running tasks to be interrupted (#19448)
  • Relaxed the requirement for custom batch samplers to expose drop_last for prediction (#19678)
  • It is no longer allowed to skip training_step() by returning None in distributed training (#19918)
Removed
  • Removed the Bagua integration (Trainer(strategy="bagua")) (#19445)
  • Removed support for PyTorch 1.13 (#19706)
Fixed
  • Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)
  • Fixed WandbLogger.log_hyperparameters() raising an error if hyperparameters are not JSON serializable (#19769)
  • Fixed an issue with the LightningCLI not being able to set the ModelCheckpoint(save_last=...) argument (#19808)
  • Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML (#19804)
  • Fixed resetting epoch_loop.restarting to avoid full validation run after LearningRateFinder (#19818)

Lightning Fabric

Added
  • Added sanitization for classes before logging them as hyperparameters (#19771)
  • Enabled consolidating distributed checkpoints through fabric consolidate in the new CLI (#19560)
  • Added the ability to explicitly mark forward methods in Fabric via _FabricModule.mark_forward_method() (#19690)
  • Added support for PyTorch 2.3 (#19708)
  • Added ModelParallelStrategy to support 2D parallelism (#19846, #19852, #19870, #19872)
  • Added a call to torch.distributed.destroy_process_group in atexit handler if process group needs destruction (#19931)
  • Added support for configuring hybrid-sharding by passing a tuple for the FSDPStrategy(device_mesh=...) argument (#19504)
Changed
  • Renamed lightning run model to fabric run (#19442, #19527)
  • The Fabric.rank_zero_first context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted (#19448)
  • Fabric now raises an error if you forget to call fabric.backward() when it is needed by the strategy or precision selection (#19447, #19493)
  • _BackwardSyncControl can now control what to do when gradient accumulation is disabled (#19577)
Removed
  • Removed support for PyTorch 1.13 (#19706)
Fixed
  • Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)

Full commit list: 2.2.0 -> 2.3.0

Contributors

We thank all our contributors who submitted pull requests for features, bug fixes and documentation updates.

New Contributors

Did you know?

Chuck Norris is a big fan and daily user of Lightning Studio.