Skip to content

TP SP examples improvement #1354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def rank_log(_rank, logger, msg):

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus
24 changes: 18 additions & 6 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# torchrun --nnodes 1 --nproc-per-node 4 <fn>
import os
import sys
import torch
Expand All @@ -13,6 +14,7 @@

from log_utils import rank_log, get_logger, verify_min_gpu_count

from torch.distributed.tensor.debug import CommDebugMode

# ---- GPU check ------------
_min_gpu_count = 2
Expand Down Expand Up @@ -63,9 +65,10 @@ def forward(self, x):
"""
logger = get_logger()

device_type = torch.accelerator.current_accelerator().type
# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
device_type=device_type, mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)

_rank = device_mesh.get_rank()
Expand All @@ -75,7 +78,7 @@ def forward(self, x):
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to("cuda")
model = ToyModel().to(device_type)

# Custom parallelization plan for the model
sp_model = parallelize_module(
Expand All @@ -87,6 +90,8 @@ def forward(self, x):
},
)

if torch.distributed.get_rank() == 0:
print (f"model {sp_model}")

# Create a optimizer for the parallelized module.
lr = 0.25
Expand All @@ -98,12 +103,19 @@ def forward(self, x):
num_iters = 10
rank_log(_rank, logger, "Sequence Parallel training starting...")


for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device="cuda")
output = sp_model(inp)
output.sum().backward()
optimizer.step()
#inp = torch.rand(20, 10, device=device_type)
inp = torch.rand(1, 10, device=device_type)
comm_mode = CommDebugMode()
with comm_mode:
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")

if i == 0:
print (f" rank{torch.distributed.get_rank()} {i} get_comm_counts {comm_mode.get_comm_counts()} get_sharding_info() {comm_mode.get_sharding_info()} generate_comm_debug_tracing_table {comm_mode.generate_comm_debug_tracing_table(noise_level=1)} ")

rank_log(_rank, logger, "Sequence Parallel training completed!")
25 changes: 17 additions & 8 deletions distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# torchrun --nnodes 1 --nproc-per-node 4 <fn>
import os
import sys
import torch
Expand All @@ -10,6 +11,7 @@
)

from log_utils import rank_log, get_logger, verify_min_gpu_count
from torch.distributed.tensor.debug import CommDebugMode

# ---- GPU check ------------
_min_gpu_count = 2
Expand Down Expand Up @@ -76,8 +78,8 @@ def forward(self, x):

# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
device_type = torch.accelerator.current_accelerator().type
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


Expand All @@ -88,8 +90,8 @@ def forward(self, x):

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
tp_model = ToyModel().to("cuda")
# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids.
tp_model = ToyModel().to(device_type)


# Custom parallelization plan for the model
Expand All @@ -102,6 +104,9 @@ def forward(self, x):
},
)

if torch.distributed.get_rank() == 0:
print (f"model {tp_model}")

# Create an optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)
Expand All @@ -116,10 +121,14 @@ def forward(self, x):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10, device="cuda")
output = tp_model(inp)
output.sum().backward()
optimizer.step()
inp = torch.rand(4, 10, device=device_type)
comm_mode = CommDebugMode()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on non cuda devices? Would be great to share some local logs of your tests

with comm_mode:
output = tp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Tensor Parallel iter {i} completed")
if i == 1:
print (f" rank{torch.distributed.get_rank()} {i} get_comm_counts {comm_mode.get_comm_counts()} get_sharding_info() {comm_mode.get_sharding_info()} generate_comm_debug_tracing_table {comm_mode.generate_comm_debug_tracing_table(noise_level=1)} ")

rank_log(_rank, logger, "Tensor Parallel training completed!")