Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit c0f6152

Browse files
committed
Rearrange code
1 parent 37e110c commit c0f6152

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

examples/llama/2d_llama.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch
44
from transformers import AutoModelForCausalLM, AutoTokenizer
55
from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
6-
from torch.distributed._tensor import init_device_mesh
7-
from torch.distributed._tensor import DTensor
6+
from torch.distributed._tensor import init_device_mesh, DTensor
7+
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
88

99

1010
# We set this flag to true to allow operations on a mix of tensor and dtensor
@@ -15,28 +15,27 @@
1515
# Grab the model
1616
llama = AutoModelForCausalLM.from_pretrained(
1717
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True,
18-
torch_dtype=torch.float16
1918
)
20-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
19+
llama.eval()
2120

21+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
2222
prompts = (
2323
"How do you", "I like to", "Can I help", "You need to",
2424
"The weather is", "I found a", "What is your", "You are so",
2525
) # bs = 8
2626
tokenizer.pad_token = tokenizer.eos_token
27+
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
2728

2829
rank = int(os.environ["RANK"])
2930
world_size = int(os.environ["WORLD_SIZE"])
3031
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
3132

33+
# Initialize 2D device mesh
3234
pp_group_size = 2
3335
tp_group_size = 4
3436
mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
3537
pp_group = mesh_2d["pp"].get_group()
3638

37-
llama.eval()
38-
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
39-
4039
# Cut model by equal number of layers per rank
4140
layers_per_stage = llama.config.num_hidden_layers // pp_group_size
4241
for i in range(1, pp_group_size):
@@ -51,7 +50,6 @@
5150
stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)
5251

5352
# Tensor parallel
54-
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
5553
starting_layer = stage_idx * layers_per_stage
5654
attn_plan = {}
5755
mlp_plan = {}
@@ -77,8 +75,9 @@
7775
parallelize_module(
7876
stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
7977
)
80-
inputs = inputs.to(device)
78+
8179
# Run
80+
inputs = inputs.to(device)
8281
if stage_idx == 0:
8382
args = inputs["input_ids"]
8483
else:

pippy/PipelineStage.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,9 @@ def _send_activations(
478478
work = dist.isend(
479479
# HACK: we convert DTensor to regular tensor here for it to
480480
# work with send ops. DTensor may show up in PP + TP cases.
481-
out.to_local() if isinstance(out, torch.distributed._tensor.DTensor) else out,
481+
out.to_local()
482+
if isinstance(out, torch.distributed._tensor.DTensor)
483+
else out,
482484
peer_rank
483485
if self.group is None
484486
else dist.get_global_rank(self.group, peer_rank), # TODO

0 commit comments

Comments
 (0)