Skip to content

Conversation

NVShreyas
Copy link
Collaborator

@NVShreyas NVShreyas commented Oct 2, 2025

Summary by CodeRabbit

  • New Features
    • Optional FlashInfer-accelerated all-reduce for tensor-parallel layers (including vLLM), enabled via environment flags and a new Linear layer option; automatically falls back to the standard path when disabled.
    • New parameters to configure FlashInfer all-reduce strategy and fusion behavior.
    • Automatic initialization of tensor-parallel communication when TP is enabled, simplifying setup.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@NVShreyas NVShreyas requested review from a team as code owners October 2, 2025 19:06
@NVShreyas NVShreyas marked this pull request as draft October 2, 2025 19:06
@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20551 [ run ] triggered by Bot

Copy link
Contributor

coderabbitai bot commented Oct 2, 2025

📝 Walkthrough

Walkthrough

Adds FlashInfer-based all-reduce support. Extends distributed exports, introduces FlashInfer ops and VLLM communicator, adds TP communicator initialization, integrates optional FlashInfer paths in Linear (and LMHead init), and defines FlashInferAllReduceParams. Minor whitespace cleanup in Llama.

Changes

Cohort / File(s) Summary of Changes
Distributed API exports
tensorrt_llm/_torch/distributed/__init__.py
Exposes FlashInferAllReduce, FlashInferAllReduceParams, FlashInferVLLMAllReduce via imports from .ops and updates all.
FlashInfer comm and ops
tensorrt_llm/_torch/distributed/communicator.py, tensorrt_llm/_torch/distributed/ops.py
Adds FlashInferVLLMComm, global TP comm init/accessors, conditional TorchDist initialization, and an allgather stub. Implements FlashInfer all-reduce workspaces, FlashInferAllReduce and FlashInferVLLMAllReduce modules, and FlashInfer invocation with error handling.
Functional params
tensorrt_llm/functional.py
Introduces FlashInferAllReduceParams subclass with strategy/fusion/config fields; imports flashinfer.comm.
Linear/Embedding integration
tensorrt_llm/_torch/modules/linear.py, tensorrt_llm/_torch/modules/embedding.py
Adds use_flashinfer_allreduce flag. Integrates FlashInferAllReduce/FlashInferVLLMAllReduce in Linear forward under env flags and conditions; updates imports. Embedding passes the flag in LMHead init.
Engine TP init
tensorrt_llm/_torch/pyexecutor/model_engine.py
Initializes Torch TP communicator when mapping.has_tp(); imports init_torch_dist_tp_comm.
Minor cleanup
tensorrt_llm/_torch/models/modeling_llama.py
Removes a blank line; no functional change.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Engine as PyTorchModelEngine
  participant Dist as TorchDistTPComm
  participant TD as TorchDist
  participant FIV as FlashInferVLLMComm

  User->>Engine: construct(mapping)
  Engine->>Engine: if mapping.has_tp(): init_torch_dist_tp_comm(mapping)
  Engine->>Dist: init_torch_dist_tp_comm(mapping)
  Dist->>TD: create TorchDist(mapping)
  TD->>TD: if _USE_FLASHINFER_VLLM_ALLREDUCE: init _flashinfer_vllm_comm
  TD->>FIV: setup shared buffers and IPC
  note right of FIV: FlashInfer VLLM comm ready
Loading
sequenceDiagram
  autonumber
  participant App as Module(Linear)
  participant AR as FlashInferAllReduce/FlashInferVLLMAllReduce
  participant Ops as flashinfer_comm
  participant FWS as Workspace Cache

  App->>App: forward(x)
  App->>App: if ROW-parallel and reduce_output
  App->>App: if env flag enables FlashInfer and batch==150
  App->>AR: forward(output, params)
  AR->>FWS: get_flashinfer_allreduce_workspace(mapping, dims)
  FWS-->>AR: workspace handles
  AR->>Ops: trtllm_custom_all_reduce / vllm_all_reduce(...)
  Ops-->>AR: reduced tensor
  AR-->>App: return reduced output
  App->>App: else fallback to standard all_reduce
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description contains only the template with placeholders and lacks any actual summary of the changes, detailed explanation, test coverage information, or completion of the checklist sections. Complete the description by providing a concise summary of the implemented changes and their purpose, listing relevant tests to cover the new functionality, and marking off the PR checklist items such as coding guidelines adherence, tests, and documentation updates.
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title “draft: all reduce” is overly generic and does not follow the repository’s ticket/type template or clearly convey the specific feature being added. Revise the title to follow the required format and concisely describe the change, for example “[TRTLLM-1234][feat] Add FlashInfer all-reduce support.”
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fefa7d8 and 325b45b.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/distributed/__init__.py (2 hunks)
  • tensorrt_llm/_torch/distributed/communicator.py (6 hunks)
  • tensorrt_llm/_torch/distributed/ops.py (2 hunks)
  • tensorrt_llm/_torch/models/modeling_llama.py (0 hunks)
  • tensorrt_llm/_torch/modules/embedding.py (1 hunks)
  • tensorrt_llm/_torch/modules/linear.py (4 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (2 hunks)
  • tensorrt_llm/functional.py (2 hunks)
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/models/modeling_llama.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/modules/embedding.py
  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/distributed/__init__.py
  • tensorrt_llm/_torch/distributed/communicator.py
  • tensorrt_llm/functional.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/modules/linear.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/modules/embedding.py
  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/distributed/__init__.py
  • tensorrt_llm/_torch/distributed/communicator.py
  • tensorrt_llm/functional.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/modules/linear.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/modules/embedding.py
  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/distributed/__init__.py
  • tensorrt_llm/_torch/distributed/communicator.py
  • tensorrt_llm/functional.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/modules/linear.py
🧬 Code graph analysis (6)
tensorrt_llm/_torch/distributed/ops.py (3)
tensorrt_llm/functional.py (7)
  • AllReduceFusionOp (3890-3899)
  • AllReduceParams (3902-3941)
  • AllReduceStrategy (3878-3887)
  • FlashInferAllReduceParams (3983-4013)
  • rank (494-498)
  • Tensor (109-604)
  • size (482-492)
tensorrt_llm/mapping.py (1)
  • Mapping (32-519)
tensorrt_llm/_torch/distributed/communicator.py (5)
  • TorchDist (421-486)
  • get_torch_dist_flashinfer_vllm_comm (535-536)
  • pp_rank (67-68)
  • rank (31-32)
  • tp_size (55-56)
tensorrt_llm/_torch/distributed/__init__.py (2)
tensorrt_llm/_torch/distributed/ops.py (3)
  • AllReduce (483-631)
  • FlashInferAllReduce (729-841)
  • FlashInferVLLMAllReduce (844-887)
tensorrt_llm/functional.py (3)
  • AllReduceParams (3902-3941)
  • AllReduceStrategy (3878-3887)
  • FlashInferAllReduceParams (3983-4013)
tensorrt_llm/_torch/distributed/communicator.py (2)
tensorrt_llm/mapping.py (2)
  • Mapping (32-519)
  • local_rank (403-404)
tensorrt_llm/_torch/distributed/ops.py (1)
  • allgather (186-270)
tensorrt_llm/functional.py (2)
cpp/tensorrt_llm/thop/allreduceOp.cpp (2)
  • strategy (691-722)
  • strategy (691-691)
cpp/tensorrt_llm/kernels/customAllReduceKernels.h (1)
  • AllReduceFusionOp (69-171)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
tensorrt_llm/_torch/distributed/communicator.py (3)
  • init_pp_comm (513-516)
  • init_torch_dist_tp_comm (529-532)
  • has_tp (39-40)
tensorrt_llm/mapping.py (1)
  • has_tp (427-428)
tensorrt_llm/_torch/modules/linear.py (1)
tensorrt_llm/_torch/distributed/ops.py (3)
  • AllReduce (483-631)
  • FlashInferAllReduce (729-841)
  • FlashInferVLLMAllReduce (844-887)
🪛 Ruff (0.13.2)
tensorrt_llm/_torch/distributed/ops.py

858-858: Unused method argument: all_reduce_params

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/modules/embedding.py (1)

67-67: LGTM! Conservative default preserves existing behavior.

The hardcoded use_flashinfer_allreduce=False explicitly opts LMHead out of the new FlashInfer all-reduce pathway, maintaining current behavior while enabling the feature for other layers.

tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

35-35: LGTM! Import follows existing pattern.

The addition of init_torch_dist_tp_comm to the import statement mirrors the existing init_pp_comm import and is consistent with the module's structure.

tensorrt_llm/_torch/distributed/__init__.py (1)

25-27: LGTM! Export list matches imports.

The __all__ additions correctly expose the three new FlashInfer-based all-reduce symbols to the public API, maintaining consistency with the import statements (assuming the import verification passes).

Comment on lines 4 to 6
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy,
FlashInferAllReduce, FlashInferAllReduceParams,
FlashInferVLLMAllReduce, MoEAllReduce, MoEAllReduceParams,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify FlashInferAllReduceParams import source.

The code imports FlashInferAllReduceParams from .ops, but the provided code snippet shows it defined in tensorrt_llm/functional.py (lines 3982-4012). Verify that .ops either defines this class or re-exports it from functional.


🏁 Script executed:

#!/bin/bash
# Description: Verify FlashInferAllReduceParams is available in ops module

# Check if FlashInferAllReduceParams is defined or imported in ops.py
rg "class FlashInferAllReduceParams|from.*functional.*import.*FlashInferAllReduceParams|import.*FlashInferAllReduceParams" tensorrt_llm/_torch/distributed/ops.py -C 3

# Also check functional.py to confirm the definition location
rg "class FlashInferAllReduceParams" tensorrt_llm/functional.py -C 2

Length of output: 75


Ensure FlashInferAllReduceParams is imported from the correct module.
File: tensorrt_llm/_torch/distributed/init.py: lines 4–6 – FlashInferAllReduceParams isn’t defined or re-exported in ops.py. Either add a re-export in tensorrt_llm/_torch/distributed/ops.py or import it directly from .functional.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/distributed/__init__.py around lines 4 to 6,
FlashInferAllReduceParams is being imported from .ops but that name is not
defined or re-exported there; either add "from .functional import
FlashInferAllReduceParams" to this __init__.py and replace the import in the
tuple, or preferably open tensorrt_llm/_torch/distributed/ops.py and add "from
.functional import FlashInferAllReduceParams" and include
FlashInferAllReduceParams in ops.py's __all__ (or re-export it) so the name is
available when importing from .ops; update the import list accordingly and run
tests to ensure no import errors remain.

Comment on lines +23 to +28
try:
import flashinfer.comm as flashinfer_comm
except ImportError:
print(
"FlashInfer comm module not found. Follow readme to install Flashinfer >=2.8.0."
)
exit(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not exit(1) on optional FlashInfer import failure

Executing exit(1) during module import will terminate every consumer that imports tensorrt_llm._torch.distributed.ops on hosts without FlashInfer, even if they never touch the FlashInfer path. That silently kills training/inference jobs and breaks environments where FlashInfer is optional. Replace this with a graceful fallback (e.g., set a feature flag, raise a descriptive exception when FlashInfer usage is attempted, or log a warning) so the rest of the module continues to work.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/distributed/ops.py around lines 23 to 29, importing
flashinfer currently calls exit(1) on ImportError which kills any process that
imports this module; instead, catch the ImportError, log a clear warning, set a
local flag or variable (e.g., flashinfer_comm = None or _HAS_FLASHINFER = False)
so the module can continue importing, and update any FlashInfer-dependent
functions to check that flag and raise a descriptive exception only when
FlashInfer functionality is actually invoked.

from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import flashinfer.comm as flashinfer_comm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard optional FlashInfer dependency

Importing flashinfer.comm unconditionally makes this module unusable anywhere FlashInfer is not installed (ImportError during import). Since FlashInfer support is optional, wrap the import in a try/except and gate all FlashInfer usage on it.

- import flashinfer.comm as flashinfer_comm
+try:
+    import flashinfer.comm as flashinfer_comm  # optional dependency
+except ImportError:
+    flashinfer_comm = None

Then, before instantiating any FlashInfer objects, check flashinfer_comm is None and either fall back to the NCCL path or raise a clear error only when the feature is explicitly requested.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import flashinfer.comm as flashinfer_comm
try:
import flashinfer.comm as flashinfer_comm # optional dependency
except ImportError:
flashinfer_comm = None
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/linear.py around line 10, the module imports
flashinfer.comm unconditionally which raises ImportError when FlashInfer isn't
installed; wrap the import in a try/except (setting flashinfer_comm = None on
failure) and update all subsequent FlashInfer usages to first check if
flashinfer_comm is None and either fall back to the NCCL path or raise a clear,
explicit error only when the caller explicitly requested FlashInfer
functionality.

Comment on lines 2011 to 2027
if self.flashinfer_trtllm and output.size(0) == 150:
output = self.flash_infer_all_reduce(
output,
all_reduce_params=None,
)
elif self.flashinfer_vllm and output.size(0) == 150:
output = self.flash_infer_all_reduce(
output,
all_reduce_params=None,
)
else:
output = self.all_reduce(
output,
all_reduce_params=all_reduce_params,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove hard-coded token-count gate

With _USE_FLASHINFER_* enabled we still fall back to the regular AllReduce unless the batch has exactly 150 tokens. That makes the new path effectively unusable outside this single size. Please drop the magic number and run the FlashInfer all-reduce whenever the corresponding flag is set.

-                if self.flashinfer_trtllm and output.size(0) == 150:
+                if self.flashinfer_trtllm:
                     output = self.flash_infer_all_reduce(
                         output,
                         all_reduce_params=None,
                     )
-                elif self.flashinfer_vllm and output.size(0) == 150:
+                elif self.flashinfer_vllm:
                     output = self.flash_infer_all_reduce(
                         output,
                         all_reduce_params=None,
                     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.flashinfer_trtllm and output.size(0) == 150:
output = self.flash_infer_all_reduce(
output,
all_reduce_params=None,
)
elif self.flashinfer_vllm and output.size(0) == 150:
output = self.flash_infer_all_reduce(
output,
all_reduce_params=None,
)
else:
output = self.all_reduce(
output,
all_reduce_params=all_reduce_params,
)
if self.flashinfer_trtllm:
output = self.flash_infer_all_reduce(
output,
all_reduce_params=None,
)
elif self.flashinfer_vllm:
output = self.flash_infer_all_reduce(
output,
all_reduce_params=None,
)
else:
output = self.all_reduce(
output,
all_reduce_params=all_reduce_params,
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/linear.py around lines 2011 to 2025, the code
gates the FlashInfer all-reduce behind a hard-coded batch-size check of
output.size(0) == 150; change this so that whenever self.flashinfer_trtllm or
self.flashinfer_vllm is true you call self.flash_infer_all_reduce (passing the
same all_reduce_params or None as currently used) without checking the token
count, otherwise call self.all_reduce with all_reduce_params as before. Ensure
both flag branches invoke flash_infer_all_reduce and preserve the original
all_reduce_params behavior for the fallback.

import tensorrt as trt
# isort: on

import flashinfer.comm as flashinfer_comm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Make FlashInfer import optional

This module now hard-depends on flashinfer-python; if that wheel is absent the entire functional API fails to import. Please guard the import (e.g. try/except ImportError) and defer raising until the FlashInfer path is used.

🤖 Prompt for AI Agents
In tensorrt_llm/functional.py around line 29, the direct import "import
flashinfer.comm as flashinfer_comm" makes the module fail to import if
flashinfer-python is not installed; wrap the import in a try/except ImportError
and set a module-level sentinel (e.g. flashinfer_comm = None) on failure, then
update any code paths that require FlashInfer to check the sentinel and only
attempt to import or raise a descriptive ImportError when those
FlashInfer-specific functions/paths are actually invoked (or perform a lazy
import inside those functions).

Comment on lines +4011 to +4014
self.strategy = strategy
self.fusion_op = fusion_op
self.config_mode = config_mode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don’t overwrite base enums with FlashInfer enums

FlashInferAllReduceParams inherits AllReduceParams, whose helpers compare self.strategy/self.fusion_op against the local AllReduceStrategy / AllReduceFusionOp. Overwriting those attributes with flashinfer_comm enums breaks those comparisons (FlashInfer NONE != AllReduce NONE), so generic call-sites (e.g. create_allreduce_plugin, allreduce) will mis-detect the configuration and crash or emit wrong inputs. Keep the base enums intact (store FlashInfer values in dedicated attributes).

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20551 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15508 completed with status: 'FAILURE'

@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20559 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20559 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15516 completed with status: 'FAILURE'

@NVShreyas
Copy link
Collaborator Author

/bot run

1 similar comment
@NVShreyas
Copy link
Collaborator Author

/bot run

@NVShreyas NVShreyas marked this pull request as ready for review October 7, 2025 17:04
@NVShreyas
Copy link
Collaborator Author

/bot run

@NVShreyas
Copy link
Collaborator Author

/bot kill

@NVShreyas NVShreyas force-pushed the user/shreyasm/flashinfer-all-reduce branch from 19f779a to 5428b52 Compare October 8, 2025 15:50
@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20807 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20807 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #15731 completed with status: 'FAILURE'

@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20811 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20811 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15735 completed with status: 'FAILURE'

@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20821 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20821 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #15742 completed with status: 'FAILURE'

@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21041 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21041 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #15907 completed with status: 'FAILURE'

Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
This reverts commit 29665dd.
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
@NVShreyas NVShreyas force-pushed the user/shreyasm/flashinfer-all-reduce branch from 1b576b8 to 0ed45fe Compare October 15, 2025 17:03
@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21494 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21494 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16227 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants