Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions bagua/torch_api/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torch.distributed import ProcessGroup as TorchProcessGroup
import gorilla
import weakref
import os

# fmt: off
__all__ = [
Expand Down Expand Up @@ -383,7 +384,6 @@ def get_backend(model_name: str):
def run_flask_app(port):
from flask import Flask
from gevent.pywsgi import WSGIServer
import os

os.environ["WERKZEUG_RUN_MAIN"] = "true"

Expand Down Expand Up @@ -443,7 +443,8 @@ def _find_free_bagua_service_port(store) -> int:
return service_port


def init_process_group(store: Optional[torch.distributed.Store] = None):
def init_process_group(store: Optional[torch.distributed.Store] = None, rank: int = -1,
world_size: int = -1, local_world_size: int = -1):
"""Initializes the PyTorch builtin distributed process group, and this will
also initialize the distributed package, should be executed before all the
APIs of Bagua.
Expand All @@ -452,6 +453,10 @@ def init_process_group(store: Optional[torch.distributed.Store] = None):
store: Key/value store accessible to all workers, used to exchange
connection/address information. If ``None``, a TCP-based store will be created.
Default: ``None``.
rank: Rank of the current process (it should be a number between 0 and world_size-1).
Required if store is specified.
world_size: Number of processes participating in the job. Required if store is specified.
local_world_size: Number of processes per node. Required if store is specified.

Examples::
>>> import torch
Expand All @@ -474,11 +479,10 @@ def init_process_group(store: Optional[torch.distributed.Store] = None):

.. note::
Each process should be associated to a CUDA device using `torch.cuda.set_device()`,
before calling :meth:`init_process_group`. Otherwise you may encounter the
before calling :meth:`init_process_group`. Otherwise, you may encounter the
`fatal runtime error: Rust cannot catch foreign exceptions` error.
"""


global _default_pg
global _default_store
global _autotune_service_port
Expand All @@ -495,11 +499,24 @@ def init_process_group(store: Optional[torch.distributed.Store] = None):
store.set_timeout(timeout)
_default_store = store
else:
assert rank >= 0
assert world_size > 0
assert local_world_size > 0

os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(rank % local_world_size)
os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size)

_default_store = store

_autotune_service_port = _find_free_bagua_service_port(_default_store)
if get_rank() == 0 and _autotune_server is None:
start_autotune_server(_autotune_service_port)
if _autotune_service_port is None:
if get_rank() == 0:
_autotune_service_port = _find_free_bagua_service_port(_default_store)
store.set("bagua_autotune_service_port", str(_autotune_service_port))
start_autotune_server(_autotune_service_port)
else:
_autotune_service_port = int(store.get("bagua_autotune_service_port"))

AUTOTUNE_SERVER_WAIT_TIME = 30
wait_time = get_autotune_server_wait_time()
Expand Down