Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions benchmark/ops/all_gather_matmul/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@
# TN/NT/TT would require kernel-level changes to permute strides.
SUPPORTED_TRANSPOSES = ("NN",)

# Supported GPU architectures with tuned configs
SUPPORTED_ARCHITECTURES = ("mi300x", "mi355x")
# Supported GPU architectures for auto-config selection. NVIDIA currently uses
# heuristic fallback configs rather than tuned JSON files.
SUPPORTED_ARCHITECTURES = ("mi300x", "mi355x", "nvidia")

# Map gfx target IDs to architecture names used in config paths
_GFX_TO_ARCH = {
Expand All @@ -96,8 +97,8 @@ def detect_gpu_arch() -> str:

Detection order:
1. IRIS_GPU_ARCH environment variable (override)
2. rocm-smi --showproductname parsing
3. rocminfo gfx target parsing
2. PyTorch CUDA-without-HIP detection for NVIDIA
3. rocminfo gfx target parsing for AMD
4. Falls back to "mi300x" (most common deployment target)

Returns:
Expand All @@ -113,7 +114,18 @@ def detect_gpu_arch() -> str:
_detected_arch = env_arch
return _detected_arch

# 2. Try rocminfo for gfx target
# 2. Check for NVIDIA CUDA via PyTorch. ROCm PyTorch also exposes
# torch.cuda, so require CUDA availability without a HIP version.
try:
import torch

if torch.cuda.is_available() and not getattr(torch.version, "hip", None):
_detected_arch = "nvidia"
return _detected_arch
except ImportError:
pass

# 3. Try rocminfo for AMD gfx target
try:
result = subprocess.run(
["rocminfo"],
Expand All @@ -132,7 +144,7 @@ def detect_gpu_arch() -> str:
except (FileNotFoundError, subprocess.TimeoutExpired, OSError):
pass

# 3. Fallback to MI300X (most common deployment target)
# 4. Fallback to MI300X (most common deployment target)
_detected_arch = "mi300x"
return _detected_arch

Expand Down Expand Up @@ -318,6 +330,25 @@ def _apply_heuristic(M: int, N: int, K: int, arch: str = "mi300x") -> Tuple[Dict
bk = 64
num_k_blocks = K // bk

if arch == "nvidia":
config_params = {
"block_size_m": 128,
"block_size_n": 128,
"block_size_k": bk,
"group_size_m": 8,
"num_xcds": 1,
"allow_tf32": True,
}
hbm_params = {
"k_per_flag": 8,
"num_fetch_sms": 16,
"num_fetch_stages": 1,
"first_stage_fetch_sms": 32,
"num_warps": 4,
"num_stages": 2,
}
return config_params, hbm_params

if arch == "mi355x":
bm = 256
num_m_tiles = M // bm
Expand Down Expand Up @@ -510,7 +541,19 @@ def select_ag_mm_config(
source=f"Heuristic (no exact shape match in {arch}/{transpose}/ws{world_size}.json)",
)

# Step 2: No config file found — check global default
# Step 2: No config file found for this architecture. For new/untuned
# architectures such as NVIDIA, enable heuristic configs directly instead
# of applying AMD-specific global world-size gates.
if arch not in ("mi300x", "mi355x"):
heuristic_config, heuristic_hbm = _apply_heuristic(M, N, K, arch=arch)
return AutoConfigResult(
enabled=True,
config_params=heuristic_config,
hbm_buffer_params=heuristic_hbm,
source=f"Heuristic fallback for {arch} (no tuned configs available)",
)
Comment on lines +547 to +554

# Step 3: No AMD config file found — check global default
default_data = _load_default_config()
ws_gate = default_data.get("world_size_gate", {})
min_ws = ws_gate.get("min_world_size", 8)
Expand Down
62 changes: 48 additions & 14 deletions examples/14_all_gather_gemm/example_run_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.distributed as dist
import iris
import argparse
import os
from all_gather_gemm_pull import persistent_ag_gemm


Expand All @@ -36,6 +37,11 @@ def parse_args():
parser.add_argument(
"--dtype", type=str, default="float16", choices=["float16", "bfloat16"], help="PyTorch data type to use."
)
parser.add_argument(
"--print_topology",
action="store_true",
help="Print the Iris-discovered topology before initializing the symmetric heap.",
)

return parser.parse_args()

Expand Down Expand Up @@ -72,17 +78,39 @@ def setup_example_data(rank, world_size, args, dtype):
}


def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namespace):
def example_run(
rank: int,
world_size: int,
init_url: str,
args: argparse.Namespace,
local_rank: int | None = None,
):
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(
backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}")
)
if local_rank is None:
local_rank = rank
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
init_kwargs = {
"backend": backend,
"init_method": init_url,
"world_size": world_size,
"rank": rank,
}
if backend == "nccl":
init_kwargs["device_id"] = torch.device(f"cuda:{local_rank}")
dist.init_process_group(**init_kwargs)

if args.print_topology:
from iris.host.distributed.topology import TopologyDiscovery

topology = TopologyDiscovery().discover()
if rank == 0:
print(topology.summary(), flush=True)

# Initialize Iris for distributed communication
shmem = iris.iris()

torch.manual_seed(42) # Use a fixed seed for consistent random data
torch.cuda.set_device(rank)
dtype = getattr(torch, args.dtype)

if rank == 0:
Expand All @@ -103,7 +131,7 @@ def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namesp

C_fused = torch.empty(args.M, args.N, dtype=dtype).cuda() # Output tensor for our kernel

NUM_SMS = torch.cuda.get_device_properties(rank).multi_processor_count
NUM_SMS = torch.cuda.get_device_properties(local_rank).multi_processor_count
grid = (NUM_SMS,)

# Launch the fused Triton kernel
Expand Down Expand Up @@ -165,14 +193,20 @@ def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namesp

def main():
args = parse_args()
num_ranks = args.num_ranks
init_url = "tcp://127.0.0.1:29504"
mp.spawn(
fn=example_run,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
example_run(rank, world_size, "env://", args, local_rank=local_rank)
else:
num_ranks = args.num_ranks
init_url = "tcp://127.0.0.1:29504"
mp.spawn(
fn=example_run,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)


if __name__ == "__main__":
Expand Down
49 changes: 49 additions & 0 deletions iris/bench/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,55 @@ def main(argv: list[str] | None = None) -> None:
print("No benchmark configurations to run after applying filters/skips.", file=sys.stderr)
sys.exit(1)

# If launched by torchrun/srun for a multi-node job, do not spawn another
# local elastic job. The current process is already one benchmark rank.
if all(key in os.environ for key in ("RANK", "LOCAL_RANK", "WORLD_SIZE")):
Comment on lines +558 to +560
Comment thread
mawad-amd marked this conversation as resolved.
world_size = int(os.environ["WORLD_SIZE"])
global_rank = int(os.environ["RANK"])
if world_size not in all_num_ranks:
if global_rank == 0:
configured = ", ".join(str(n) for n in sorted(all_num_ranks))
print(
f"torchrun WORLD_SIZE={world_size} does not match benchmark num_ranks selection "
f"{{{configured}}}. Pass --axis_num_ranks={world_size} or launch with a matching world size.",
file=sys.stderr,
)
sys.exit(1)

dropped_num_ranks = sorted(all_num_ranks - {world_size})
if dropped_num_ranks and global_rank == 0:
dropped = ", ".join(str(n) for n in dropped_num_ranks)
print(
f"Warning: torchrun WORLD_SIZE={world_size}; skipping benchmark num_ranks values: {dropped}",
file=sys.stderr,
)

all_results = _run_benchmarks_worker(
benchmarks,
axis_overrides,
skip_overrides,
args.heap_size,
args.use_gluon,
args.n_warmup,
args.n_repeat,
args.benchmark_filter,
)

if global_rank == 0:
if args.benchmark_format == "json":
output = _format_json(all_results)
elif args.benchmark_format == "csv":
output = _format_csv(all_results)
else:
output = _format_console(all_results)

print(output, end="")

if args.benchmark_out:
with open(args.benchmark_out, "w") as f:
f.write(output)
return

# Launch once per unique num_ranks, collecting results across runs
all_results: list[Result] = []

Expand Down
Loading
Loading