-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[#9198][feat] Refactor dist ops in AutoDeploy #9301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MrGeva
wants to merge
8
commits into
NVIDIA:main
Choose a base branch
from
nv-auto-deploy:egeva/template_dist_patterns
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+543
−274
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
9e0b47a
split dist custom ops and use templated patterns
MrGeva 81bb39d
fixed tests, readme, removed rms torch pattern
MrGeva 5003571
removed legacy names
MrGeva 168b656
simplified collectives.py
MrGeva 596c643
removed torch rms alred op
MrGeva 572a89c
Fixed CR comments
MrGeva aeb7410
added dist_backend arg
MrGeva 2f3b295
removed try/catch in trtllm)dist.py
MrGeva File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| """Custom ops required for implementing tensor parallelism. | ||
|
|
||
| This module defines atomic distributed ops - each op uses a specific backend | ||
| (torch.distributed or TRT-LLM) without internal dispatch logic. | ||
| """ | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from ..distributed import common as dist | ||
|
|
||
| # ============================================================================ | ||
| # PyTorch Distributed Backend Ops (demollm mode) | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda") | ||
| def torch_dist_all_gather( | ||
| tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None | ||
| ) -> torch.Tensor: | ||
| """All gather using PyTorch distributed backend. | ||
|
|
||
| This op always uses torch.distributed.all_gather and is used in demollm mode. | ||
| """ | ||
| tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] | ||
| dist.all_gather(tl, tensor) | ||
| return torch.cat(tl, dim=dim) | ||
|
|
||
|
|
||
| @torch_dist_all_gather.register_fake | ||
| def torch_dist_all_gather_fake(tensor, dim=0, sizes=None): | ||
| return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim) | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda") | ||
| def torch_dist_all_reduce(t: torch.Tensor) -> torch.Tensor: | ||
| """All_reduce using PyTorch distributed backend. Reduction op is SUM. | ||
|
|
||
| This op always uses torch.distributed.all_reduce and is used in demollm mode. | ||
|
|
||
| NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For | ||
| efficient all_reduce ops one should write/replace it with a fused op. | ||
| """ | ||
| t_res = t.clone() | ||
| dist.all_reduce(t_res, op=dist.ReduceOp.SUM) | ||
| return t_res | ||
|
|
||
|
|
||
| @torch_dist_all_reduce.register_fake | ||
| def torch_dist_all_reduce_fake(tensor): | ||
| return torch.empty_like(tensor) |
115 changes: 115 additions & 0 deletions
115
tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py
MrGeva marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| """TRT-LLM distributed operations and fused kernels. | ||
|
|
||
| This module defines atomic TRT-LLM-specific ops that use optimized kernels. | ||
| The torch fallback variants are defined separately to enable multi-pattern matching. | ||
| """ | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| # use trtllm distributed ops to improve TP performance if possible | ||
| from ....mapping import Mapping | ||
| from ...distributed import AllReduce, allgather | ||
| from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy | ||
| from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi | ||
|
|
||
| # Cache AllReduce modules to avoid recreating on every call | ||
| # This is critical for CUDA graph compatibility - recreating modules during | ||
| # warmup causes hangs due to workspace allocation with CPU synchronization | ||
| _allreduce_cache = {} | ||
|
|
||
|
|
||
| def trtllm_allgather(tensor, dim, sizes=None): | ||
| rank, world_size = get_rank_world_size() | ||
| p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) | ||
| return allgather(tensor, p_config, dim=dim, sizes=sizes) | ||
|
|
||
|
|
||
| def trtllm_allreduce(tensor, op, all_reduce_params=None): | ||
| rank, world_size = get_rank_world_size() | ||
| assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." | ||
|
|
||
| # Cache key includes rank, world_size, and dtype to handle different configurations | ||
| cache_key = (rank, world_size, tensor.dtype) | ||
| if cache_key not in _allreduce_cache: | ||
| p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) | ||
| # Use Strategy.AUTO for optimal performance | ||
| _allreduce_cache[cache_key] = AllReduce( | ||
| mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype | ||
| ) | ||
|
|
||
| torch_op = _allreduce_cache[cache_key] | ||
| return torch_op(tensor, all_reduce_params=all_reduce_params) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # TRT-LLM Backend Ops (MPI mode) | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| @torch.library.custom_op( | ||
| "auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda" | ||
| ) | ||
| def trtllm_dist_all_gather( | ||
| tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None | ||
| ) -> torch.Tensor: | ||
| """All gather using TRT-LLM optimized backend. | ||
|
|
||
| This op always uses TRT-LLM's optimized allgather and is used in MPI mode. | ||
| """ | ||
| return trtllm_allgather(tensor, dim=dim, sizes=sizes) | ||
|
|
||
|
|
||
| @trtllm_dist_all_gather.register_fake | ||
| def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None): | ||
| return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim) | ||
|
|
||
|
|
||
| @torch.library.custom_op( | ||
| "auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda" | ||
| ) | ||
| def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor: | ||
| """All_reduce using TRT-LLM optimized backend. Reduction op is SUM. | ||
|
|
||
| This op always uses TRT-LLM's optimized allreduce and is used in MPI mode. | ||
| """ | ||
| return trtllm_allreduce(t, op=ReduceOp.SUM) | ||
|
|
||
|
|
||
| @trtllm_dist_all_reduce.register_fake | ||
| def trtllm_dist_all_reduce_fake(tensor): | ||
| return torch.empty_like(tensor) | ||
|
|
||
|
|
||
| # TRT-LLM fused op (atomic - always uses TRT-LLM backend) | ||
| @torch.library.custom_op( | ||
| "dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" | ||
| ) | ||
| def trtllm_fused_allreduce_residual_rmsnorm( | ||
| tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel. | ||
|
|
||
| This op always uses TRT-LLM's fused kernel and is used in MPI mode. | ||
| """ | ||
| all_reduce_params = AllReduceParams( | ||
| fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, | ||
| bias=None, | ||
| residual=residual, | ||
| norm_weight=norm_weight, | ||
| eps=eps, | ||
| ) | ||
| return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params) | ||
|
|
||
|
|
||
| @trtllm_fused_allreduce_residual_rmsnorm.register_fake | ||
| def trtllm_fused_allreduce_residual_rmsnorm_fake( | ||
| tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| return torch.empty_like(tensor), torch.empty_like(tensor) | ||
|
|
||
|
|
||
| def is_trtllm_op_available(): | ||
| """Check if TRT-LLM ops are available and running with MPI.""" | ||
| return is_ompi() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.