diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py index f5db98d177..237a26c3e1 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py @@ -28,7 +28,7 @@ from benchmarks.utils import profile_fn from torchao.prototype.moe_training.kernels.mxfp8.comms import ( - mxfp8_on_device_all_to_all_v, + to_mxfp8_a2a_dequant, ) device = torch.device("cuda") @@ -69,7 +69,8 @@ def get_configs() -> List[ExperimentConfig]: # Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765 def default_a2a_dispatch( routed_input: torch.Tensor, - num_tokens_per_expert: torch.Tensor, + output_splits_list: list[int], + input_splits_list: list[int], device_mesh: DeviceMesh, ): """ @@ -81,34 +82,6 @@ def default_a2a_dispatch( output_splits: the output splits for all-to-all dispatch num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch """ - ep_degree = device_mesh.size(0) - # generate the input splits and output splits for all-to-all - with torch.no_grad(): - num_tokens_per_expert_group = all_to_all_single( - num_tokens_per_expert, - None, - None, - group=device_mesh.get_group(), - ) - # Need to wait explicitly because it is used by a triton kernel later - # which doesn't realize that AsyncCollectiveTensor needs unwrapping - num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( - num_tokens_per_expert_group - ) - input_splits = ( - num_tokens_per_expert.view(ep_degree, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) - ) - # NOTE: this would incur a device-to-host sync - output_splits = ( - num_tokens_per_expert_group.view(ep_degree, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=False) - ) - input_splits_list = input_splits.tolist() - output_splits_list = output_splits.tolist() - # perform all-to-all routed_input = all_to_all_single_autograd( routed_input, @@ -117,55 +90,7 @@ def default_a2a_dispatch( device_mesh.get_group(), ) routed_input = torch.ops._c10d_functional.wait_tensor(routed_input) - return ( - routed_input, - input_splits_list, - output_splits_list, - num_tokens_per_expert_group, - ) - - -def mxfp8_a2a_dispatch( - routed_input: torch.Tensor, - num_tokens_per_expert: torch.Tensor, - device_mesh: DeviceMesh, - max_tokens_per_ep_rank: int, -): - """ - Perform on-device all-to-all dispatch with dynamically quantized mxfp8 inputs to save network bandwidth - and avoid device-to-host sync. - - Returns: - routed_input: the local tokens after all-to-all dispatch - input_splits: the input splits for all-to-all dispatch - output_splits: the output splits for all-to-all dispatch - """ - - ep_degree = device_mesh.size(0) - num_tokens_per_expert_group = all_to_all_single( - num_tokens_per_expert, - None, - None, - group=device_mesh.get_group(), - ) - input_splits_per_ep_rank = num_tokens_per_expert.view(ep_degree, -1).sum(dim=1) - num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( - num_tokens_per_expert_group - ) - routed_input, output_splits_per_ep_rank = mxfp8_on_device_all_to_all_v( - routed_input, - input_splits_per_ep_rank, - max_tokens_per_ep_rank, - device_mesh.get_group().group_name, - ) - tokens_on_rank_after_a2a = output_splits_per_ep_rank.sum() - routed_input_no_padding = routed_input[:tokens_on_rank_after_a2a] - return ( - routed_input_no_padding, - input_splits_per_ep_rank, - output_splits_per_ep_rank, - num_tokens_per_expert_group, - ) + return routed_input def run_experiment( @@ -184,7 +109,6 @@ def run_experiment( # Max output tokens per rank is worst case where one rank receives all tokens input_tokens_per_rank = batch_size * seq_len - max_output_tokens_per_rank = input_tokens_per_rank * dist.get_world_size() def warmup(func_no_args): for _ in range(2): @@ -195,40 +119,50 @@ def warmup(func_no_args): input_splits = generate_split_sizes( num_splits, input_tokens_per_rank, device=device ) + input_splits_list, output_splits_list = get_split_lists(input_splits, mesh) - # Bench default a2a - warmup(lambda: default_a2a_dispatch(ref_x, input_splits, mesh)) + # Compile target funcs + default_a2a_dispatch_c = torch.compile(default_a2a_dispatch) + to_mxfp8_a2a_dequant_c = torch.compile(to_mxfp8_a2a_dequant) + + # Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list) + warmup( + lambda: default_a2a_dispatch_c( + ref_x, output_splits_list, input_splits_list, mesh + ) + ) start_sec = time.perf_counter() - default_a2a_dispatch(ref_x, input_splits, mesh) + default_a2a_dispatch_c(ref_x, output_splits_list, input_splits_list, mesh) end_sec = time.perf_counter() bf16_ms = (end_sec - start_sec) * 1e3 if args.profile: profile_fn( - default_a2a_dispatch, + default_a2a_dispatch_c, ref_x, - input_splits, + output_splits_list, + input_splits_list, mesh, distributed=True, profile_name="all_to_all_single_autograd", ) - # Bench mxfp8 a2a + # Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list) warmup( - lambda: mxfp8_a2a_dispatch(x, input_splits, mesh, max_output_tokens_per_rank) + lambda: to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh) ) start_sec = time.perf_counter() - mxfp8_a2a_dispatch(x, input_splits, mesh, max_output_tokens_per_rank) + to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh) end_sec = time.perf_counter() mxfp8_ms = (end_sec - start_sec) * 1e3 if args.profile: profile_fn( - mxfp8_a2a_dispatch, + to_mxfp8_a2a_dequant_c, x, - input_splits, + output_splits_list, + input_splits_list, mesh, - max_output_tokens_per_rank, distributed=True, - profile_name="mxfp8_all_to_all_v", + profile_name="to_mxfp8_a2a_dequant", ) return ExperimentResult( @@ -258,6 +192,41 @@ def print_results(experiments: List[Experiment]): print(tabulate(rows, headers=headers)) +def get_split_lists( + num_tokens_per_expert: torch.Tensor, device_mesh: DeviceMesh +) -> tuple[list[int], list[int]]: + ep_degree = device_mesh.size(0) + + # generate the input splits and output splits for sync-impls + num_tokens_per_expert_group = all_to_all_single( + num_tokens_per_expert, + None, + None, + group=device_mesh.get_group(), + ) + # Need to wait explicitly because it is used by a triton kernel later + # which doesn't realize that AsyncCollectiveTensor needs unwrapping + num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( + num_tokens_per_expert_group + ) + input_splits = ( + num_tokens_per_expert.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + # NOTE: this would incur a device-to-host sync + output_splits = ( + num_tokens_per_expert_group.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + + input_splits_list = input_splits.tolist() + output_splits_list = output_splits.tolist() + + return input_splits_list, output_splits_list + + def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor: """ Generates a tensor of K random non-negative integers that sum to N. diff --git a/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py b/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py index dca30be419..cd98fe69ac 100644 --- a/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py +++ b/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py @@ -7,6 +7,7 @@ import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem from torch.distributed._functional_collectives import ( + all_to_all_single, all_to_all_single_autograd, ) from torch.nn import functional as F @@ -23,13 +24,14 @@ ) from torchao.prototype.moe_training.kernels.mxfp8.comms import ( mxfp8_on_device_all_to_all_v, + to_mxfp8_a2a_dequant, ) from ..testing_utils import generate_split_sizes @instantiate_parametrized_tests -class MXFP8AllToAllVTest(MultiProcessTestCase): +class MXFP8OnDeviceAllToAllVTest(MultiProcessTestCase): def setUp(self) -> None: super().setUp() self._spawn_processes() @@ -143,5 +145,126 @@ def test_a2a_fwd_bwd(self): dist.destroy_process_group() +@instantiate_parametrized_tests +class ToMXFP8AllToAllVDequantTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 4 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(42 + self.rank) + + def _init_device(self): + symm_mem.set_backend("NVSHMEM") + + def test_a2a_fwd_bwd(self): + self._init_process() + try: + torch.manual_seed(42 + self.rank) + self._init_device() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + tokens_per_ep_rank = 8192 + dim = 2048 + input_tensor = torch.randn( + tokens_per_ep_rank, + dim, + device=self.device, + dtype=torch.float32, + requires_grad=True, + ) + ref_input_tensor = input_tensor.detach().clone().requires_grad_(True) + + # Generate random input splits that sum to tokens_per_ep_rank + experts_per_rank = 2 + num_splits = experts_per_rank * self.world_size + num_tokens_per_expert = generate_split_sizes( + num_splits, tokens_per_ep_rank, self.device + ) + + ep_degree = self.world_size + + # Compute tokens per expert group using tokens per expert + with torch.no_grad(): + num_tokens_per_expert_group = all_to_all_single( + num_tokens_per_expert, + None, + None, + group=dist.group.WORLD, + ) + # Need to wait explicitly because it is used by a triton kernel later + # which doesn't realize that AsyncCollectiveTensor needs unwrapping + num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( + num_tokens_per_expert_group + ) + input_splits = ( + num_tokens_per_expert.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + # NOTE: this would incur a device-to-host sync + output_splits = ( + num_tokens_per_expert_group.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + + # Compute reference a2a autograd + ref_output = all_to_all_single_autograd( + ref_input_tensor, + output_splits.tolist(), + input_splits.tolist(), + dist.group.WORLD, + ) + + # Compute mxfp8 a2a sync + output = to_mxfp8_a2a_dequant( + input_tensor, + output_splits.tolist(), + input_splits.tolist(), + group_name, + ) + + # Compare output + sqnr = compute_error(ref_output, output) + min_sqnr = 30.0 + assert sqnr > min_sqnr, f"sqnr={sqnr} is less than min_sqnr={min_sqnr}" + + # Test backwards + labels = torch.ones_like(output) + ref_loss = F.mse_loss(ref_output, labels) + loss = F.mse_loss(output, labels) + ref_loss.backward() + loss.backward() + + # Compare grads + grad_sqnr = compute_error(ref_input_tensor.grad, input_tensor.grad) + min_grad_sqnr = 28.0 + assert grad_sqnr > min_grad_sqnr, ( + f"grad_sqnr={grad_sqnr} is less than min_grad_sqnr={min_grad_sqnr}" + ) + + finally: + dist.destroy_process_group() + + if __name__ == "__main__": run_tests() diff --git a/torchao/prototype/moe_training/kernels/mxfp8/comms.py b/torchao/prototype/moe_training/kernels/mxfp8/comms.py index e0dac42f71..7430010d3a 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/comms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/comms.py @@ -3,6 +3,9 @@ import torch.distributed._symmetric_memory as symm_mem import triton import triton.language as tl +from torch.distributed._functional_collectives import ( + all_to_all_single, +) from torchao.prototype.moe_training.kernels.triton_utils import ( blockwise_barrier, @@ -31,12 +34,6 @@ class MXFP8OnDeviceAllToAllV(torch.autograd.Function): # Maximum output length (need to be set before use of MXFP8OnDeviceAllToAllV) max_output_rows_per_rank = None - # A preallocated buffer for holding the output, that can be reused without cudaMalloc/cudaFree each iteration - output_buf = None - - # A preallocated buffer for holding the output scales, that can be reused without cudaMalloc/cudaFree each iteration - output_scales_buf = None - # A preallocated buffer for holding the grad_input, that can be reused without cudaMalloc/cudaFree each iteration grad_input_buf = None @@ -47,6 +44,7 @@ class MXFP8OnDeviceAllToAllV(torch.autograd.Function): grad_input_splits_buf = None @staticmethod + @torch.compiler.disable def forward( ctx, input: torch.Tensor, @@ -165,6 +163,7 @@ def forward( return hp_output_no_padding, output_splits @staticmethod + @torch.compiler.disable def backward(ctx, grad_output, grad_splits): """ Backward is implemented as a shuffle of the output's gradients to the input. @@ -455,3 +454,116 @@ def _exchange_row_offsets( output_offset_for_remote_rank = tl.sum(output_split_sizes) return input_offset_for_remote_rank, output_offset_for_remote_rank, num_rows_to_read + + +class ToMXFP8AllToAllVDequant(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: torch.Tensor, + output_splits: list[int], + input_splits: list[int], + group: dist.ProcessGroup = dist.group.WORLD, + ): + """ + Dynamically quantizes input to mxfp8, performs all-to-all, then dequantizes output back to original precision. + Requires d2h sync to get input_splits and output_splits on host, as required by torch.distributed.all_to_all_single API. + """ + + # Quantize input + block_size = 32 + input_scales, input_data = to_mx( + input, + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Dispatch data (async) + output_data = all_to_all_single( + input_data, + output_split_sizes=output_splits, + input_split_sizes=input_splits, + group=group, + ) + + # Dispatch scales (async) + output_scales = all_to_all_single( + input_scales.view(torch.uint8), # NCCL cannot handle float8_e8m0fnu yet + output_split_sizes=output_splits, + input_split_sizes=input_splits, + group=group, + ) + + # Explicitly wait since the a2a ops are async + output_scales = torch.ops._c10d_functional.wait_tensor(output_scales) + output_data = torch.ops._c10d_functional.wait_tensor(output_data) + + # Dequantize output + lowp_dtype = output_data.dtype + hp_dtype = input.dtype + hp_output = to_dtype( + output_data, + output_scales.view(torch.float8_e8m0fnu), + lowp_dtype, + block_size, + hp_dtype, + ) + + ctx.input_splits = input_splits + ctx.output_splits = output_splits + ctx.group = group + return hp_output + + @staticmethod + def backward(ctx, grad_output_hp): + """ + Backward is implemented as a shuffle of the output's gradients to the input. + Args: + `grad_output_hp`: high precision output gradient passed from upstream + """ + # In backward, mxfp8_all_to_all_v input is `grad_output`, and output is `grad_input`. + # Input splits are the output splits from forward (and vice-versa). + input_splits, output_splits = ctx.input_splits, ctx.output_splits + + # Quantize grad_output + block_size = 32 + grad_out_scales, grad_out_data = to_mx( + grad_output_hp, + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Dispatch data (async) + grad_input_data = all_to_all_single( + grad_out_data, + output_split_sizes=input_splits, + input_split_sizes=output_splits, + group=ctx.group, + ) + + # Dispatch scales (async) + grad_input_scales = all_to_all_single( + grad_out_scales.view(torch.uint8), # NCCL cannot handle float8_e8m0fnu yet + output_split_sizes=input_splits, + input_split_sizes=output_splits, + group=ctx.group, + ) + + # Explicitly wait since the a2a ops are async + grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales) + grad_input_data = torch.ops._c10d_functional.wait_tensor(grad_input_data) + + hp_dtype = grad_output_hp.dtype + lowp_dtype = grad_input_data.dtype + grad_input_hp = to_dtype( + grad_input_data, + grad_input_scales.view(torch.float8_e8m0fnu), + lowp_dtype, + block_size, + hp_dtype, + ) + return grad_input_hp, None, None, None + + +# Alias +to_mxfp8_a2a_dequant = ToMXFP8AllToAllVDequant.apply