-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-6342][feat] Factory TP sharding of quantized models #8123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from transformers.models.starcoder2.configuration_starcoder2 import Starcoder2Config | ||
|
||
# Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57 | ||
Starcoder2Config.base_model_tp_plan["layers.*.mlp.c_proj"] = "rowwise" | ||
Comment on lines
+3
to
+4
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major 🧩 Analysis chainVerify transformers version 4.57 availability and add tracking for patch removal. The comment references transformers >= 4.57 as the removal threshold. Based on learnings, version 4.56.0 was released in August 2025, but 4.57 has not been mentioned yet. Please verify the actual version that includes the fix and confirm the timeline. Additionally, consider adding a TODO or FIXME comment with an issue reference to ensure this temporary patch is tracked and removed appropriately. Run the following script to check the latest transformers version and search for related PRs: 🏁 Script executed: #!/bin/bash
# Check latest transformers version
curl -s https://pypi.org/pypi/transformers/json | jq -r '.info.version'
# Search for Starcoder2 TP plan fixes in transformers repo
gh api graphql -f query='
{
search(query: "repo:huggingface/transformers starcoder2 base_model_tp_plan", type: ISSUE, first: 5) {
nodes {
... on Issue {
number
title
state
createdAt
}
... on PullRequest {
number
title
state
createdAt
merged
}
}
}
}' Length of output: 325 Update patch removal comment for correct version and add tracking Replace the existing line with: # TODO: Remove this patch after upgrading to HuggingFace transformers ≥4.56.2 (fix included in HF PR #40814) 🤖 Prompt for AI Agents
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,7 +292,7 @@ def detect_sharding_from_factory_config( | |
num_simple_shards = 0 | ||
num_row_col_shards = 0 | ||
|
||
for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op): | ||
for lin_node in filtered_nodes(gm.graph.nodes, [is_linear_op, is_fake_quantized_linear_op]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicate nodes will be processed due to filtered_nodes bug. The call to This issue will be resolved once the bug in 🤖 Prompt for AI Agents
|
||
# use node's weight name to get the module name | ||
module_name = lin_node.args[1].target | ||
|
||
|
@@ -368,7 +368,7 @@ def detect_sharding_from_factory_config( | |
) | ||
num_row_col_shards += 1 | ||
else: | ||
ad_logger.warning("Invalid sharding config. Skipping.") | ||
ad_logger.warning(f"Unsupported sharding action {config}. Skipping.") | ||
else: | ||
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet. | ||
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") | ||
|
@@ -387,7 +387,19 @@ def detect_sharding_from_factory_config( | |
) | ||
num_simple_shards += 1 | ||
else: | ||
ad_logger.warning("Invalid sharding config. Skipping.") | ||
ad_logger.warning( | ||
f"Unsupported sharding action {config}. Fallback to simple shard" | ||
) | ||
sharding_config.tp_transforms.append( | ||
TPShardingInfo.from_node( | ||
lin_node, | ||
split_dim=SplitDimension.COLUMN, | ||
rank=rank, | ||
world_size=world_size, | ||
dist_op="all_gather", | ||
min_local_shape=1, | ||
) | ||
) | ||
# after successful match, break the loop | ||
break | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -239,6 +239,11 @@ def filtered_nodes( | |||||||||||||||||||||||
for node in nodes: | ||||||||||||||||||||||||
if target(node): | ||||||||||||||||||||||||
yield node | ||||||||||||||||||||||||
elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target): | ||||||||||||||||||||||||
for node in nodes: | ||||||||||||||||||||||||
for t in target: | ||||||||||||||||||||||||
if t(node): | ||||||||||||||||||||||||
yield node | ||||||||||||||||||||||||
Comment on lines
+242
to
+246
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add break after yielding to prevent duplicate nodes. The inner loop at line 244-246 will yield the same node multiple times if more than one predicate matches. This creates duplicates in the iteration results. Apply this diff to add a break statement: elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target):
for node in nodes:
for t in target:
if t(node):
yield node
+ break 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
# Handle the case where target or ops contains operations | ||||||||||||||||||||||||
operations = ops if ops is not None else target | ||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the required NVIDIA Apache-2.0 copyright header.
Per coding guidelines, all Python source files must include the NVIDIA Apache-2.0 copyright header with the current year at the top of the file.
Add this header at the top of the file:
As per coding guidelines.
📝 Committable suggestion
🤖 Prompt for AI Agents