diff --git a/benchmark/ops/all_gather_matmul/auto_config.py b/benchmark/ops/all_gather_matmul/auto_config.py index 0e8990886..fbe7628e6 100644 --- a/benchmark/ops/all_gather_matmul/auto_config.py +++ b/benchmark/ops/all_gather_matmul/auto_config.py @@ -81,8 +81,9 @@ # TN/NT/TT would require kernel-level changes to permute strides. SUPPORTED_TRANSPOSES = ("NN",) -# Supported GPU architectures with tuned configs -SUPPORTED_ARCHITECTURES = ("mi300x", "mi355x") +# Supported GPU architectures for auto-config selection. NVIDIA currently uses +# heuristic fallback configs rather than tuned JSON files. +SUPPORTED_ARCHITECTURES = ("mi300x", "mi355x", "nvidia") # Map gfx target IDs to architecture names used in config paths _GFX_TO_ARCH = { @@ -96,8 +97,8 @@ def detect_gpu_arch() -> str: Detection order: 1. IRIS_GPU_ARCH environment variable (override) - 2. rocm-smi --showproductname parsing - 3. rocminfo gfx target parsing + 2. PyTorch CUDA-without-HIP detection for NVIDIA + 3. rocminfo gfx target parsing for AMD 4. Falls back to "mi300x" (most common deployment target) Returns: @@ -113,7 +114,18 @@ def detect_gpu_arch() -> str: _detected_arch = env_arch return _detected_arch - # 2. Try rocminfo for gfx target + # 2. Check for NVIDIA CUDA via PyTorch. ROCm PyTorch also exposes + # torch.cuda, so require CUDA availability without a HIP version. + try: + import torch + + if torch.cuda.is_available() and not getattr(torch.version, "hip", None): + _detected_arch = "nvidia" + return _detected_arch + except ImportError: + pass + + # 3. Try rocminfo for AMD gfx target try: result = subprocess.run( ["rocminfo"], @@ -132,7 +144,7 @@ def detect_gpu_arch() -> str: except (FileNotFoundError, subprocess.TimeoutExpired, OSError): pass - # 3. Fallback to MI300X (most common deployment target) + # 4. Fallback to MI300X (most common deployment target) _detected_arch = "mi300x" return _detected_arch @@ -318,6 +330,25 @@ def _apply_heuristic(M: int, N: int, K: int, arch: str = "mi300x") -> Tuple[Dict bk = 64 num_k_blocks = K // bk + if arch == "nvidia": + config_params = { + "block_size_m": 128, + "block_size_n": 128, + "block_size_k": bk, + "group_size_m": 8, + "num_xcds": 1, + "allow_tf32": True, + } + hbm_params = { + "k_per_flag": 8, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 32, + "num_warps": 4, + "num_stages": 2, + } + return config_params, hbm_params + if arch == "mi355x": bm = 256 num_m_tiles = M // bm @@ -510,7 +541,19 @@ def select_ag_mm_config( source=f"Heuristic (no exact shape match in {arch}/{transpose}/ws{world_size}.json)", ) - # Step 2: No config file found — check global default + # Step 2: No config file found for this architecture. For new/untuned + # architectures such as NVIDIA, enable heuristic configs directly instead + # of applying AMD-specific global world-size gates. + if arch not in ("mi300x", "mi355x"): + heuristic_config, heuristic_hbm = _apply_heuristic(M, N, K, arch=arch) + return AutoConfigResult( + enabled=True, + config_params=heuristic_config, + hbm_buffer_params=heuristic_hbm, + source=f"Heuristic fallback for {arch} (no tuned configs available)", + ) + + # Step 3: No AMD config file found — check global default default_data = _load_default_config() ws_gate = default_data.get("world_size_gate", {}) min_ws = ws_gate.get("min_world_size", 8) diff --git a/examples/14_all_gather_gemm/example_run_pull.py b/examples/14_all_gather_gemm/example_run_pull.py index 3dfe9733f..b588f3919 100644 --- a/examples/14_all_gather_gemm/example_run_pull.py +++ b/examples/14_all_gather_gemm/example_run_pull.py @@ -18,6 +18,7 @@ import torch.distributed as dist import iris import argparse +import os from all_gather_gemm_pull import persistent_ag_gemm @@ -36,6 +37,11 @@ def parse_args(): parser.add_argument( "--dtype", type=str, default="float16", choices=["float16", "bfloat16"], help="PyTorch data type to use." ) + parser.add_argument( + "--print_topology", + action="store_true", + help="Print the Iris-discovered topology before initializing the symmetric heap.", + ) return parser.parse_args() @@ -72,17 +78,39 @@ def setup_example_data(rank, world_size, args, dtype): } -def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namespace): +def example_run( + rank: int, + world_size: int, + init_url: str, + args: argparse.Namespace, + local_rank: int | None = None, +): backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group( - backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") - ) + if local_rank is None: + local_rank = rank + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + init_kwargs = { + "backend": backend, + "init_method": init_url, + "world_size": world_size, + "rank": rank, + } + if backend == "nccl": + init_kwargs["device_id"] = torch.device(f"cuda:{local_rank}") + dist.init_process_group(**init_kwargs) + + if args.print_topology: + from iris.host.distributed.topology import TopologyDiscovery + + topology = TopologyDiscovery().discover() + if rank == 0: + print(topology.summary(), flush=True) # Initialize Iris for distributed communication shmem = iris.iris() torch.manual_seed(42) # Use a fixed seed for consistent random data - torch.cuda.set_device(rank) dtype = getattr(torch, args.dtype) if rank == 0: @@ -103,7 +131,7 @@ def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namesp C_fused = torch.empty(args.M, args.N, dtype=dtype).cuda() # Output tensor for our kernel - NUM_SMS = torch.cuda.get_device_properties(rank).multi_processor_count + NUM_SMS = torch.cuda.get_device_properties(local_rank).multi_processor_count grid = (NUM_SMS,) # Launch the fused Triton kernel @@ -165,14 +193,20 @@ def example_run(rank: int, world_size: int, init_url: str, args: argparse.Namesp def main(): args = parse_args() - num_ranks = args.num_ranks - init_url = "tcp://127.0.0.1:29504" - mp.spawn( - fn=example_run, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + example_run(rank, world_size, "env://", args, local_rank=local_rank) + else: + num_ranks = args.num_ranks + init_url = "tcp://127.0.0.1:29504" + mp.spawn( + fn=example_run, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/iris/bench/_runner.py b/iris/bench/_runner.py index e53410d95..fa2ff990a 100644 --- a/iris/bench/_runner.py +++ b/iris/bench/_runner.py @@ -555,6 +555,55 @@ def main(argv: list[str] | None = None) -> None: print("No benchmark configurations to run after applying filters/skips.", file=sys.stderr) sys.exit(1) + # If launched by torchrun/srun for a multi-node job, do not spawn another + # local elastic job. The current process is already one benchmark rank. + if all(key in os.environ for key in ("RANK", "LOCAL_RANK", "WORLD_SIZE")): + world_size = int(os.environ["WORLD_SIZE"]) + global_rank = int(os.environ["RANK"]) + if world_size not in all_num_ranks: + if global_rank == 0: + configured = ", ".join(str(n) for n in sorted(all_num_ranks)) + print( + f"torchrun WORLD_SIZE={world_size} does not match benchmark num_ranks selection " + f"{{{configured}}}. Pass --axis_num_ranks={world_size} or launch with a matching world size.", + file=sys.stderr, + ) + sys.exit(1) + + dropped_num_ranks = sorted(all_num_ranks - {world_size}) + if dropped_num_ranks and global_rank == 0: + dropped = ", ".join(str(n) for n in dropped_num_ranks) + print( + f"Warning: torchrun WORLD_SIZE={world_size}; skipping benchmark num_ranks values: {dropped}", + file=sys.stderr, + ) + + all_results = _run_benchmarks_worker( + benchmarks, + axis_overrides, + skip_overrides, + args.heap_size, + args.use_gluon, + args.n_warmup, + args.n_repeat, + args.benchmark_filter, + ) + + if global_rank == 0: + if args.benchmark_format == "json": + output = _format_json(all_results) + elif args.benchmark_format == "csv": + output = _format_csv(all_results) + else: + output = _format_console(all_results) + + print(output, end="") + + if args.benchmark_out: + with open(args.benchmark_out, "w") as f: + f.write(output) + return + # Launch once per unique num_ranks, collecting results across runs all_results: list[Result] = [] diff --git a/iris/drivers/local/nvidia.py b/iris/drivers/local/nvidia.py index 9bb8bef84..3f4530012 100644 --- a/iris/drivers/local/nvidia.py +++ b/iris/drivers/local/nvidia.py @@ -49,6 +49,8 @@ _CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = 0x1 _CU_MEM_ALLOC_GRANULARITY_MINIMUM = 0 _CU_MEM_ACCESS_FLAGS_PROT_READWRITE = 0x3 +_CU_POINTER_ATTRIBUTE_RANGE_START_ADDR = 11 +_CU_POINTER_ATTRIBUTE_RANGE_SIZE = 12 class LocalCudaError(DriverError): @@ -115,6 +117,8 @@ def _configure_signatures() -> None: cu_mem_set_access = _get_required_cuda_symbol("cuMemSetAccess") cu_mem_export_to_shareable_handle = _get_required_cuda_symbol("cuMemExportToShareableHandle") cu_mem_import_from_shareable_handle = _get_required_cuda_symbol("cuMemImportFromShareableHandle") + cu_mem_get_address_range = _get_required_cuda_symbol("cuMemGetAddressRange") + cu_pointer_get_attribute = _get_required_cuda_symbol("cuPointerGetAttribute") cu_init.argtypes = [ctypes.c_uint] cu_init.restype = ctypes.c_int @@ -189,6 +193,28 @@ def _configure_signatures() -> None: ] cu_mem_export_to_shareable_handle.restype = ctypes.c_int + cu_mem_get_address_range.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.POINTER(ctypes.c_size_t), + ctypes.c_uint64, + ] + cu_mem_get_address_range.restype = ctypes.c_int + + cu_pointer_get_attribute.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_uint64, + ] + cu_pointer_get_attribute.restype = ctypes.c_int + + cu_mem_retain_allocation_handle = getattr(_cuda_driver, "cuMemRetainAllocationHandle", None) + if cu_mem_retain_allocation_handle is not None: + cu_mem_retain_allocation_handle.argtypes = [ + ctypes.POINTER(ctypes.c_uint64), + ctypes.c_void_p, + ] + cu_mem_retain_allocation_handle.restype = ctypes.c_int + cu_mem_import_from_shareable_handle.argtypes = [ ctypes.POINTER(ctypes.c_uint64), ctypes.c_void_p, @@ -281,10 +307,13 @@ def __init__(self) -> None: self._device_ordinal: int = 0 self._granularity: Optional[int] = None self._initialized: bool = False + self._context: Optional[ctypes.c_void_p] = None def _check_initialized(self) -> None: if not self._initialized: raise LocalCudaError("LocalCudaDriver not initialized - call initialize() first") + if self._context is not None: + _cuda_try(_cuda_driver.cuCtxSetCurrent(self._context), "cuCtxSetCurrent") def _make_alloc_props(self) -> _MemAllocationProp: props = _MemAllocationProp() @@ -339,6 +368,7 @@ def initialize(self, device_ordinal: int) -> None: _cuda_try(_cuda_driver.cuCtxSetCurrent(ctx), "cuCtxSetCurrent") self._device_ordinal = device_ordinal self._granularity = None + self._context = ctypes.c_void_p(ctx.value) self._initialized = True logger.info("LocalCudaDriver initialized (device %d)", device_ordinal) @@ -439,6 +469,43 @@ def export_handle(self, allocation: LocalAllocation) -> bytes: ) return struct.pack(_CUDA_HANDLE_FMT, int(fd.value)) + def export_pointer_handle(self, ptr: int, size: int) -> bytes: + """Export the VMM allocation containing ptr as a 4-byte native-endian POSIX FD.""" + self._check_initialized() + + retain_handle = getattr(_cuda_driver, "cuMemRetainAllocationHandle", None) + if retain_handle is None: + raise LocalCudaNotSupported("cuMemRetainAllocationHandle is not available in this CUDA driver") + + handle = ctypes.c_uint64() + try: + _cuda_try( + retain_handle(ctypes.byref(handle), ctypes.c_void_p(ptr)), + "cuMemRetainAllocationHandle", + ) + except LocalCudaError as exc: + raise LocalCudaNotSupported( + "CUDA can only export allocations backed by its virtual memory management API" + ) from exc + + try: + fd = ctypes.c_int(-1) + try: + _cuda_try( + _cuda_driver.cuMemExportToShareableHandle( + ctypes.byref(fd), + handle.value, + _CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + 0, + ), + "cuMemExportToShareableHandle", + ) + except LocalCudaError as exc: + raise LocalCudaNotSupported("CUDA could not export the retained allocation handle") from exc + return struct.pack(_CUDA_HANDLE_FMT, int(fd.value)) + finally: + _cuda_try(_cuda_driver.cuMemRelease(handle.value), "cuMemRelease") + def _import_handle(self, handle_bytes: bytes) -> int: handle_bytes = _normalize_handle_bytes(handle_bytes) fd_value = struct.unpack(_CUDA_HANDLE_FMT, handle_bytes)[0] @@ -615,3 +682,36 @@ def free_va(self, va: int, size: int) -> None: """Free a CUDA VA range previously returned by reserve_va.""" self._check_initialized() _cuda_try(_cuda_driver.cuMemAddressFree(va, size), "cuMemAddressFree") + + def get_address_range(self, ptr: int) -> tuple[int, int]: + """Return base address and size for the CUDA allocation containing ptr.""" + self._check_initialized() + base = ctypes.c_uint64() + size = ctypes.c_size_t() + try: + _cuda_try( + _cuda_driver.cuMemGetAddressRange( + ctypes.byref(base), + ctypes.byref(size), + ctypes.c_uint64(ptr), + ), + "cuMemGetAddressRange", + ) + except LocalCudaError: + _cuda_try( + _cuda_driver.cuPointerGetAttribute( + ctypes.byref(base), + _CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + ctypes.c_uint64(ptr), + ), + "cuPointerGetAttribute(RANGE_START_ADDR)", + ) + _cuda_try( + _cuda_driver.cuPointerGetAttribute( + ctypes.byref(size), + _CU_POINTER_ATTRIBUTE_RANGE_SIZE, + ctypes.c_uint64(ptr), + ), + "cuPointerGetAttribute(RANGE_SIZE)", + ) + return int(base.value), int(size.value) diff --git a/iris/host/distributed/topology.py b/iris/host/distributed/topology.py index cbf20ce22..20da03e94 100644 --- a/iris/host/distributed/topology.py +++ b/iris/host/distributed/topology.py @@ -8,6 +8,7 @@ import os import re import socket +import hashlib from dataclasses import dataclass, field from enum import IntEnum from typing import Any, Dict, List, Optional, Set, Tuple @@ -343,6 +344,111 @@ def _get_gpu_fabric_info(gpu_id: int, vendor: str, pci_bus_id: str = "") -> Fabr return _nvidia_get_gpu_fabric_info(gpu_id, pci_bus_id=pci_bus_id) +def _probe_nvidia_fabric_connectivity(gpu_id: int, rank: int, world_size: int) -> Optional[List[List[bool]]]: + """ + Probe NVIDIA fabric reachability with CUDA fabric memory handles. + + Some GB200/MNNVL environments expose working fabric handles while NVML + reports an empty GPU fabric UUID. This collective probe uses the same + driver interface as Iris memory sharing: each rank exports a tiny fabric + allocation and every other rank tries to import it. A successful symmetric + import means the ranks share an NVIDIA fabric memory domain. + """ + if not dist.is_initialized() or world_size <= 1: + return None + + local_record: dict[str, Any] = { + "rank": rank, + "ok": False, + "handle": b"", + "size": 0, + } + driver = None + allocation = None + imported_mappings = [] + + try: + from iris.drivers.fabric.nvidia import NvidiaFabricDriver + + driver = NvidiaFabricDriver() + driver.initialize(gpu_id) + size = driver.get_minimum_granularity() + allocation = driver.allocate_exportable(size) + local_record = { + "rank": rank, + "ok": True, + "handle": driver.export_handle(allocation), + "size": allocation.size, + } + except Exception as exc: + logger.debug("[Rank %d] CUDA fabric handle export probe failed: %s", rank, exc) + + records: List[Optional[dict[str, Any]]] = [None] * world_size + dist.all_gather_object(records, local_record) + + local_row = [False] * world_size + for record in records: + if not record or not record.get("ok"): + continue + peer_rank = int(record["rank"]) + if peer_rank == rank: + local_row[peer_rank] = True + continue + if driver is None: + continue + try: + mapping = driver.import_and_map(peer_rank, record["handle"], int(record["size"])) + imported_mappings.append(mapping) + local_row[peer_rank] = True + except Exception as exc: + logger.debug("[Rank %d] CUDA fabric handle import from rank %d failed: %s", rank, peer_rank, exc) + + rows: List[Optional[List[bool]]] = [None] * world_size + dist.all_gather_object(rows, local_row) + + if driver is not None: + for mapping in imported_mappings: + try: + driver.cleanup_import(mapping) + except Exception as exc: + logger.debug("[Rank %d] CUDA fabric probe import cleanup failed: %s", rank, exc) + if allocation is not None: + try: + driver.cleanup_local(allocation) + except Exception as exc: + logger.debug("[Rank %d] CUDA fabric probe local cleanup failed: %s", rank, exc) + + if any(row is None for row in rows): + return None + return [list(row) for row in rows if row is not None] + + +def _fabric_components_from_connectivity(connectivity: List[List[bool]]) -> List[Set[int]]: + """Return bidirectionally reachable components from a fabric probe matrix.""" + world_size = len(connectivity) + visited: Set[int] = set() + components: List[Set[int]] = [] + + for start in range(world_size): + if start in visited: + continue + stack = [start] + component: Set[int] = set() + visited.add(start) + while stack: + rank = stack.pop() + component.add(rank) + for peer in range(world_size): + if peer in visited: + continue + if connectivity[rank][peer] and connectivity[peer][rank]: + visited.add(peer) + stack.append(peer) + components.append(component) + + return components + + def _normalize_pci_bus_id(bus_id: str) -> str: """ Normalize a PCI bus ID to a canonical lowercase form for comparison. @@ -1225,6 +1331,30 @@ def discover(self) -> TopologyMap: info = GPUInfo.from_dict(json.loads(gpu_json)) gpu_info_map[info.global_rank] = info + if ( + vendor == "nvidia" + and self.world_size > 1 + and any(not info.fabric_info.domain_key for info in gpu_info_map.values()) + ): + connectivity = _probe_nvidia_fabric_connectivity(self.gpu_id, self.rank, self.world_size) + if connectivity is not None: + for component in _fabric_components_from_connectivity(connectivity): + if len(component) <= 1: + continue + component_uuids = sorted(gpu_info_map[r].uuid for r in component) + cluster_uuid = "cuda-probe-" + hashlib.sha1(",".join(component_uuids).encode()).hexdigest()[:16] + for component_rank in component: + if not gpu_info_map[component_rank].fabric_info.domain_key: + gpu_info_map[component_rank].fabric_info = FabricInfo( + cluster_uuid=cluster_uuid, + clique_id=0, + ) + logger.info( + "Detected NVIDIA fabric domain via CUDA fabric handle probe: ranks=%s domain=%s", + sorted(component), + cluster_uuid, + ) + all_node_infos = [json.loads(s) for s in all_node_jsons] # Group ranks by hostname diff --git a/iris/host/memory/allocators/vmem_chunked_allocator.py b/iris/host/memory/allocators/vmem_chunked_allocator.py index 1dbc6cde8..62e922902 100644 --- a/iris/host/memory/allocators/vmem_chunked_allocator.py +++ b/iris/host/memory/allocators/vmem_chunked_allocator.py @@ -31,7 +31,7 @@ import torch from .base import BaseAllocator -from iris.drivers.base import LocalAllocation, PeerMapping +from iris.drivers.base import DriverNotSupported, LocalAllocation, PeerMapping from iris.drivers.factory import DriverFactory from iris.host.distributed.topology import ( InterconnectLevel, @@ -448,6 +448,15 @@ def get_num_chunks(self): """Return the number of exported heap regions.""" return len(self._shared_regions) + def _can_return_external_tensor_alias(self) -> bool: + return self.num_ranks == 1 and self.driver.__class__.__name__ == "LocalCudaDriver" + + def _external_tensor_alias(self, external_tensor: torch.Tensor) -> torch.Tensor: + logger.info( + "Returning a local CUDA tensor alias because this external allocation cannot be mapped into the Iris heap" + ) + return external_tensor.view(external_tensor.shape) + def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: """Import an external tensor into the symmetric heap (zero-copy). @@ -457,10 +466,10 @@ def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: until allocator.close() so peer translation remains valid for RMA. Raises: - DriverNotSupported: This operation requires DMA-BUF support and is - currently AMD-only. On NVIDIA, the local driver does not - implement export_pointer_handle for arbitrary device pointers, - and this method will raise. + DriverNotSupported: If the active driver cannot export and remap + the external allocation. Single-rank local CUDA jobs may return + a direct tensor alias instead, because no peer VA translation is + needed. RuntimeError: If the input tensor is not on a CUDA/HIP device or is not contiguous. """ @@ -469,10 +478,18 @@ def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: raise RuntimeError("Can only import CUDA/HIP tensors") if not external_tensor.is_contiguous(): raise RuntimeError("Only contiguous tensors can be imported; call .contiguous() before as_symmetric()") + if self._can_return_external_tensor_alias(): + return self._external_tensor_alias(external_tensor) external_ptr = external_tensor.data_ptr() tensor_size = external_tensor.numel() * external_tensor.element_size() - alloc_base, alloc_size = self.driver.get_address_range(external_ptr) + try: + alloc_base, alloc_size = self.driver.get_address_range(external_ptr) + except DriverNotSupported: + if self._can_return_external_tensor_alias(): + return self._external_tensor_alias(external_tensor) + raise + offset_in_alloc = external_ptr - alloc_base aligned_alloc_size = (alloc_size + self.granularity - 1) & ~(self.granularity - 1) @@ -485,7 +502,13 @@ def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: ) target_base_va = self.base_va + target_offset - handle_bytes = self.driver.export_pointer_handle(alloc_base, alloc_size) + try: + handle_bytes = self.driver.export_pointer_handle(alloc_base, alloc_size) + except DriverNotSupported: + if self._can_return_external_tensor_alias(): + return self._external_tensor_alias(external_tensor) + raise + import_kwargs = {} if self.driver.__class__.__name__ == "LocalHipDriver": import_kwargs = { diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index 78720b18e..4e18a4ea1 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -646,7 +646,9 @@ def all_gather_matmul_hbm_buffer( else: ctx.tracing.reset() - launch_kwargs = {"matrix_instr_nonkdim": 16} + launch_kwargs = {} + if getattr(torch.version, "hip", None): + launch_kwargs["matrix_instr_nonkdim"] = 16 if num_warps is not None: launch_kwargs["num_warps"] = num_warps if num_stages is not None: