Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9780215
Restructure iris into host/device split
mawad-amd Apr 22, 2026
90772e0
Delete old files and update all imports to new paths
mawad-amd Apr 22, 2026
3012e60
Apply Ruff auto-fixes
mawad-amd Apr 22, 2026
db71785
Remove all iris.experimental.iris_gluon references
mawad-amd Apr 22, 2026
7d1aa92
Add iris.gluon shortcut, consolidate docs into host/device split
mawad-amd Apr 22, 2026
8431c37
Fix remaining stale imports in tests and examples
mawad-amd Apr 22, 2026
7f0ff4b
Apply Ruff auto-fixes
mawad-amd Apr 22, 2026
4d78e0a
Address review feedback: remove get_backend, fix error messages
mawad-amd Apr 22, 2026
c014524
Remove dead hasattr checks and stale IrisGluon references in CCL
mawad-amd Apr 22, 2026
077bc73
Fix docs build: mock triton.language.extra.hip, fix cross-ref links
mawad-amd Apr 22, 2026
20fb324
Mock triton.language.target_info for docs build
mawad-amd Apr 22, 2026
22716da
Fix broadcast test: src_rank -> source_rank
mawad-amd Apr 22, 2026
8fef12e
Rename broadcast source_rank -> src to match torch.distributed API
mawad-amd Apr 22, 2026
3e7914b
Remove duplicate broadcast test, drop backend suffix
mawad-amd Apr 22, 2026
4efb825
Fix test_device_context import: iris.iris -> iris.device.triton.context
mawad-amd Apr 22, 2026
d0dfb9c
Fix stale import: iris.allocators -> iris.host.memory.allocators
mawad-amd Apr 22, 2026
0ac37ad
Fix copy() dst pointer cast to use dst_ptr.dtype
mawad-amd Apr 25, 2026
14ae045
Resolve merge conflict in test_topology.py
mawad-amd Apr 25, 2026
c1cd286
Fix stale iris.topology imports in drivers
mawad-amd Apr 25, 2026
a8ab81d
Merge main into muhaawad/refactor
mawad-amd Apr 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
import iris.experimental.iris_gluon as iris_gl
import iris
from iris.gluon import IrisDeviceCtx

# Device-side APIs - context encapsulates heap_bases
@gluon.jit
Expand Down Expand Up @@ -163,20 +164,20 @@ def _worker(rank, world_size):

# Iris initialization
heap_size = 2**30 # 1GiB symmetric heap
iris_ctx = iris_gl.iris(heap_size)
iris_ctx = iris.iris(heap_size)
context_tensor = iris_ctx.get_device_context() # Get encoded context
cur_rank = iris_ctx.get_rank()

# Iris tensor allocation
buffer_size = 4096 # 4K elements buffer
buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32)

# Launch the kernel on rank 0
block_size = 1024
grid = (buffer_size + block_size - 1) // block_size
source_rank = 0
if cur_rank == source_rank:
kernel[(grid,)](iris_gl.IrisDeviceCtx, context_tensor,
kernel[(grid,)](IrisDeviceCtx, context_tensor,
buffer, buffer_size, block_size, num_warps=1)

# Synchronize all ranks
Expand Down
9 changes: 6 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@
autodoc_typehints = "description"
autodoc_typehints_format = "short"

# Render objects without full module path (e.g., show "Iris" instead of "iris.iris.Iris")
# Render objects without full module path (e.g., show "Iris" instead of "iris.host.iris.Iris")
add_module_names = False

# Mock heavy/runtime-only dependencies when building docs
autodoc_mock_imports = [
"torch",
"numpy",
"iris._distributed_helpers",
"iris.hip",
"iris.host.distributed.helpers",
"iris.host.platform.hip",
"tritonblas",
]

Expand All @@ -118,6 +118,9 @@ def __call__(self, func=None, **kwargs):
sys.modules["triton.language"] = triton_language_mock
sys.modules["triton.language.core"] = MagicMock()
sys.modules["triton.language.core"]._aggregate = lambda cls: cls # Preserve class
sys.modules["triton.language.extra"] = MagicMock()
sys.modules["triton.language.extra.hip"] = MagicMock()
sys.modules["triton.language.target_info"] = MagicMock()


# Mock triton modules with docstring-preserving jit decorator
Expand Down
11 changes: 6 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
import iris.experimental.iris_gluon as iris_gl
import iris
from iris.gluon import IrisDeviceCtx

# Device-side APIs - context encapsulates heap_bases
@gluon.jit
Expand Down Expand Up @@ -167,20 +168,20 @@ def _worker(rank, world_size):

# Iris initialization
heap_size = 2**30 # 1GiB symmetric heap
iris_ctx = iris_gl.iris(heap_size)
iris_ctx = iris.iris(heap_size)
context_tensor = iris_ctx.get_device_context() # Get encoded context
cur_rank = iris_ctx.get_rank()

# Iris tensor allocation
buffer_size = 4096 # 4K elements buffer
buffer = iris_ctx.zeros(buffer_size, device="cuda", dtype=torch.float32)

# Launch the kernel on rank 0
block_size = 1024
grid = (buffer_size + block_size - 1) // block_size
source_rank = 0
if cur_rank == source_rank:
kernel[(grid,)](iris_gl.IrisDeviceCtx, context_tensor,
kernel[(grid,)](IrisDeviceCtx, context_tensor,
buffer, buffer_size, block_size, num_warps=1)

# Synchronize all ranks
Expand Down
40 changes: 18 additions & 22 deletions docs/reference/api-reference.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
# API Reference

Explore Iris APIs. The reference is broken down into focused sections to mirror common workflows:

- The `Iris` class itself (constructor and helper utilities)
- Tensor-like creation methods on the `Iris` context
- Triton device-side functions for remote memory ops and atomics
- Collective communication operations (CCL)
- Fused GEMM + CCL operations
- Experimental Gluon APIs (using `@aggregate` and `@gluon.jit`)

Use the links below to navigate:

- [Triton](triton/overview.md)
- [Iris Class](triton/class.md)
- [Tensor Creation](triton/tensor-creation.md)
- [Device Functions](triton/device-functions.md)
- [Collective Communication (CCL)](triton/ccl.md)
- [Fused GEMM + CCL Operations](triton/ops.md)
- [Gluon (Experimental)](gluon/overview.md)
- [Iris Class](gluon/class.md)
- [Tensor Creation](gluon/tensor-creation.md)
- [Device Functions](gluon/device-functions.md)
- [Collective Communication (CCL)](gluon/ccl.md)
Explore Iris APIs. The host-side API (class, tensor creation, CCL) is identical across all backends. Only the device-side API differs between Triton and Gluon.

## Host API (All Backends)

- [Iris Class](host/class.md)
- [Tensor Creation](host/tensor-creation.md)
- [Collective Communication (CCL)](host/ccl.md)

## Triton Backend

- [Overview](triton/overview.md)
- [Device Functions](triton/device-functions.md)
- [Fused GEMM + CCL Operations](triton/ops.md)

## Gluon Backend (Experimental)

- [Overview](gluon/overview.md)
- [Device Functions](gluon/device-functions.md)

22 changes: 0 additions & 22 deletions docs/reference/gluon/ccl.md

This file was deleted.

54 changes: 0 additions & 54 deletions docs/reference/gluon/class.md

This file was deleted.

30 changes: 15 additions & 15 deletions docs/reference/gluon/device-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,95 +10,95 @@ Device-side functions provided by Iris Gluon for remote memory operations and at

### initialize
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.initialize
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.initialize
:noindex:
```

## Memory transfer operations

### load
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.load
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.load
:noindex:
```

### store
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.store
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.store
:noindex:
```

### copy
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.copy
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.copy
:noindex:
```

### get
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.get
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.get
:noindex:
```

### put
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.put
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.put
:noindex:
```

## Atomic operations

### atomic_add
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_add
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_add
:noindex:
```

### atomic_sub
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_sub
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_sub
:noindex:
```

### atomic_cas
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_cas
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_cas
:noindex:
```

### atomic_xchg
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_xchg
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_xchg
:noindex:
```

### atomic_xor
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_xor
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_xor
:noindex:
```

### atomic_and
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_and
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_and
:noindex:
```

### atomic_or
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_or
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_or
:noindex:
```

### atomic_min
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_min
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_min
:noindex:
```

### atomic_max
```{eval-rst}
.. automethod:: iris.experimental.iris_gluon.IrisDeviceCtx.atomic_max
.. automethod:: iris.device.gluon.context.IrisDeviceCtx.atomic_max
:noindex:
```

20 changes: 11 additions & 9 deletions docs/reference/gluon/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ The Gluon API provides a Triton Gluon-based implementation of Iris that uses the
## Usage Example

```python
import iris.experimental.iris_gluon as iris_gl
import iris
from iris.gluon import IrisDeviceCtx
from triton.experimental import gluon
from triton.experimental.gluon import language as gl

# Host-side: Initialize Iris Gluon context
ctx = iris_gl.iris(heap_size=2**30) # 1GB heap
ctx = iris.iris(heap_size=2**30) # 1GB heap
context_tensor = ctx.get_device_context()

# Device-side: Use in Gluon kernels
Expand All @@ -49,10 +50,10 @@ def kernel(IrisDeviceCtx: gl.constexpr, context_tensor, buffer):

Explore the API by section:

- [Iris Class](class.md)
- [Tensor Creation](tensor-creation.md)
- [Iris Class](../host/class.md)
- [Tensor Creation](../host/tensor-creation.md)
- [Device Functions](device-functions.md)
- [Collective Communication (CCL)](ccl.md)
- [Collective Communication (CCL)](../host/ccl.md)

## Complete Example: Producer-Consumer Pattern

Expand All @@ -64,7 +65,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
import iris.experimental.iris_gluon as iris_gl
import iris
from iris.gluon import IrisDeviceCtx

@gluon.jit
def producer_kernel(
Expand Down Expand Up @@ -137,7 +139,7 @@ def worker(rank, world_size):
)

# Initialize Iris Gluon
ctx = iris_gl.iris(heap_size=2**30)
ctx = iris.iris(heap_size=2**30)
context_tensor = ctx.get_device_context()

# Allocate buffers
Expand All @@ -159,7 +161,7 @@ def worker(rank, world_size):
if rank == producer_rank:
ctx.info(f"Rank {rank} producing data...")
producer_kernel[grid](
iris_gl.IrisDeviceCtx,
IrisDeviceCtx,
context_tensor,
source,
target,
Expand All @@ -173,7 +175,7 @@ def worker(rank, world_size):
else:
ctx.info(f"Rank {rank} consuming data...")
consumer_kernel[grid](
iris_gl.IrisDeviceCtx,
IrisDeviceCtx,
context_tensor,
target,
flag,
Expand Down
Loading
Loading