Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 61 additions & 92 deletions benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from benchmarks.utils import profile_fn
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
mxfp8_on_device_all_to_all_v,
to_mxfp8_a2a_dequant,
)

device = torch.device("cuda")
Expand Down Expand Up @@ -69,7 +69,8 @@ def get_configs() -> List[ExperimentConfig]:
# Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765
def default_a2a_dispatch(
routed_input: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
output_splits_list: list[int],
input_splits_list: list[int],
device_mesh: DeviceMesh,
):
"""
Expand All @@ -81,34 +82,6 @@ def default_a2a_dispatch(
output_splits: the output splits for all-to-all dispatch
num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
"""
ep_degree = device_mesh.size(0)
# generate the input splits and output splits for all-to-all
with torch.no_grad():
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=device_mesh.get_group(),
)
# Need to wait explicitly because it is used by a triton kernel later
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
input_splits = (
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
input_splits_list = input_splits.tolist()
output_splits_list = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all_single_autograd(
routed_input,
Expand All @@ -117,55 +90,7 @@ def default_a2a_dispatch(
device_mesh.get_group(),
)
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)
return (
routed_input,
input_splits_list,
output_splits_list,
num_tokens_per_expert_group,
)


def mxfp8_a2a_dispatch(
routed_input: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
device_mesh: DeviceMesh,
max_tokens_per_ep_rank: int,
):
"""
Perform on-device all-to-all dispatch with dynamically quantized mxfp8 inputs to save network bandwidth
and avoid device-to-host sync.

Returns:
routed_input: the local tokens after all-to-all dispatch
input_splits: the input splits for all-to-all dispatch
output_splits: the output splits for all-to-all dispatch
"""

ep_degree = device_mesh.size(0)
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=device_mesh.get_group(),
)
input_splits_per_ep_rank = num_tokens_per_expert.view(ep_degree, -1).sum(dim=1)
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
routed_input, output_splits_per_ep_rank = mxfp8_on_device_all_to_all_v(
routed_input,
input_splits_per_ep_rank,
max_tokens_per_ep_rank,
device_mesh.get_group().group_name,
)
tokens_on_rank_after_a2a = output_splits_per_ep_rank.sum()
routed_input_no_padding = routed_input[:tokens_on_rank_after_a2a]
return (
routed_input_no_padding,
input_splits_per_ep_rank,
output_splits_per_ep_rank,
num_tokens_per_expert_group,
)
return routed_input


def run_experiment(
Expand All @@ -184,7 +109,6 @@ def run_experiment(

# Max output tokens per rank is worst case where one rank receives all tokens
input_tokens_per_rank = batch_size * seq_len
max_output_tokens_per_rank = input_tokens_per_rank * dist.get_world_size()

def warmup(func_no_args):
for _ in range(2):
Expand All @@ -195,40 +119,50 @@ def warmup(func_no_args):
input_splits = generate_split_sizes(
num_splits, input_tokens_per_rank, device=device
)
input_splits_list, output_splits_list = get_split_lists(input_splits, mesh)

# Bench default a2a
warmup(lambda: default_a2a_dispatch(ref_x, input_splits, mesh))
# Compile target funcs
default_a2a_dispatch_c = torch.compile(default_a2a_dispatch)
to_mxfp8_a2a_dequant_c = torch.compile(to_mxfp8_a2a_dequant)

# Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(
lambda: default_a2a_dispatch_c(
ref_x, output_splits_list, input_splits_list, mesh
)
)
start_sec = time.perf_counter()
default_a2a_dispatch(ref_x, input_splits, mesh)
default_a2a_dispatch_c(ref_x, output_splits_list, input_splits_list, mesh)
end_sec = time.perf_counter()
bf16_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
default_a2a_dispatch,
default_a2a_dispatch_c,
ref_x,
input_splits,
output_splits_list,
input_splits_list,
mesh,
distributed=True,
profile_name="all_to_all_single_autograd",
)

# Bench mxfp8 a2a
# Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(
lambda: mxfp8_a2a_dispatch(x, input_splits, mesh, max_output_tokens_per_rank)
lambda: to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
)
start_sec = time.perf_counter()
mxfp8_a2a_dispatch(x, input_splits, mesh, max_output_tokens_per_rank)
to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
end_sec = time.perf_counter()
mxfp8_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
mxfp8_a2a_dispatch,
to_mxfp8_a2a_dequant_c,
x,
input_splits,
output_splits_list,
input_splits_list,
mesh,
max_output_tokens_per_rank,
distributed=True,
profile_name="mxfp8_all_to_all_v",
profile_name="to_mxfp8_a2a_dequant",
)

return ExperimentResult(
Expand Down Expand Up @@ -258,6 +192,41 @@ def print_results(experiments: List[Experiment]):
print(tabulate(rows, headers=headers))


def get_split_lists(
num_tokens_per_expert: torch.Tensor, device_mesh: DeviceMesh
) -> tuple[list[int], list[int]]:
ep_degree = device_mesh.size(0)

# generate the input splits and output splits for sync-impls
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=device_mesh.get_group(),
)
# Need to wait explicitly because it is used by a triton kernel later
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
input_splits = (
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)

input_splits_list = input_splits.tolist()
output_splits_list = output_splits.tolist()

return input_splits_list, output_splits_list


def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor:
"""
Generates a tensor of K random non-negative integers that sum to N.
Expand Down
125 changes: 124 additions & 1 deletion test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.distributed._functional_collectives import (
all_to_all_single,
all_to_all_single_autograd,
)
from torch.nn import functional as F
Expand All @@ -23,13 +24,14 @@
)
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
mxfp8_on_device_all_to_all_v,
to_mxfp8_a2a_dequant,
)

from ..testing_utils import generate_split_sizes


@instantiate_parametrized_tests
class MXFP8AllToAllVTest(MultiProcessTestCase):
class MXFP8OnDeviceAllToAllVTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
Expand Down Expand Up @@ -143,5 +145,126 @@ def test_a2a_fwd_bwd(self):
dist.destroy_process_group()


@instantiate_parametrized_tests
class ToMXFP8AllToAllVDequantTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()

@property
def world_size(self) -> int:
return 4

@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")

def _init_process(self):
torch.cuda.set_device(self.device)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
torch.manual_seed(42 + self.rank)

def _init_device(self):
symm_mem.set_backend("NVSHMEM")

def test_a2a_fwd_bwd(self):
self._init_process()
try:
torch.manual_seed(42 + self.rank)
self._init_device()

group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)

tokens_per_ep_rank = 8192
dim = 2048
input_tensor = torch.randn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: if you quantize this to mxfp8 and then dequantize, you can then test for bitwise equality at the end

tokens_per_ep_rank,
dim,
device=self.device,
dtype=torch.float32,
requires_grad=True,
)
ref_input_tensor = input_tensor.detach().clone().requires_grad_(True)

# Generate random input splits that sum to tokens_per_ep_rank
experts_per_rank = 2
num_splits = experts_per_rank * self.world_size
num_tokens_per_expert = generate_split_sizes(
num_splits, tokens_per_ep_rank, self.device
)

ep_degree = self.world_size

# Compute tokens per expert group using tokens per expert
with torch.no_grad():
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=dist.group.WORLD,
)
# Need to wait explicitly because it is used by a triton kernel later
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
input_splits = (
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)

# Compute reference a2a autograd
ref_output = all_to_all_single_autograd(
ref_input_tensor,
output_splits.tolist(),
input_splits.tolist(),
dist.group.WORLD,
)

# Compute mxfp8 a2a sync
output = to_mxfp8_a2a_dequant(
input_tensor,
output_splits.tolist(),
input_splits.tolist(),
group_name,
)

# Compare output
sqnr = compute_error(ref_output, output)
min_sqnr = 30.0
assert sqnr > min_sqnr, f"sqnr={sqnr} is less than min_sqnr={min_sqnr}"

# Test backwards
labels = torch.ones_like(output)
ref_loss = F.mse_loss(ref_output, labels)
loss = F.mse_loss(output, labels)
ref_loss.backward()
loss.backward()

# Compare grads
grad_sqnr = compute_error(ref_input_tensor.grad, input_tensor.grad)
min_grad_sqnr = 28.0
assert grad_sqnr > min_grad_sqnr, (
f"grad_sqnr={grad_sqnr} is less than min_grad_sqnr={min_grad_sqnr}"
)

finally:
dist.destroy_process_group()


if __name__ == "__main__":
run_tests()
Loading
Loading