Skip to content

Commit 1434760

Browse files
authored
[Model] Deepseek-v3 support (#3192)
1 parent e89a484 commit 1434760

26 files changed

+3107
-74
lines changed

ci/task/pylint.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ if [[ -n ${MLC_CI_SETUP_DEPS:-} ]]; then
1010
echo "MLC_CI_SETUP_DEPS=1 start setup deps"
1111
# TVM Unity is a dependency to this testing
1212
pip install --quiet --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-cpu
13-
pip install requests
13+
pip install requests triton
1414
pip install --quiet --pre -U cuda-python
1515
fi
1616

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""A pass that dispatch generic calls of triton kernels to specific kernel implementations."""
2+
3+
# pylint: disable=invalid-name
4+
5+
from typing import List
6+
7+
import tvm
8+
from tvm import IRModule, relax
9+
from tvm.relax.expr_functor import PyExprMutator, mutator
10+
11+
from mlc_llm.op.triton import (
12+
get_tir_w8a8_block_fp8_group_matmul,
13+
get_tir_w8a8_block_fp8_matmul,
14+
)
15+
from mlc_llm.support import logging
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@mutator
21+
class _Rewriter(PyExprMutator): # pylint: disable=abstract-method
22+
def __init__(self, mod: IRModule, target: tvm.target.Target) -> None:
23+
super().__init__(mod)
24+
self.mod = mod
25+
self.target = target
26+
self.extern_mods: List[tvm.runtime.Module] = []
27+
28+
def transform(self) -> tvm.IRModule: # pylint: disable=too-many-locals
29+
"""Entry point of the transformation"""
30+
for g_var, func in self.mod.functions_items():
31+
if not isinstance(func, relax.Function):
32+
continue
33+
new_func = self.visit_expr(func)
34+
# new_func = remove_all_unused(new_func)
35+
self.builder_.update_func(g_var, new_func)
36+
37+
mod = self.builder_.finalize()
38+
mod_attrs = dict(mod.attrs) if mod.attrs else {}
39+
mod = mod.with_attr(
40+
"external_mods", list(mod_attrs.get("external_mods", [])) + self.extern_mods
41+
)
42+
return mod
43+
44+
def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=arguments-renamed
45+
call = super().visit_call_(call)
46+
47+
if (
48+
call.op != tvm.ir.Op.get("relax.call_dps_packed")
49+
or not isinstance(call.args[0], relax.ExternFunc)
50+
or not str(call.args[0].global_symbol).startswith("mlc.triton.")
51+
):
52+
return call
53+
54+
global_symbol = str(call.args[0].global_symbol)
55+
assert isinstance(call.args[1], relax.Tuple)
56+
if global_symbol == "mlc.triton.w8a8_block_fp8_matmul":
57+
return self.w8a8_block_fp8_matmul(call.args[1].fields, call.struct_info)
58+
if global_symbol == "mlc.triton.w8a8_block_fp8_group_matmul":
59+
return self.w8a8_block_fp8_group_matmul(call.args[1].fields, call.struct_info)
60+
raise ValueError(f"Unknown mlc.triton kernel identifier: {global_symbol}")
61+
62+
def w8a8_block_fp8_matmul( # pylint: disable=too-many-locals
63+
self, args: List[relax.Expr], out_sinfo: relax.StructInfo
64+
) -> relax.Expr:
65+
"""Emit the w8a8_block_fp8_matmul triton kernel."""
66+
assert len(args) == 16
67+
x, weight, x_scale, weight_scale = args[:4]
68+
(
69+
N,
70+
K,
71+
block_n,
72+
block_k,
73+
BLOCK_SIZE_M,
74+
BLOCK_SIZE_N,
75+
BLOCK_SIZE_K,
76+
GROUP_SIZE_M,
77+
num_warps,
78+
num_stages,
79+
) = [arg.value.value for arg in args[4:14]]
80+
in_dtype, out_dtype = str(args[14].value), str(args[15].value)
81+
82+
prim_func, func_name = get_tir_w8a8_block_fp8_matmul(
83+
N,
84+
K,
85+
block_n,
86+
block_k,
87+
in_dtype, # type: ignore
88+
out_dtype, # type: ignore
89+
BLOCK_SIZE_M,
90+
BLOCK_SIZE_N,
91+
BLOCK_SIZE_K,
92+
GROUP_SIZE_M,
93+
num_warps,
94+
num_stages,
95+
self.extern_mods,
96+
)
97+
if prim_func is None:
98+
# The TIR function is already in the IRModule
99+
gv = self.builder_.get().get_global_var(func_name)
100+
else:
101+
# Add the TIR function to the IRModule
102+
gv = self.builder_.add_func(prim_func, func_name)
103+
104+
return relax.call_tir(gv, [x, weight, x_scale, weight_scale], out_sinfo=out_sinfo)
105+
106+
def w8a8_block_fp8_group_matmul( # pylint: disable=too-many-locals
107+
self, args: List[relax.Expr], out_sinfo: relax.StructInfo
108+
) -> relax.Expr:
109+
"""Emit the w8a8_block_fp8_group_matmul triton kernel."""
110+
assert len(args) == 19
111+
x, weight, x_scale, weight_scale, expert_ids, indptr = args[:6]
112+
(
113+
N,
114+
K,
115+
num_experts,
116+
block_n,
117+
block_k,
118+
BLOCK_SIZE_M,
119+
BLOCK_SIZE_N,
120+
BLOCK_SIZE_K,
121+
GROUP_SIZE_M,
122+
num_warps,
123+
num_stages,
124+
) = [arg.value.value for arg in args[6:17]]
125+
in_dtype, out_dtype = str(args[17].value), str(args[18].value)
126+
127+
prim_func, func_name = get_tir_w8a8_block_fp8_group_matmul(
128+
N,
129+
K,
130+
num_experts,
131+
block_n,
132+
block_k,
133+
in_dtype, # type: ignore
134+
out_dtype, # type: ignore
135+
BLOCK_SIZE_M,
136+
BLOCK_SIZE_N,
137+
BLOCK_SIZE_K,
138+
GROUP_SIZE_M,
139+
num_warps,
140+
num_stages,
141+
self.extern_mods,
142+
)
143+
if prim_func is None:
144+
# The TIR function is already in the IRModule
145+
gv = self.builder_.get().get_global_var(func_name)
146+
else:
147+
# Add the TIR function to the IRModule
148+
gv = self.builder_.add_func(prim_func, func_name)
149+
150+
return relax.call_tir(
151+
gv,
152+
[x, weight, x_scale, weight_scale, expert_ids, indptr],
153+
out_sinfo=out_sinfo,
154+
)
155+
156+
157+
@tvm.transform.module_pass(opt_level=0, name="DispatchTritonKernel")
158+
class DispatchTritonKernel: # pylint: disable=too-many-instance-attributes,too-few-public-methods
159+
"""Rewrite KV cache creation functions to IRModule."""
160+
161+
def __init__(self, target: tvm.target.Target) -> None:
162+
"""Initializer.
163+
164+
Parameters
165+
----------
166+
"""
167+
self.target = target
168+
169+
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
170+
"""Entrypoint"""
171+
if self.target.kind.name != "cuda":
172+
return mod
173+
174+
return _Rewriter(mod, self.target).transform()

python/mlc_llm/compiler_pass/fuse_add_norm.py

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=argume
182182
call = super().visit_call_(call)
183183

184184
# Match the "rms_norm(add(x1, x2), w)" pattern
185+
# Todo: support bf16 # pylint: disable=fixme
185186
if call.op != tvm.ir.Op.get("relax.nn.rms_norm") or call.struct_info.dtype != "float16":
186187
return call
187188
assert len(call.args) == 2

python/mlc_llm/compiler_pass/fuse_dequantize_take.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class FuseDequantizeTake: # pylint: disable=too-few-public-methods
1616
"""A compiler pass that fuses dequantize + take."""
1717

18-
def transform_module(
18+
def transform_module( # pylint: disable=too-many-locals
1919
self,
2020
mod: IRModule,
2121
_ctx: tvm.transform.PassContext,

python/mlc_llm/compiler_pass/pipeline.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .blas_dispatch import BLASDispatch
2929
from .clean_up_tir_attrs import CleanUpTIRAttrs
3030
from .dispatch_kv_cache_creation import DispatchKVCacheCreation
31+
from .dispatch_triton_kernel import DispatchTritonKernel
3132
from .estimate_memory_usage import AttachMetadataWithMemoryUsage
3233
from .fuse_add_norm import FuseAddRMSNorm
3334
from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise
@@ -117,6 +118,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
117118
_DebugDump("debug-phase0.py", debug_dump, show_meta=False),
118119
# Phase 1. Passes on high-level operator graph
119120
_LogProgress("Running TVM Relax graph-level optimizations"),
121+
DispatchTritonKernel(target),
120122
FuseFTDequantizeEpilogue(),
121123
FuseDequantizeTranspose(),
122124
BLASDispatch(target) if cublas_gemm else tvm.transform.Sequential([]),
@@ -185,6 +187,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
185187
),
186188
tvm.relax.transform.StaticPlanBlockMemory(),
187189
AttachMetadataWithMemoryUsage(metadata),
190+
_DebugDump("debug-phase5.py", debug_dump, show_meta=False),
188191
tvm.relax.transform.RewriteCUDAGraph(),
189192
AttachCUDAGraphAllocInitFunc(),
190193
tvm.relax.transform.LowerGPUIPCAllocStorage(),
@@ -193,7 +196,6 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
193196
tvm.relax.transform.LowerRuntimeBuiltin(),
194197
tvm.relax.transform.VMShapeLower(),
195198
tvm.relax.transform.AttachGlobalSymbol(),
196-
_DebugDump("debug-final.py", debug_dump, show_meta=False),
197199
_LogProgress("Compiling external modules"),
198200
tvm.relax.transform.AttachExternModules(ext_mods),
199201
_LogProgress("Compilation complete! Exporting to disk"),

python/mlc_llm/conversation_template/deepseek.py

+16
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@
3636
)
3737
)
3838

39+
# DeepSeek-V3
40+
ConvTemplateRegistry.register_conv_template(
41+
Conversation(
42+
name="deepseek_v3",
43+
system_template=f"<|begin▁of▁sentence|>{MessagePlaceholders.SYSTEM.value}",
44+
system_message="You are Deepseek-V3, an AI assistant created exclusively by the Chinese "
45+
"Company DeepSeek. You'll provide helpful, harmless, and detailed responses to all "
46+
"user inquiries.",
47+
roles={"user": "<|User|>", "assistant": "<|Assistant|>"},
48+
seps=["", "<|end▁of▁sentence|>"],
49+
role_content_sep="",
50+
role_empty_sep="",
51+
stop_token_ids=[1],
52+
)
53+
)
54+
3955
# DeepSeek-R1-Distill-Qwen
4056
ConvTemplateRegistry.register_conv_template(
4157
Conversation(

python/mlc_llm/interface/gen_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
309309
"aya-23",
310310
"deepseek",
311311
"deepseek_v2",
312+
"deepseek_v3",
312313
"deepseek_r1_qwen",
313314
"deepseek_r1_llama",
314315
"olmo",

python/mlc_llm/loader/utils.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,21 @@ def load_torch_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:
5555
def load_safetensor_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:
5656
"""Load and yield SafeTensor format parameters."""
5757
import safetensors # pylint: disable=import-outside-toplevel,import-error
58+
import torch # pylint: disable=import-outside-toplevel
5859

5960
with safetensors.safe_open(path, framework="pt", device="cpu") as in_file:
6061
for name in in_file.keys():
6162
param = in_file.get_tensor(name)
6263
param = param.detach().cpu()
6364
dtype = str(param.dtype)
6465
if dtype == "torch.bfloat16":
65-
param = param.float()
66-
param = param.numpy()
66+
import ml_dtypes # pylint: disable=import-outside-toplevel
67+
68+
param = param.view(torch.float16).cpu().numpy().view(ml_dtypes.bfloat16)
69+
elif dtype == "torch.float8_e4m3fn":
70+
import ml_dtypes # pylint: disable=import-outside-toplevel
71+
72+
param = param.view(torch.uint8).cpu().numpy().view(ml_dtypes.float8_e4m3fn)
73+
else:
74+
param = param.numpy()
6775
yield name, param

0 commit comments

Comments
 (0)