Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions iris/ccl/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def persistent_all_gather_gluon(
COMM_SMS: gl.constexpr,
THREADS_PER_WARP: gl.constexpr,
WARPS_PER_CTA: gl.constexpr,
TRACING: gl.constexpr = False,
):
"""
Persistent all-gather kernel using Gluon with flat-2D tiling.
Expand Down Expand Up @@ -374,7 +375,7 @@ def persistent_all_gather_gluon(
THREADS_PER_WARP: Threads per warp/wavefront (64 for AMD, 32 for NVIDIA).
WARPS_PER_CTA: Number of warps per workgroup. Must match num_warps.
"""
ctx = IrisDeviceCtx.initialize(context_tensor, tracing=False)
ctx = IrisDeviceCtx.initialize(context_tensor, tracing=TRACING)

pid = gl.program_id(0)

Expand Down Expand Up @@ -445,17 +446,17 @@ def persistent_all_gather_gluon(
def all_gather(
output_tensor,
input_tensor,
shmem,
ctx,
group=None,
async_op=False,
config=None,
):
"""
Internal all-gather collective operation implementation.

This function is called internally by shmem.ccl.all_gather().
This function is called internally by ctx.ccl.all_gather().
Users should use the Iris instance method instead:
>>> shmem.ccl.all_gather(output_tensor, input_tensor)
>>> ctx.ccl.all_gather(output_tensor, input_tensor)

Each rank sends its input tensor to all ranks, and all ranks receive
and concatenate all input tensors along dimension 0 (rows), matching
Expand All @@ -464,7 +465,7 @@ def all_gather(
Args:
output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs
input_tensor: Input tensor of shape (M, N) - local rank's data to send
shmem: Iris shmem context
ctx: Iris context
group: ProcessGroup or None. If None, uses all ranks in `iris` context.
Default: None.
async_op: If False, performs a barrier at the end. If True, returns immediately.
Expand All @@ -480,7 +481,7 @@ def all_gather(
# Extract group information
# rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel
# rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem)
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx)

M, N = input_tensor.shape[:2]
expected_output_shape = (world_size * M, N)
Expand All @@ -496,8 +497,8 @@ def all_gather(

# Choose between Triton and Gluon implementation
if config.use_gluon and GLUON_AVAILABLE:
# Check if shmem is Iris Gluon (has get_device_context method)
if not hasattr(shmem, "get_device_context"):
# Check if ctx is Iris Gluon (has get_device_context method)
if not hasattr(ctx, "get_device_context"):
raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()")

# Gluon only supports the persistent variant
Expand Down Expand Up @@ -535,7 +536,9 @@ def all_gather(
f"Recommended: block_size_m=8, block_size_n=256."
)

context_tensor = shmem.get_device_context()
context_tensor = ctx.get_device_context()
tracing = getattr(ctx, "tracing", None)
tracing_enabled = bool(tracing and getattr(tracing, "enabled", False))

iris_launch(
persistent_all_gather_gluon,
Expand All @@ -561,6 +564,7 @@ def all_gather(
config.comm_sms,
config.threads_per_warp,
config.num_warps,
tracing_enabled,
num_stages=config.num_stages,
num_warps=config.num_warps,
waves_per_eu=config.waves_per_eu,
Expand All @@ -579,7 +583,7 @@ def all_gather(
f"Please adjust config.comm_sms to be a multiple of {world_size}."
)

heap_bases = shmem.get_heap_bases()
heap_bases = ctx.get_heap_bases()

# Dispatch to the appropriate kernel based on variant
if config.all_gather_variant == "persistent":
Expand Down Expand Up @@ -621,4 +625,4 @@ def all_gather(
)

if not async_op:
shmem.barrier()
ctx.barrier()
88 changes: 33 additions & 55 deletions iris/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,68 @@
"""
Device-side utility functions for Iris.

This module provides low-level device intrinsics for accessing hardware
information and timing within Triton kernels.
Provides portable device intrinsics for timestamps and hardware topology
that work across all supported AMD GPU architectures. Uses Triton's
architecture-aware APIs (``tl.extra.hip``) where available.
"""

import triton
import triton.language as tl
from triton.language.extra.hip import memrealtime, smid
from triton.language.target_info import is_hip_cdna3, is_hip_cdna4
Comment thread
mawad-amd marked this conversation as resolved.


@triton.jit
def read_realtime():
"""
Read GPU wall clock timestamp from s_memrealtime.
Read GPU wall clock timestamp.

Returns a 64-bit timestamp from a constant 100MHz clock (not affected
by power modes or core clock frequency changes).
Returns a 64-bit value from the GPU's constant-frequency real-time
counter (100 MHz, unaffected by power states or clock gating).

Delegates to ``tl.extra.hip.memrealtime()`` which emits the correct
instruction for each architecture family.

Returns:
int64: Current timestamp in cycles (100MHz constant clock)
int64: Current timestamp in cycles (100 MHz constant clock)
"""
tmp = tl.inline_asm_elementwise(
asm="""s_waitcnt vmcnt(0)
s_memrealtime $0
s_waitcnt lgkmcnt(0)""",
constraints=("=s"),
args=[],
dtype=tl.int64,
is_pure=False,
pack=1,
)
return tmp
return memrealtime()


@triton.jit
def get_xcc_id():
"""
Get XCC (GPU chiplet) ID.

On multi-XCC parts (CDNA3/CDNA4) reads ``HW_REG_XCC_ID``.
On single-die architectures returns 0.

Returns:
int32: XCC ID for the current execution
"""
xcc_id = tl.inline_asm_elementwise(
asm="s_getreg_b32 $0, hwreg(HW_REG_XCC_ID, 0, 16)",
constraints=("=s"),
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
return xcc_id
if is_hip_cdna3() or is_hip_cdna4():
return tl.inline_asm_elementwise(
asm="s_getreg_b32 $0, hwreg(HW_REG_XCC_ID, 0, 16)",
constraints=("=s"),
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
else:
return tl.cast(0, tl.int32)
Comment thread
mawad-amd marked this conversation as resolved.


@triton.jit
def get_cu_id():
"""
Get Compute Unit ID.

Returns:
int32: CU ID for the current execution
"""
cu_id = tl.inline_asm_elementwise(
asm="s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 8, 4)",
constraints=("=s"),
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
return cu_id

Get compute-unit / workgroup-processor ID for the current wave.

@triton.jit
def get_se_id():
"""
Get Shader Engine ID.
Delegates to ``tl.extra.hip.smid()`` which reads the appropriate
hardware register for each architecture family (CU_ID on CDNA,
WGP_ID on RDNA).

Returns:
int32: SE ID for the current execution
int32: CU / WGP ID for the current execution
"""
se_id = tl.inline_asm_elementwise(
asm="s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 13, 3)",
constraints=("=s"),
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
return se_id
return smid()
Loading