|
3 | 3 | import torch
|
4 | 4 | from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 | 5 | from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
|
6 |
| -from torch.distributed._tensor import init_device_mesh |
7 |
| -from torch.distributed._tensor import DTensor |
| 6 | +from torch.distributed._tensor import init_device_mesh, DTensor |
| 7 | +from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel |
8 | 8 |
|
9 | 9 |
|
10 | 10 | # We set this flag to true to allow operations on a mix of tensor and dtensor
|
|
15 | 15 | # Grab the model
|
16 | 16 | llama = AutoModelForCausalLM.from_pretrained(
|
17 | 17 | "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True,
|
18 |
| - torch_dtype=torch.float16 |
19 | 18 | )
|
20 |
| -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
| 19 | +llama.eval() |
21 | 20 |
|
| 21 | +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
22 | 22 | prompts = (
|
23 | 23 | "How do you", "I like to", "Can I help", "You need to",
|
24 | 24 | "The weather is", "I found a", "What is your", "You are so",
|
25 | 25 | ) # bs = 8
|
26 | 26 | tokenizer.pad_token = tokenizer.eos_token
|
| 27 | +inputs = tokenizer(prompts, return_tensors="pt", padding=True) |
27 | 28 |
|
28 | 29 | rank = int(os.environ["RANK"])
|
29 | 30 | world_size = int(os.environ["WORLD_SIZE"])
|
30 | 31 | device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
|
31 | 32 |
|
| 33 | +# Initialize 2D device mesh |
32 | 34 | pp_group_size = 2
|
33 | 35 | tp_group_size = 4
|
34 | 36 | mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
|
35 | 37 | pp_group = mesh_2d["pp"].get_group()
|
36 | 38 |
|
37 |
| -llama.eval() |
38 |
| -inputs = tokenizer(prompts, return_tensors="pt", padding=True) |
39 |
| - |
40 | 39 | # Cut model by equal number of layers per rank
|
41 | 40 | layers_per_stage = llama.config.num_hidden_layers // pp_group_size
|
42 | 41 | for i in range(1, pp_group_size):
|
|
51 | 50 | stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)
|
52 | 51 |
|
53 | 52 | # Tensor parallel
|
54 |
| -from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel |
55 | 53 | starting_layer = stage_idx * layers_per_stage
|
56 | 54 | attn_plan = {}
|
57 | 55 | mlp_plan = {}
|
|
77 | 75 | parallelize_module(
|
78 | 76 | stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
|
79 | 77 | )
|
80 |
| -inputs = inputs.to(device) |
| 78 | + |
81 | 79 | # Run
|
| 80 | +inputs = inputs.to(device) |
82 | 81 | if stage_idx == 0:
|
83 | 82 | args = inputs["input_ids"]
|
84 | 83 | else:
|
|
0 commit comments