Skip to content

Commit

Permalink
fix embedding ColWiseParallel in qwen model
Browse files Browse the repository at this point in the history
  • Loading branch information
liym27 committed Feb 27, 2025
1 parent 45b5012 commit 9e93201
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/transformers/qwen/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __init__(self, config):
self.recompute_granularity = config.recompute_granularity

self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
self.wte.weight = dist.shard_tensor(self.wte.weight, get_mesh(), [dist.Replicate(), dist.Shard(0)])
self.wte.weight = dist.shard_tensor(self.wte.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)])

Check warning on line 541 in paddlenlp/transformers/qwen/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling_auto.py#L541

Added line #L541 was not covered by tests
self.drop = nn.Dropout(config.emb_dropout_prob)

self.h = nn.LayerList(
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/qwen/modeling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def auto_dist_config(self, prefix=""):
"sp_config": {
"parallelize_plan": {
f"{prefix}qwen.wte": [
dist.RowWiseParallel(),
dist.ColWiseParallel(),
dist.SequenceParallelBegin(),
],
f"{prefix}qwen.h.*.attn.c_attn": dist.ColWiseParallel(),
Expand All @@ -684,7 +684,7 @@ def auto_dist_config(self, prefix=""):
},
"mp_config": {
"parallelize_plan": {
f"{prefix}qwen.wte": dist.RowWiseParallel(),
f"{prefix}qwen.wte": dist.ColWiseParallel(),
f"{prefix}qwen.h.*.attn.c_attn": dist.ColWiseParallel(),
f"{prefix}qwen.h.*.attn.c_proj": dist.RowWiseParallel(),
f"{prefix}qwen.h.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
Expand Down

0 comments on commit 9e93201

Please sign in to comment.