diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index acf93456e..f3c8c6d85 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -326,6 +326,7 @@ def persistent_all_gather_gluon( COMM_SMS: gl.constexpr, THREADS_PER_WARP: gl.constexpr, WARPS_PER_CTA: gl.constexpr, + TRACING: gl.constexpr = False, ): """ Persistent all-gather kernel using Gluon with flat-2D tiling. @@ -374,7 +375,7 @@ def persistent_all_gather_gluon( THREADS_PER_WARP: Threads per warp/wavefront (64 for AMD, 32 for NVIDIA). WARPS_PER_CTA: Number of warps per workgroup. Must match num_warps. """ - ctx = IrisDeviceCtx.initialize(context_tensor, tracing=False) + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=TRACING) pid = gl.program_id(0) @@ -445,7 +446,7 @@ def persistent_all_gather_gluon( def all_gather( output_tensor, input_tensor, - shmem, + ctx, group=None, async_op=False, config=None, @@ -453,9 +454,9 @@ def all_gather( """ Internal all-gather collective operation implementation. - This function is called internally by shmem.ccl.all_gather(). + This function is called internally by ctx.ccl.all_gather(). Users should use the Iris instance method instead: - >>> shmem.ccl.all_gather(output_tensor, input_tensor) + >>> ctx.ccl.all_gather(output_tensor, input_tensor) Each rank sends its input tensor to all ranks, and all ranks receive and concatenate all input tensors along dimension 0 (rows), matching @@ -464,7 +465,7 @@ def all_gather( Args: output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs input_tensor: Input tensor of shape (M, N) - local rank's data to send - shmem: Iris shmem context + ctx: Iris context group: ProcessGroup or None. If None, uses all ranks in `iris` context. Default: None. async_op: If False, performs a barrier at the end. If True, returns immediately. @@ -480,7 +481,7 @@ def all_gather( # Extract group information # rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel # rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations - rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem) + rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) M, N = input_tensor.shape[:2] expected_output_shape = (world_size * M, N) @@ -496,8 +497,8 @@ def all_gather( # Choose between Triton and Gluon implementation if config.use_gluon and GLUON_AVAILABLE: - # Check if shmem is Iris Gluon (has get_device_context method) - if not hasattr(shmem, "get_device_context"): + # Check if ctx is Iris Gluon (has get_device_context method) + if not hasattr(ctx, "get_device_context"): raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()") # Gluon only supports the persistent variant @@ -535,7 +536,9 @@ def all_gather( f"Recommended: block_size_m=8, block_size_n=256." ) - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() + tracing = getattr(ctx, "tracing", None) + tracing_enabled = bool(tracing and getattr(tracing, "enabled", False)) iris_launch( persistent_all_gather_gluon, @@ -561,6 +564,7 @@ def all_gather( config.comm_sms, config.threads_per_warp, config.num_warps, + tracing_enabled, num_stages=config.num_stages, num_warps=config.num_warps, waves_per_eu=config.waves_per_eu, @@ -579,7 +583,7 @@ def all_gather( f"Please adjust config.comm_sms to be a multiple of {world_size}." ) - heap_bases = shmem.get_heap_bases() + heap_bases = ctx.get_heap_bases() # Dispatch to the appropriate kernel based on variant if config.all_gather_variant == "persistent": @@ -621,4 +625,4 @@ def all_gather( ) if not async_op: - shmem.barrier() + ctx.barrier() diff --git a/iris/device_utils.py b/iris/device_utils.py index 1e328ebcf..ad7bb67e3 100644 --- a/iris/device_utils.py +++ b/iris/device_utils.py @@ -4,36 +4,32 @@ """ Device-side utility functions for Iris. -This module provides low-level device intrinsics for accessing hardware -information and timing within Triton kernels. +Provides portable device intrinsics for timestamps and hardware topology +that work across all supported AMD GPU architectures. Uses Triton's +architecture-aware APIs (``tl.extra.hip``) where available. """ import triton import triton.language as tl +from triton.language.extra.hip import memrealtime, smid +from triton.language.target_info import is_hip_cdna3, is_hip_cdna4 @triton.jit def read_realtime(): """ - Read GPU wall clock timestamp from s_memrealtime. + Read GPU wall clock timestamp. - Returns a 64-bit timestamp from a constant 100MHz clock (not affected - by power modes or core clock frequency changes). + Returns a 64-bit value from the GPU's constant-frequency real-time + counter (100 MHz, unaffected by power states or clock gating). + + Delegates to ``tl.extra.hip.memrealtime()`` which emits the correct + instruction for each architecture family. Returns: - int64: Current timestamp in cycles (100MHz constant clock) + int64: Current timestamp in cycles (100 MHz constant clock) """ - tmp = tl.inline_asm_elementwise( - asm="""s_waitcnt vmcnt(0) - s_memrealtime $0 - s_waitcnt lgkmcnt(0)""", - constraints=("=s"), - args=[], - dtype=tl.int64, - is_pure=False, - pack=1, - ) - return tmp + return memrealtime() @triton.jit @@ -41,53 +37,35 @@ def get_xcc_id(): """ Get XCC (GPU chiplet) ID. + On multi-XCC parts (CDNA3/CDNA4) reads ``HW_REG_XCC_ID``. + On single-die architectures returns 0. + Returns: int32: XCC ID for the current execution """ - xcc_id = tl.inline_asm_elementwise( - asm="s_getreg_b32 $0, hwreg(HW_REG_XCC_ID, 0, 16)", - constraints=("=s"), - args=[], - dtype=tl.int32, - is_pure=False, - pack=1, - ) - return xcc_id + if is_hip_cdna3() or is_hip_cdna4(): + return tl.inline_asm_elementwise( + asm="s_getreg_b32 $0, hwreg(HW_REG_XCC_ID, 0, 16)", + constraints=("=s"), + args=[], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + return tl.cast(0, tl.int32) @triton.jit def get_cu_id(): """ - Get Compute Unit ID. - - Returns: - int32: CU ID for the current execution - """ - cu_id = tl.inline_asm_elementwise( - asm="s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 8, 4)", - constraints=("=s"), - args=[], - dtype=tl.int32, - is_pure=False, - pack=1, - ) - return cu_id - + Get compute-unit / workgroup-processor ID for the current wave. -@triton.jit -def get_se_id(): - """ - Get Shader Engine ID. + Delegates to ``tl.extra.hip.smid()`` which reads the appropriate + hardware register for each architecture family (CU_ID on CDNA, + WGP_ID on RDNA). Returns: - int32: SE ID for the current execution + int32: CU / WGP ID for the current execution """ - se_id = tl.inline_asm_elementwise( - asm="s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 13, 3)", - constraints=("=s"), - args=[], - dtype=tl.int32, - is_pure=False, - pack=1, - ) - return se_id + return smid()