-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] AutoDeploy: Precompute the A log for mamba layers #9344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9603844
1a0f90b
4692586
948e5d4
0439b3b
9c95592
3f71b77
c5179ce
f78ae5a
73e2b1b
df1daca
9a1cf7e
5b411c5
4783eb3
06e2fb3
ad94c87
448700e
26e8041
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,7 +94,7 @@ def cuda_causal_conv_prepare_metadata_fake( | |
| ) | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={}) | ||
| @torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={"input"}) | ||
| def _cuda_cached_causal_conv1d( | ||
| # INPUTS (dense but may be flattened across sequences) | ||
| input: torch.Tensor, # [b, s, c_in] | ||
|
|
@@ -114,13 +114,15 @@ def _cuda_cached_causal_conv1d( | |
| groups: int, | ||
| padding_mode: str, | ||
| activation: Optional[str], | ||
| ) -> torch.Tensor: | ||
| ) -> None: | ||
| """Flattened cached causal conv that respects slot-indexed state caches (CUDA backend). | ||
|
|
||
| Supports two layouts from the attention interface: | ||
| - Generate-only: input is [b, 1, c_in]. We'll gather caches using slot_idx[:b]. | ||
| - Flattened context/mixed: input is [1, total_s, c_in] and seq_len/seq_start | ||
| describe per-sequence segments. We'll process each segment and scatter final states to caches. | ||
|
|
||
| NOTE: This op modifies `input` in-place. | ||
| """ | ||
| b, s = input.shape[:2] | ||
| num_seq = seq_len.shape[0] | ||
|
|
@@ -137,8 +139,6 @@ def _cuda_cached_causal_conv1d( | |
| # Flatten tokens | ||
| bs = b * s | ||
| inp_flat = input.reshape(bs, *input.shape[2:]) # [total_s, C_in] | ||
| y = torch.empty(b, s, weight.shape[0], device=input.device, dtype=input.dtype) | ||
| y_flat = y.view(bs, *y.shape[2:]) | ||
|
|
||
| # Prepare weight as [dim, width] (depthwise) | ||
| if weight.ndim == 3: | ||
|
|
@@ -155,6 +155,7 @@ def _cuda_cached_causal_conv1d( | |
| total_prefill_tokens = int(seq_len_prefill.sum().item()) | ||
|
|
||
| # x_varlen: (dim, cu_seq_len) | ||
| # We must clone to make it contiguous for the kernel | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the comment about cloning is probably from an earlier version |
||
| x_varlen = inp_flat[:total_prefill_tokens].transpose(0, 1).contiguous() | ||
|
|
||
| # Metadata | ||
|
|
@@ -181,17 +182,17 @@ def _cuda_cached_causal_conv1d( | |
| pad_slot_id=PAD_SLOT_ID, | ||
| ) # (dim, total_prefill_tokens) | ||
|
|
||
| # Scatter outputs back to y | ||
| y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out] | ||
| y_flat[:total_prefill_tokens].copy_(y_prefill) | ||
| # Scatter outputs back to input buffer | ||
| inp_flat[:total_prefill_tokens] = y_varlen.transpose(0, 1) | ||
|
|
||
| # DECODE: batch update for single-token sequences | ||
| if num_decode > 0: | ||
| x_decode = inp_flat[ | ||
| total_prefill_tokens : total_prefill_tokens + num_decode | ||
| ] # [num_decode, C_in] | ||
|
|
||
| y_dec = causal_conv1d_update( | ||
| # causal_conv1d_update modifies x_decode in-place | ||
| causal_conv1d_update( | ||
| x_decode, # [batch, dim] | ||
| conv_state_cache, | ||
| w2d, | ||
|
|
@@ -201,13 +202,9 @@ def _cuda_cached_causal_conv1d( | |
| conv_state_indices=slot_idx[num_prefill:].to(torch.int32), | ||
| pad_slot_id=PAD_SLOT_ID, | ||
| ) | ||
| # No copy needed! | ||
|
|
||
| if y_dec.dim() == 3: | ||
| y_dec = y_dec.squeeze(-1) | ||
| y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec) | ||
|
|
||
| # Custom op must not return an alias of any input; return a fresh tensor | ||
| return y | ||
| return | ||
|
|
||
|
|
||
| @_cuda_cached_causal_conv1d.register_fake | ||
|
|
@@ -231,9 +228,12 @@ def _cuda_cached_causal_conv1d_fake( | |
| padding_mode: str, | ||
| activation: Optional[str], | ||
| ): | ||
| return torch.empty( | ||
| input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype | ||
| ) | ||
| return | ||
|
|
||
|
|
||
| def cuda_cached_causal_conv1d_wrapper(input, *args, **kwargs): | ||
| torch.ops.auto_deploy.cuda_cached_causal_conv1d(input, *args, **kwargs) | ||
| return input | ||
|
|
||
|
|
||
| @AttentionRegistry.register("cuda_causal_conv") | ||
|
|
@@ -259,7 +259,7 @@ def get_source_attention_op(cls) -> OpOverloadPacket: | |
|
|
||
| @classmethod | ||
| def get_cached_attention_op(cls) -> MHACallable: | ||
| return torch.ops.auto_deploy.cuda_cached_causal_conv1d | ||
| return cuda_cached_causal_conv1d_wrapper | ||
|
|
||
| @classmethod | ||
| def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| """Transform to fuse A_log into A for Mamba/NemotronH models.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like there is a memory leak. Could you also add a unit test that checks for mem usage before after this transformation |
||
|
|
||
| import operator | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.fx import GraphModule | ||
|
|
||
| from ...models.factory import ModelFactory | ||
| from ...shim.interface import CachedSequenceInterface | ||
| from ...utils.logger import ad_logger | ||
| from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry | ||
|
|
||
|
|
||
| def _get_attr_by_name(obj, name): | ||
| for part in name.split("."): | ||
| obj = getattr(obj, part) | ||
| return obj | ||
|
|
||
|
|
||
| def _set_attr_by_name(obj, name, value): | ||
| parts = name.split(".") | ||
| for part in parts[:-1]: | ||
| obj = getattr(obj, part) | ||
| setattr(obj, parts[-1], value) | ||
|
|
||
|
|
||
| @TransformRegistry.register("fuse_mamba_a_log") | ||
| class FuseMambaALog(BaseTransform): | ||
| """Fuse A_log parameter into A constant/parameter. | ||
| Replaces: | ||
| A = -torch.exp(self.A_log.float()) | ||
| With: | ||
| A = self.A_fused | ||
| """ | ||
|
Comment on lines
+29
to
+37
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we use the new pattern matcher - style replacement here? This is going to be really hard to maintain. Please check with @Fridah-nv if you have any questions |
||
|
|
||
| def _apply( | ||
| self, | ||
| gm: GraphModule, | ||
| cm: CachedSequenceInterface, | ||
| factory: ModelFactory, | ||
| shared_config: SharedConfig, | ||
| ) -> Tuple[GraphModule, TransformInfo]: | ||
| num_matches = 0 | ||
|
|
||
| # Candidates for operations | ||
| exp_ops = {torch.exp, torch.ops.aten.exp.default, "exp"} | ||
| neg_ops = {operator.neg, torch.neg, torch.ops.aten.neg.default, "neg"} | ||
|
|
||
| # We search bottom-up starting from A_log parameters to be more robust | ||
| # pattern: A_log -> [optional cast] -> exp -> neg | ||
|
|
||
| # Snapshot nodes to avoid modification issues during iteration | ||
| nodes = list(gm.graph.nodes) | ||
|
|
||
| for node in nodes: | ||
| if node.op != "get_attr": | ||
| continue | ||
|
|
||
| if not node.target.endswith("A_log"): | ||
| continue | ||
| # Found an A_log node. Check its usage. | ||
| users = list(node.users.keys()) | ||
| for user in users: | ||
| # 1. Check for optional Cast | ||
| current_node = user | ||
|
|
||
| # Skip cast/to nodes | ||
| exp_node = None | ||
|
|
||
| # Walk forward looking for exp | ||
| cursor = current_node | ||
| for _ in range(3): # Max depth for casts | ||
| if (cursor.op == "call_function" and cursor.target in exp_ops) or ( | ||
| cursor.op == "call_method" and cursor.target == "exp" | ||
| ): | ||
| exp_node = cursor | ||
| break | ||
|
|
||
| if len(cursor.users) != 1: | ||
| break | ||
| cursor = list(cursor.users.keys())[0] | ||
|
|
||
| if not exp_node: | ||
| continue | ||
|
|
||
| # 2. Check for Neg | ||
| if len(exp_node.users) != 1: | ||
| continue | ||
|
|
||
| neg_node = list(exp_node.users.keys())[0] | ||
| is_neg = (neg_node.op == "call_function" and neg_node.target in neg_ops) or ( | ||
| neg_node.op == "call_method" and neg_node.target == "neg" | ||
| ) | ||
|
|
||
| if not is_neg: | ||
| continue | ||
| # Found the pattern: node -> ... -> exp_node -> neg_node | ||
| num_matches += 1 | ||
|
|
||
| # Perform Fusion | ||
| param_name = node.target | ||
| try: | ||
| a_log = _get_attr_by_name(gm, param_name) | ||
| except AttributeError: | ||
| ad_logger.warning(f"Could not find attribute {param_name} in gm.") | ||
| continue | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe move |
||
|
|
||
| # Compute A_fused | ||
| with torch.no_grad(): | ||
| # Replicate the logic: -exp(a_log.float()) | ||
| a_fused = -torch.exp(a_log.float()) | ||
|
|
||
| new_param_name = param_name.replace("A_log", "A_fused") | ||
|
|
||
| # Check if we already created this param (if A_log used multiple times) | ||
| try: | ||
| _get_attr_by_name(gm, new_param_name) | ||
| except AttributeError: | ||
| _set_attr_by_name( | ||
| gm, new_param_name, nn.Parameter(a_fused, requires_grad=False) | ||
| ) | ||
|
|
||
| # Replace usage | ||
| with gm.graph.inserting_before(neg_node): | ||
| new_node = gm.graph.create_node("get_attr", new_param_name) | ||
|
|
||
| neg_node.replace_all_uses_with(new_node) | ||
|
|
||
| if num_matches > 0: | ||
| gm.graph.eliminate_dead_code() | ||
|
|
||
| return gm, TransformInfo( | ||
| skipped=False, | ||
| num_matches=num_matches, | ||
| is_clean=num_matches == 0, | ||
| has_valid_shapes=True, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.