Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ transforms:
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
fuse_mamba_a_log:
stage: post_load_fusion
fuse_gemms:
stage: post_load_fusion
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def cuda_causal_conv_prepare_metadata_fake(
)


@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={})
@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={"input"})
def _cuda_cached_causal_conv1d(
# INPUTS (dense but may be flattened across sequences)
input: torch.Tensor, # [b, s, c_in]
Expand All @@ -114,13 +114,15 @@ def _cuda_cached_causal_conv1d(
groups: int,
padding_mode: str,
activation: Optional[str],
) -> torch.Tensor:
) -> None:
"""Flattened cached causal conv that respects slot-indexed state caches (CUDA backend).

Supports two layouts from the attention interface:
- Generate-only: input is [b, 1, c_in]. We'll gather caches using slot_idx[:b].
- Flattened context/mixed: input is [1, total_s, c_in] and seq_len/seq_start
describe per-sequence segments. We'll process each segment and scatter final states to caches.

NOTE: This op modifies `input` in-place.
"""
b, s = input.shape[:2]
num_seq = seq_len.shape[0]
Expand All @@ -137,8 +139,6 @@ def _cuda_cached_causal_conv1d(
# Flatten tokens
bs = b * s
inp_flat = input.reshape(bs, *input.shape[2:]) # [total_s, C_in]
y = torch.empty(b, s, weight.shape[0], device=input.device, dtype=input.dtype)
y_flat = y.view(bs, *y.shape[2:])

# Prepare weight as [dim, width] (depthwise)
if weight.ndim == 3:
Expand All @@ -155,6 +155,7 @@ def _cuda_cached_causal_conv1d(
total_prefill_tokens = int(seq_len_prefill.sum().item())

# x_varlen: (dim, cu_seq_len)
# We must clone to make it contiguous for the kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment about cloning is probably from an earlier version

x_varlen = inp_flat[:total_prefill_tokens].transpose(0, 1).contiguous()

# Metadata
Expand All @@ -181,17 +182,17 @@ def _cuda_cached_causal_conv1d(
pad_slot_id=PAD_SLOT_ID,
) # (dim, total_prefill_tokens)

# Scatter outputs back to y
y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out]
y_flat[:total_prefill_tokens].copy_(y_prefill)
# Scatter outputs back to input buffer
inp_flat[:total_prefill_tokens] = y_varlen.transpose(0, 1)

# DECODE: batch update for single-token sequences
if num_decode > 0:
x_decode = inp_flat[
total_prefill_tokens : total_prefill_tokens + num_decode
] # [num_decode, C_in]

y_dec = causal_conv1d_update(
# causal_conv1d_update modifies x_decode in-place
causal_conv1d_update(
x_decode, # [batch, dim]
conv_state_cache,
w2d,
Expand All @@ -201,13 +202,9 @@ def _cuda_cached_causal_conv1d(
conv_state_indices=slot_idx[num_prefill:].to(torch.int32),
pad_slot_id=PAD_SLOT_ID,
)
# No copy needed!

if y_dec.dim() == 3:
y_dec = y_dec.squeeze(-1)
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec)

# Custom op must not return an alias of any input; return a fresh tensor
return y
return


@_cuda_cached_causal_conv1d.register_fake
Expand All @@ -231,9 +228,12 @@ def _cuda_cached_causal_conv1d_fake(
padding_mode: str,
activation: Optional[str],
):
return torch.empty(
input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype
)
return


def cuda_cached_causal_conv1d_wrapper(input, *args, **kwargs):
torch.ops.auto_deploy.cuda_cached_causal_conv1d(input, *args, **kwargs)
return input


@AttentionRegistry.register("cuda_causal_conv")
Expand All @@ -259,7 +259,7 @@ def get_source_attention_op(cls) -> OpOverloadPacket:

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.cuda_cached_causal_conv1d
return cuda_cached_causal_conv1d_wrapper

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,13 @@ def _triton_cached_ssm(
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H]

y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
y_flat = y.view(bs, *y.shape[2:])

ssm_state_size = B.shape[3]

num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist()

y_prefill = None
y_dec = None

# Prefill: concatenate tokens at the front and run combined scan
if num_prefill > 0:
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
Expand Down Expand Up @@ -232,7 +232,6 @@ def _triton_cached_ssm(
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
)

y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
ssm_state_cache.index_copy_(
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
)
Expand Down Expand Up @@ -265,9 +264,19 @@ def _triton_cached_ssm(
state_batch_indices=slot_idx_decode,
) # [nd, H, D]

y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype))

return y
# Dispatch return logic
if num_prefill > 0 and num_decode > 0:
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
y_flat = y.view(bs, *y.shape[2:])
y_flat[:num_prefill_tokens].copy_(y_prefill[0])
y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec)
return y
elif num_prefill > 0:
return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype)
elif num_decode > 0:
return y_dec.view(b, s, num_heads, head_dim).to(hidden_states.dtype)
else:
return torch.empty_like(hidden_states)


@_triton_cached_ssm.register_fake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,17 @@ def _apply(
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph

# Import wrapper to match against
# We use the wrapper because the underlying op returns None (void) to avoid aliasing,
# but the wrapper returns the tensor to maintain graph data flow.
from ...custom_ops.mamba.cuda_backend_causal_conv import cuda_cached_causal_conv1d_wrapper

target_op = cuda_cached_causal_conv1d_wrapper

# Step 1: Identify causal_conv + activation pattern
matches = _match_causal_conv_activation_pattern(
graph,
target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d,
target_op=target_op,
)

# Step 2: Replace matched patterns with fused version
Expand All @@ -98,7 +105,7 @@ def _apply(
# Replace the last arg (activation=None) with activation_name
new_args = list(conv_node.args[:-1]) + [activation_name]
fused_node = graph.call_function(
torch.ops.auto_deploy.cuda_cached_causal_conv1d,
target_op,
args=tuple(new_args),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Transform to fuse A_log into A for Mamba/NemotronH models."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like there is a memory leak. Could you also add a unit test that checks for mem usage before after this transformation


import operator
from typing import Tuple

import torch
import torch.nn as nn
from torch.fx import GraphModule

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry


def _get_attr_by_name(obj, name):
for part in name.split("."):
obj = getattr(obj, part)
return obj


def _set_attr_by_name(obj, name, value):
parts = name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)


@TransformRegistry.register("fuse_mamba_a_log")
class FuseMambaALog(BaseTransform):
"""Fuse A_log parameter into A constant/parameter.
Replaces:
A = -torch.exp(self.A_log.float())
With:
A = self.A_fused
"""
Comment on lines +29 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the new pattern matcher - style replacement here? This is going to be really hard to maintain. Please check with @Fridah-nv if you have any questions


def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
num_matches = 0

# Candidates for operations
exp_ops = {torch.exp, torch.ops.aten.exp.default, "exp"}
neg_ops = {operator.neg, torch.neg, torch.ops.aten.neg.default, "neg"}

# We search bottom-up starting from A_log parameters to be more robust
# pattern: A_log -> [optional cast] -> exp -> neg

# Snapshot nodes to avoid modification issues during iteration
nodes = list(gm.graph.nodes)

for node in nodes:
if node.op != "get_attr":
continue

if not node.target.endswith("A_log"):
continue
# Found an A_log node. Check its usage.
users = list(node.users.keys())
for user in users:
# 1. Check for optional Cast
current_node = user

# Skip cast/to nodes
exp_node = None

# Walk forward looking for exp
cursor = current_node
for _ in range(3): # Max depth for casts
if (cursor.op == "call_function" and cursor.target in exp_ops) or (
cursor.op == "call_method" and cursor.target == "exp"
):
exp_node = cursor
break

if len(cursor.users) != 1:
break
cursor = list(cursor.users.keys())[0]

if not exp_node:
continue

# 2. Check for Neg
if len(exp_node.users) != 1:
continue

neg_node = list(exp_node.users.keys())[0]
is_neg = (neg_node.op == "call_function" and neg_node.target in neg_ops) or (
neg_node.op == "call_method" and neg_node.target == "neg"
)

if not is_neg:
continue
# Found the pattern: node -> ... -> exp_node -> neg_node
num_matches += 1

# Perform Fusion
param_name = node.target
try:
a_log = _get_attr_by_name(gm, param_name)
except AttributeError:
ad_logger.warning(f"Could not find attribute {param_name} in gm.")
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe move num_matches += 1 to follow this try/except


# Compute A_fused
with torch.no_grad():
# Replicate the logic: -exp(a_log.float())
a_fused = -torch.exp(a_log.float())

new_param_name = param_name.replace("A_log", "A_fused")

# Check if we already created this param (if A_log used multiple times)
try:
_get_attr_by_name(gm, new_param_name)
except AttributeError:
_set_attr_by_name(
gm, new_param_name, nn.Parameter(a_fused, requires_grad=False)
)

# Replace usage
with gm.graph.inserting_before(neg_node):
new_node = gm.graph.create_node("get_attr", new_param_name)

neg_node.replace_all_uses_with(new_node)

if num_matches > 0:
gm.graph.eliminate_dead_code()

return gm, TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=True,
)