-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add sharding support for latent attention block
Baselined latent attention against sharded version.
- Loading branch information
Showing
6 changed files
with
193 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Specifications describing how the Llama model is sharded.""" | ||
|
||
from ...types.sharding import * | ||
from ...types import Theta | ||
from ... import ops | ||
|
||
from ..llama.llama import LlamaModelConfig | ||
|
||
|
||
class LatentAttentionBlockSharding(ThetaLayerSharding): | ||
def __init__(self, shard_count: int): | ||
super().__init__() | ||
self.shard_count = shard_count | ||
|
||
def theta_sharding(self) -> ThetaSharding: | ||
return ThetaSharding( | ||
{ | ||
# The size of this is the token embedding length, which is not a memory | ||
# space concern if replicated even for all attention blocks. | ||
"attn_norm": RmsNormReplicatedSharding( | ||
self.shard_count | ||
).theta_sharding(), | ||
"q_norm": RmsNormReplicatedSharding(self.shard_count).theta_sharding(), | ||
"kv_norm": RmsNormReplicatedSharding(self.shard_count).theta_sharding(), | ||
"wq": LinearSplitParallelWeightAndBiasSharding( | ||
shard_count=self.shard_count | ||
).theta_sharding(), | ||
"wq_a": LinearSplitParallelWeightAndBiasSharding( | ||
shard_count=self.shard_count | ||
).theta_sharding(), | ||
"wq_b": LinearSplitParallelWeightAndBiasSharding( | ||
shard_count=self.shard_count | ||
).theta_sharding(), | ||
"wkv_a": LinearReplicatedWeightAndBiasSharding( | ||
shard_count=self.shard_count | ||
).theta_sharding(), | ||
"wkv_b": LinearSplitParallelWeightAndBiasSharding( | ||
shard_count=self.shard_count | ||
).theta_sharding(), | ||
"wo": LinearSplitReductionDimSharding( | ||
shard_count=self.shard_count | ||
).theta_sharding(), | ||
"attn_output": LinearSplitReductionDimSharding( | ||
shard_count=self.shard_count | ||
).theta_sharding(), | ||
} | ||
) | ||
|
||
|
||
def shard_theta(theta: Theta, sharding: ThetaLayerSharding) -> Theta: | ||
return ops.reshard(theta, sharding) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
|
||
from sharktank.models.deepseek.deepseek import PagedLatentAttentionBlock | ||
from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer | ||
from sharktank.models.deepseek.sharding import ( | ||
flat_to_nested_dict, | ||
shard_theta, | ||
LatentAttentionBlockSharding, | ||
) | ||
from sharktank.models.deepseek.toy_deepseek import generate | ||
from sharktank import ops | ||
from sharktank.types.theta import Theta | ||
|
||
import pytest | ||
import torch | ||
|
||
|
||
def test_deepseek(): | ||
theta, config = generate(12345) | ||
theta = theta("blk", 0) | ||
|
||
sharding = 2 | ||
spec = LatentAttentionBlockSharding(sharding) | ||
|
||
keys = {k for k in theta.keys} | ||
sharding_keys = {k for k in spec.theta_sharding().keys()} | ||
flattened = theta.flatten() | ||
|
||
t = {} | ||
for k in flattened: | ||
if k.split(".")[0] in sharding_keys: | ||
t[k] = flattened[k] | ||
theta = Theta(flat_to_nested_dict(t)) | ||
|
||
sharded_theta = shard_theta(theta, spec) | ||
|
||
hp = config.hp | ||
reference_model = PagedLatentAttentionBlock( | ||
theta=theta, | ||
block_index=0, | ||
cache=None, | ||
head_count=hp.attention_head_count, | ||
head_dim=hp.attn_head_dim, | ||
head_count_kv=hp.attention_head_count_kv, | ||
rms_epsilon=hp.attention_layer_norm_rms_epsilon, | ||
rope_dimension_count=hp.rope_dimension_count, | ||
) | ||
|
||
sharded_model = PagedLatentAttentionBlock( | ||
theta=sharded_theta, | ||
block_index=0, | ||
cache=None, | ||
head_count=hp.attention_head_count, | ||
head_dim=hp.attn_head_dim, | ||
head_count_kv=hp.attention_head_count_kv, | ||
rms_epsilon=hp.attention_layer_norm_rms_epsilon, | ||
rope_dimension_count=hp.rope_dimension_count, | ||
) | ||
|
||
bs = 1 | ||
seq = 11 | ||
embed = hp.embedding_length | ||
input = torch.rand((bs, seq, embed)) | ||
|
||
embedding = RotaryEmbeddingLayer( | ||
rope_dimension_count=hp.rope_dimension_count, | ||
rope_freq_base=hp.rope_freq_base, | ||
max_seqlen=hp.context_length, | ||
) | ||
|
||
sharded_embedding = RotaryEmbeddingLayer( | ||
rope_dimension_count=hp.rope_dimension_count, | ||
rope_freq_base=hp.rope_freq_base, | ||
max_seqlen=hp.context_length, | ||
tensor_parallelism_size=sharding, | ||
) | ||
|
||
reference = reference_model.forward(embedding=embedding, h=input) | ||
sharded = sharded_model.forward( | ||
embedding=sharded_embedding, h=ops.replicate(input, count=sharding) | ||
) | ||
sharded = ops.unshard(sharded) | ||
assert torch.isclose(reference, sharded, atol=1e-5).all() |