|
| 1 | +"""A pass that dispatch generic calls of triton kernels to specific kernel implementations.""" |
| 2 | + |
| 3 | +# pylint: disable=invalid-name |
| 4 | + |
| 5 | +from typing import List |
| 6 | + |
| 7 | +import tvm |
| 8 | +from tvm import IRModule, relax |
| 9 | +from tvm.relax.expr_functor import PyExprMutator, mutator |
| 10 | + |
| 11 | +from mlc_llm.op.triton import ( |
| 12 | + get_tir_w8a8_block_fp8_group_matmul, |
| 13 | + get_tir_w8a8_block_fp8_matmul, |
| 14 | +) |
| 15 | +from mlc_llm.support import logging |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +@mutator |
| 21 | +class _Rewriter(PyExprMutator): # pylint: disable=abstract-method |
| 22 | + def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: |
| 23 | + super().__init__(mod) |
| 24 | + self.mod = mod |
| 25 | + self.target = target |
| 26 | + self.extern_mods: List[tvm.runtime.Module] = [] |
| 27 | + |
| 28 | + def transform(self) -> tvm.IRModule: # pylint: disable=too-many-locals |
| 29 | + """Entry point of the transformation""" |
| 30 | + for g_var, func in self.mod.functions_items(): |
| 31 | + if not isinstance(func, relax.Function): |
| 32 | + continue |
| 33 | + new_func = self.visit_expr(func) |
| 34 | + # new_func = remove_all_unused(new_func) |
| 35 | + self.builder_.update_func(g_var, new_func) |
| 36 | + |
| 37 | + mod = self.builder_.finalize() |
| 38 | + mod_attrs = dict(mod.attrs) if mod.attrs else {} |
| 39 | + mod = mod.with_attr( |
| 40 | + "external_mods", list(mod_attrs.get("external_mods", [])) + self.extern_mods |
| 41 | + ) |
| 42 | + return mod |
| 43 | + |
| 44 | + def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=arguments-renamed |
| 45 | + call = super().visit_call_(call) |
| 46 | + |
| 47 | + if ( |
| 48 | + call.op != tvm.ir.Op.get("relax.call_dps_packed") |
| 49 | + or not isinstance(call.args[0], relax.ExternFunc) |
| 50 | + or not str(call.args[0].global_symbol).startswith("mlc.triton.") |
| 51 | + ): |
| 52 | + return call |
| 53 | + |
| 54 | + global_symbol = str(call.args[0].global_symbol) |
| 55 | + assert isinstance(call.args[1], relax.Tuple) |
| 56 | + if global_symbol == "mlc.triton.w8a8_block_fp8_matmul": |
| 57 | + return self.w8a8_block_fp8_matmul(call.args[1].fields, call.struct_info) |
| 58 | + if global_symbol == "mlc.triton.w8a8_block_fp8_group_matmul": |
| 59 | + return self.w8a8_block_fp8_group_matmul(call.args[1].fields, call.struct_info) |
| 60 | + raise ValueError(f"Unknown mlc.triton kernel identifier: {global_symbol}") |
| 61 | + |
| 62 | + def w8a8_block_fp8_matmul( # pylint: disable=too-many-locals |
| 63 | + self, args: List[relax.Expr], out_sinfo: relax.StructInfo |
| 64 | + ) -> relax.Expr: |
| 65 | + """Emit the w8a8_block_fp8_matmul triton kernel.""" |
| 66 | + assert len(args) == 16 |
| 67 | + x, weight, x_scale, weight_scale = args[:4] |
| 68 | + ( |
| 69 | + N, |
| 70 | + K, |
| 71 | + block_n, |
| 72 | + block_k, |
| 73 | + BLOCK_SIZE_M, |
| 74 | + BLOCK_SIZE_N, |
| 75 | + BLOCK_SIZE_K, |
| 76 | + GROUP_SIZE_M, |
| 77 | + num_warps, |
| 78 | + num_stages, |
| 79 | + ) = [arg.value.value for arg in args[4:14]] |
| 80 | + in_dtype, out_dtype = str(args[14].value), str(args[15].value) |
| 81 | + |
| 82 | + prim_func, func_name = get_tir_w8a8_block_fp8_matmul( |
| 83 | + N, |
| 84 | + K, |
| 85 | + block_n, |
| 86 | + block_k, |
| 87 | + in_dtype, # type: ignore |
| 88 | + out_dtype, # type: ignore |
| 89 | + BLOCK_SIZE_M, |
| 90 | + BLOCK_SIZE_N, |
| 91 | + BLOCK_SIZE_K, |
| 92 | + GROUP_SIZE_M, |
| 93 | + num_warps, |
| 94 | + num_stages, |
| 95 | + self.extern_mods, |
| 96 | + ) |
| 97 | + if prim_func is None: |
| 98 | + # The TIR function is already in the IRModule |
| 99 | + gv = self.builder_.get().get_global_var(func_name) |
| 100 | + else: |
| 101 | + # Add the TIR function to the IRModule |
| 102 | + gv = self.builder_.add_func(prim_func, func_name) |
| 103 | + |
| 104 | + return relax.call_tir(gv, [x, weight, x_scale, weight_scale], out_sinfo=out_sinfo) |
| 105 | + |
| 106 | + def w8a8_block_fp8_group_matmul( # pylint: disable=too-many-locals |
| 107 | + self, args: List[relax.Expr], out_sinfo: relax.StructInfo |
| 108 | + ) -> relax.Expr: |
| 109 | + """Emit the w8a8_block_fp8_group_matmul triton kernel.""" |
| 110 | + assert len(args) == 19 |
| 111 | + x, weight, x_scale, weight_scale, expert_ids, indptr = args[:6] |
| 112 | + ( |
| 113 | + N, |
| 114 | + K, |
| 115 | + num_experts, |
| 116 | + block_n, |
| 117 | + block_k, |
| 118 | + BLOCK_SIZE_M, |
| 119 | + BLOCK_SIZE_N, |
| 120 | + BLOCK_SIZE_K, |
| 121 | + GROUP_SIZE_M, |
| 122 | + num_warps, |
| 123 | + num_stages, |
| 124 | + ) = [arg.value.value for arg in args[6:17]] |
| 125 | + in_dtype, out_dtype = str(args[17].value), str(args[18].value) |
| 126 | + |
| 127 | + prim_func, func_name = get_tir_w8a8_block_fp8_group_matmul( |
| 128 | + N, |
| 129 | + K, |
| 130 | + num_experts, |
| 131 | + block_n, |
| 132 | + block_k, |
| 133 | + in_dtype, # type: ignore |
| 134 | + out_dtype, # type: ignore |
| 135 | + BLOCK_SIZE_M, |
| 136 | + BLOCK_SIZE_N, |
| 137 | + BLOCK_SIZE_K, |
| 138 | + GROUP_SIZE_M, |
| 139 | + num_warps, |
| 140 | + num_stages, |
| 141 | + self.extern_mods, |
| 142 | + ) |
| 143 | + if prim_func is None: |
| 144 | + # The TIR function is already in the IRModule |
| 145 | + gv = self.builder_.get().get_global_var(func_name) |
| 146 | + else: |
| 147 | + # Add the TIR function to the IRModule |
| 148 | + gv = self.builder_.add_func(prim_func, func_name) |
| 149 | + |
| 150 | + return relax.call_tir( |
| 151 | + gv, |
| 152 | + [x, weight, x_scale, weight_scale, expert_ids, indptr], |
| 153 | + out_sinfo=out_sinfo, |
| 154 | + ) |
| 155 | + |
| 156 | + |
| 157 | +@tvm.transform.module_pass(opt_level=0, name="DispatchTritonKernel") |
| 158 | +class DispatchTritonKernel: # pylint: disable=too-many-instance-attributes,too-few-public-methods |
| 159 | + """Rewrite KV cache creation functions to IRModule.""" |
| 160 | + |
| 161 | + def __init__(self, target: tvm.target.Target) -> None: |
| 162 | + """Initializer. |
| 163 | +
|
| 164 | + Parameters |
| 165 | + ---------- |
| 166 | + """ |
| 167 | + self.target = target |
| 168 | + |
| 169 | + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: |
| 170 | + """Entrypoint""" |
| 171 | + if self.target.kind.name != "cuda": |
| 172 | + return mod |
| 173 | + |
| 174 | + return _Rewriter(mod, self.target).transform() |
0 commit comments