Skip to content

Commit

Permalink
Add sharding support for latent attention block
Browse files Browse the repository at this point in the history
Baselined latent attention against sharded version.
  • Loading branch information
rsuderman committed Feb 7, 2025
1 parent 3cccc20 commit 190ca77
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 2 deletions.
12 changes: 12 additions & 0 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def forward(
start_index: int,
):
table = self.rotary_embed_table
if isinstance(xt, ReplicatedTensor):
return ReplicatedTensor(
ts=[
self.forward_unsharded(
xt=unbox_tensor(s),
start_index=start_index,
rotary_embed_table=unbox_tensor(t),
)
for s, t in zip(xt.shards, table.shards)
]
)

if not isinstance(xt, SplitPrimitiveTensor):
return self.forward_unsharded(
xt=xt,
Expand Down
10 changes: 9 additions & 1 deletion sharktank/sharktank/models/deepseek/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def forward(
q_nope = q[:, :, :, :qk_nope_head_dim]
q_rope = q[:, :, :, qk_nope_head_dim:]
q_rope = embedding(xt=q_rope, start_index=0)
q = torch.cat((q_nope, q_rope), dim=-1)
q = ops.cat((q_nope, q_rope), dim=-1)

kv = self.wkv_a(h)
kv_nope_size = kv.shape[-1] - self.rope_dimension_count
Expand All @@ -179,6 +179,14 @@ def forward(
v = wkv_b[:, :, :, qk_nope_head_dim:]

k_rope = ops.repeat(k_rope, (1, 1, k_nope.shape[2] // k_rope.shape[2], 1))

if isinstance(k_rope, ReplicatedTensor) and isinstance(
k_nope, SplitPrimitiveTensor
):
k_rope = ops.reshard_split(
k_rope, dim=k_nope.shard_dim, count=k_nope.shard_count
)

k = ops.cat((k_nope, k_rope), dim=-1)

q = q.transpose(1, 2)
Expand Down
57 changes: 57 additions & 0 deletions sharktank/sharktank/models/deepseek/sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Specifications describing how the Llama model is sharded."""

from ...types.sharding import *
from ...types import Theta
from ... import ops

from ..llama.llama import LlamaModelConfig


class LatentAttentionBlockSharding(ThetaLayerSharding):
def __init__(self, shard_count: int):
super().__init__()
self.shard_count = shard_count

def theta_sharding(self) -> ThetaSharding:
return ThetaSharding(
{
# The size of this is the token embedding length, which is not a memory
# space concern if replicated even for all attention blocks.
"attn_norm": RmsNormReplicatedSharding(
self.shard_count
).theta_sharding(),
"q_norm": RmsNormReplicatedSharding(self.shard_count).theta_sharding(),
"kv_norm": RmsNormReplicatedSharding(self.shard_count).theta_sharding(),
"wq": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"wq_a": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"wq_b": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"wkv_a": LinearReplicatedWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"wkv_b": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"wo": LinearSplitReductionDimSharding(
shard_count=self.shard_count
).theta_sharding(),
"attn_output": LinearSplitReductionDimSharding(
shard_count=self.shard_count
).theta_sharding(),
}
)


def shard_theta(theta: Theta, sharding: ThetaLayerSharding) -> Theta:
return ops.reshard(theta, sharding)
18 changes: 17 additions & 1 deletion sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,17 @@ def linear_sharded(
# Sharded matmuls.


@matmul.override(ReplicatedTensor, ReplicatedTensor)
def matmul_replicated_lhs_split_rhs(
lhs: ReplicatedTensor, rhs: ReplicatedTensor, *, transpose_rhs: bool
) -> ReplicatedTensor:
shards = [
matmul(lhs_shard, rhs_shard)
for (lhs_shard, rhs_shard) in zip(lhs.shards, rhs.shards)
]
return ReplicatedTensor(ts=shards)


@matmul.override(ReplicatedTensor, SplitPrimitiveTensor)
def matmul_replicated_lhs_split_rhs(
lhs: ReplicatedTensor, rhs: SplitPrimitiveTensor, *, transpose_rhs: bool
Expand Down Expand Up @@ -1172,7 +1183,12 @@ def transpose_split(
def unflatten_split(
input: SplitPrimitiveTensor, dim: int, sizes: Tuple[int]
) -> SplitPrimitiveTensor:
assert dim != input.shard_dim, "Unflattening the split dimension is not supported."
if dim == input.shard_dim:
if sizes[0] == -1:
assert (
dim != input.shard_dim
), "Unflattening the split dimension is not supported."
sizes = tuple([sizes[0] // input.shard_dim] + [s for s in sizes[1:]])
shards = [unflatten(shard, dim, sizes) for shard in input.shards]
shard_dim = input.shard_dim
if dim < shard_dim:
Expand Down
10 changes: 10 additions & 0 deletions sharktank/sharktank/types/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ def theta_sharding(self) -> ThetaSharding:
)


class LinearReplicatedWeightAndBiasSharding(LinearLayerSharding):
def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0):
"""The linear operation is replicated across devices"""
super().__init__(
premul_input=Replicated(shard_count=shard_count),
weight=Replicated(shard_count=shard_count),
bias=Replicated(shard_count=shard_count),
)


class LinearSplitParallelWeightAndBiasSharding(LinearLayerSharding):
def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0):
"""Split one parallel dimension for both the weight and bias.
Expand Down
88 changes: 88 additions & 0 deletions sharktank/tests/models/deepseek/test_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2025 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


from sharktank.models.deepseek.deepseek import PagedLatentAttentionBlock
from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer
from sharktank.models.deepseek.sharding import (
flat_to_nested_dict,
shard_theta,
LatentAttentionBlockSharding,
)
from sharktank.models.deepseek.toy_deepseek import generate
from sharktank import ops
from sharktank.types.theta import Theta

import pytest
import torch


def test_deepseek():
theta, config = generate(12345)
theta = theta("blk", 0)

sharding = 2
spec = LatentAttentionBlockSharding(sharding)

keys = {k for k in theta.keys}
sharding_keys = {k for k in spec.theta_sharding().keys()}
flattened = theta.flatten()

t = {}
for k in flattened:
if k.split(".")[0] in sharding_keys:
t[k] = flattened[k]
theta = Theta(flat_to_nested_dict(t))

sharded_theta = shard_theta(theta, spec)

hp = config.hp
reference_model = PagedLatentAttentionBlock(
theta=theta,
block_index=0,
cache=None,
head_count=hp.attention_head_count,
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
rope_dimension_count=hp.rope_dimension_count,
)

sharded_model = PagedLatentAttentionBlock(
theta=sharded_theta,
block_index=0,
cache=None,
head_count=hp.attention_head_count,
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
rope_dimension_count=hp.rope_dimension_count,
)

bs = 1
seq = 11
embed = hp.embedding_length
input = torch.rand((bs, seq, embed))

embedding = RotaryEmbeddingLayer(
rope_dimension_count=hp.rope_dimension_count,
rope_freq_base=hp.rope_freq_base,
max_seqlen=hp.context_length,
)

sharded_embedding = RotaryEmbeddingLayer(
rope_dimension_count=hp.rope_dimension_count,
rope_freq_base=hp.rope_freq_base,
max_seqlen=hp.context_length,
tensor_parallelism_size=sharding,
)

reference = reference_model.forward(embedding=embedding, h=input)
sharded = sharded_model.forward(
embedding=sharded_embedding, h=ops.replicate(input, count=sharding)
)
sharded = ops.unshard(sharded)
assert torch.isclose(reference, sharded, atol=1e-5).all()

0 comments on commit 190ca77

Please sign in to comment.