Skip to content

Commit e47c787

Browse files
authored
[TRTLLM-8535][feat] Support DeepSeek V3.2 with FP8 + BF16 KV cache/NVFP4 + BF16 KV cache (#8405)
Signed-off-by: Fanrong Li <[email protected]> Signed-off-by: Chang Liu <[email protected]> Signed-off-by: Tracin <[email protected]>
1 parent 2d86d6b commit e47c787

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+4914
-153
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ tensorrt_llm/deep_gemm/
4747
tensorrt_llm/deep_gemm_cpp_tllm.*.so
4848
tensorrt_llm/deep_gemm_cpp_tllm.pyi
4949
tensorrt_llm/pg_utils_bindings.*.so
50+
tensorrt_llm/flash_mla/
51+
tensorrt_llm/flash_mla_cpp_tllm.*.so
52+
tensorrt_llm/flash_mla_cpp_tllm.pyi
5053
*docs/cpp_docs*
5154
*docs/source/_cpp_gen*
5255
docs/source/**/*.rst

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@
3030
path = 3rdparty/DeepGEMM
3131
url = https://github.com/ruoqianguo/DeepGEMM.git
3232
branch = swapab_sm100
33+
[submodule "3rdparty/flash-mla"]
34+
path = 3rdparty/flash-mla
35+
url = https://github.com/deepseek-ai/FlashMLA.git

3rdparty/flash-mla

Submodule flash-mla added at 1408756

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ option(BUILD_TESTS "Build Google tests" ON)
3232
option(BUILD_BENCHMARKS "Build benchmarks" ON)
3333
option(BUILD_DEEP_EP "Build the Deep EP module" ON)
3434
option(BUILD_DEEP_GEMM "Build the DeepGEMM module" ON)
35+
option(BUILD_FLASH_MLA "Build the FlashMLA module" ON)
3536
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
3637
option(NVTX_DISABLE "Disable all NVTX features" ON)
3738
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)

cpp/tensorrt_llm/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,4 +298,8 @@ if(BUILD_DEEP_GEMM)
298298
add_subdirectory(deep_gemm)
299299
endif()
300300

301+
if(BUILD_FLASH_MLA)
302+
add_subdirectory(flash_mla)
303+
endif()
304+
301305
add_subdirectory(plugins)
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
add_custom_target(flash_mla)
2+
3+
if(WIN32)
4+
return()
5+
endif()
6+
7+
# Prepare files
8+
# =============
9+
10+
# Use FlashMLA submodule
11+
set(FLASH_MLA_SOURCE_DIR
12+
${CMAKE_CURRENT_SOURCE_DIR}/../../../3rdparty/flash-mla)
13+
get_filename_component(FLASH_MLA_SOURCE_DIR ${FLASH_MLA_SOURCE_DIR} ABSOLUTE)
14+
15+
if(NOT EXISTS ${FLASH_MLA_SOURCE_DIR})
16+
message(
17+
FATAL_ERROR
18+
"FlashMLA submodule not found at ${FLASH_MLA_SOURCE_DIR}. Please run: git submodule update --init --recursive"
19+
)
20+
endif()
21+
22+
# Check if submodules are initialized
23+
if(NOT EXISTS ${FLASH_MLA_SOURCE_DIR}/csrc/cutlass/include)
24+
message(
25+
FATAL_ERROR
26+
"FlashMLA submodules not initialized. Please run: cd ${FLASH_MLA_SOURCE_DIR} && git submodule update --init --recursive"
27+
)
28+
endif()
29+
30+
# Compiler compatibility for SM100 inline assembly
31+
# =================================================
32+
# FlashMLA SM100 contains PTX inline assembly that Clang++ on ARM64 incorrectly
33+
# validates against host architecture. Use GCC for flash_mla on ARM64+Clang++.
34+
# This follows the same pattern as DeepEP (see deep_ep/CMakeLists.txt:125-132).
35+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64" AND CMAKE_CXX_COMPILER_ID
36+
MATCHES "Clang")
37+
message(
38+
STATUS
39+
"FlashMLA: ARM64 + Clang++ detected, switching to GCC for SM100 compatibility"
40+
)
41+
42+
# Find GCC (required on ARM64 when Clang++ is used)
43+
find_program(GCC_EXECUTABLE NAMES g++ REQUIRED)
44+
if(NOT GCC_EXECUTABLE)
45+
message(
46+
FATAL_ERROR
47+
"FlashMLA: GCC (g++) is required on ARM64 with Clang++ but was not found. "
48+
"Install GCC or use GCC as the primary compiler.")
49+
endif()
50+
51+
# Override CUDA host compiler for this target only
52+
set(CMAKE_CUDA_HOST_COMPILER ${GCC_EXECUTABLE})
53+
message(
54+
STATUS "FlashMLA: Using GCC at ${GCC_EXECUTABLE} for CUDA compilation")
55+
endif()
56+
57+
# Check CUDA version and architecture support
58+
# ============================================
59+
60+
# FlashMLA requires CUDA 12.3+ for SM90, 12.9+ for SM100
61+
set(SUPPORT_ARCHS)
62+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
63+
list(APPEND SUPPORT_ARCHS "90")
64+
endif()
65+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
66+
list(APPEND SUPPORT_ARCHS "100" "103")
67+
endif()
68+
69+
# Find intersection of supported and requested architectures
70+
set(FLASH_MLA_ARCHS)
71+
foreach(ARCH ${SUPPORT_ARCHS})
72+
if("${ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
73+
list(APPEND FLASH_MLA_ARCHS ${ARCH})
74+
endif()
75+
endforeach()
76+
77+
message(STATUS "flash_mla FLASH_MLA_ARCHS: ${FLASH_MLA_ARCHS}")
78+
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/cuda_architectures.txt
79+
"${FLASH_MLA_ARCHS}")
80+
if(NOT FLASH_MLA_ARCHS)
81+
message(
82+
STATUS
83+
"FlashMLA requires SM90 (CUDA 12.3+) or SM100 (CUDA 12.9+), skipping. "
84+
"Current CUDA version: ${CMAKE_CUDA_COMPILER_VERSION}, "
85+
"Current architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}")
86+
return()
87+
endif()
88+
89+
message(STATUS "Building FlashMLA for architectures: ${FLASH_MLA_ARCHS}")
90+
91+
# Copy and update python files
92+
# =============================
93+
set(FLASH_MLA_PYTHON_DEST ${CMAKE_CURRENT_BINARY_DIR}/python/flash_mla)
94+
file(REMOVE_RECURSE ${FLASH_MLA_PYTHON_DEST})
95+
file(MAKE_DIRECTORY ${FLASH_MLA_PYTHON_DEST})
96+
97+
# Copy all files from flash_mla directory
98+
file(GLOB_RECURSE FLASH_MLA_ALL_FILES ${FLASH_MLA_SOURCE_DIR}/flash_mla/*)
99+
configure_file(${FLASH_MLA_SOURCE_DIR}/LICENSE ${FLASH_MLA_PYTHON_DEST}/LICENSE
100+
COPYONLY)
101+
102+
foreach(SOURCE_FILE ${FLASH_MLA_ALL_FILES})
103+
file(RELATIVE_PATH REL_PATH ${FLASH_MLA_SOURCE_DIR}/flash_mla ${SOURCE_FILE})
104+
get_filename_component(REL_DIR ${REL_PATH} DIRECTORY)
105+
file(MAKE_DIRECTORY ${FLASH_MLA_PYTHON_DEST}/${REL_DIR})
106+
107+
# Check if it's a Python file that needs import renaming
108+
get_filename_component(FILE_EXT ${SOURCE_FILE} EXT)
109+
if(FILE_EXT STREQUAL ".py")
110+
# Read file content and replace module imports for Python files
111+
file(READ ${SOURCE_FILE} _content)
112+
# Replace the C++ extension module import
113+
string(REPLACE "flash_mla.cuda" "tensorrt_llm.flash_mla_cpp_tllm" _content
114+
"${_content}")
115+
# Replace absolute imports with relative imports for internal modules
116+
string(REPLACE "from flash_mla." "from ." _content "${_content}")
117+
118+
# Add adaptation header
119+
string(
120+
PREPEND
121+
_content
122+
"# Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/${REL_PATH}\n"
123+
)
124+
125+
# Write modified content
126+
set(_dst "${FLASH_MLA_PYTHON_DEST}/${REL_PATH}")
127+
file(WRITE ${_dst} "${_content}")
128+
else()
129+
# Copy non-Python files as-is
130+
set(_dst "${FLASH_MLA_PYTHON_DEST}/${REL_PATH}")
131+
file(COPY ${SOURCE_FILE} DESTINATION ${FLASH_MLA_PYTHON_DEST}/${REL_DIR})
132+
endif()
133+
134+
# Add dependency tracking
135+
set_property(
136+
DIRECTORY
137+
APPEND
138+
PROPERTY CMAKE_CONFIGURE_DEPENDS ${SOURCE_FILE})
139+
endforeach()
140+
141+
# Build flash_mla_cpp_tllm extension
142+
# ===================================
143+
144+
# Find torch_python
145+
find_library(TORCH_PYTHON_LIB torch_python REQUIRED
146+
HINTS ${TORCH_INSTALL_PREFIX}/lib)
147+
148+
# Define source files matching FlashMLA's setup.py Note: pybind.cpp has runtime
149+
# checks (arch.is_sm90()), but still references all kernel symbols We compile
150+
# all sources to avoid undefined symbols at module load time
151+
set(FLASH_MLA_SOURCES
152+
${FLASH_MLA_SOURCE_DIR}/csrc/pybind.cpp
153+
${FLASH_MLA_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
154+
${FLASH_MLA_SOURCE_DIR}/csrc/smxx/mla_combine.cu)
155+
156+
# Add SM90 sources (always include if CUDA >= 12.3)
157+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
158+
list(
159+
APPEND
160+
FLASH_MLA_SOURCES
161+
${FLASH_MLA_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
162+
${FLASH_MLA_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
163+
${FLASH_MLA_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu)
164+
endif()
165+
166+
# Add SM100 sources (always include if CUDA >= 12.9)
167+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
168+
list(
169+
APPEND
170+
FLASH_MLA_SOURCES
171+
${FLASH_MLA_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
172+
${FLASH_MLA_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
173+
${FLASH_MLA_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
174+
${FLASH_MLA_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu)
175+
endif()
176+
177+
# Disable LTO before creating target (similar to DeepEP) Let CMake generate
178+
# fatbinData for CUDA separable compilation
179+
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION FALSE)
180+
181+
pybind11_add_module(flash_mla_cpp_tllm ${FLASH_MLA_SOURCES})
182+
183+
set_target_properties(
184+
flash_mla_cpp_tllm
185+
PROPERTIES CXX_STANDARD_REQUIRED ON
186+
CXX_STANDARD 17
187+
CXX_SCAN_FOR_MODULES OFF
188+
CUDA_STANDARD 17
189+
CUDA_STANDARD_REQUIRED ON
190+
CUDA_SEPARABLE_COMPILATION ON
191+
CUDA_RESOLVE_DEVICE_SYMBOLS ON
192+
LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/flash_mla_cpp_tllm.version
193+
INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib"
194+
BUILD_WITH_INSTALL_RPATH TRUE)
195+
196+
# Set CUDA architectures Compile kernels for all CUDA-version-supported
197+
# architectures since pybind.cpp references them
198+
set(FLASH_MLA_BUILD_ARCHS)
199+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
200+
list(APPEND FLASH_MLA_BUILD_ARCHS "90")
201+
endif()
202+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
203+
list(APPEND FLASH_MLA_BUILD_ARCHS "100f")
204+
endif()
205+
set_cuda_architectures(flash_mla_cpp_tllm ${FLASH_MLA_BUILD_ARCHS})
206+
207+
# Copy of compiler options from FlashMLA setup.py
208+
target_compile_options(
209+
flash_mla_cpp_tllm
210+
PRIVATE
211+
${TORCH_CXX_FLAGS}
212+
$<$<COMPILE_LANGUAGE:CXX>:-std=c++17>
213+
$<$<COMPILE_LANGUAGE:CXX>:-O3>
214+
$<$<COMPILE_LANGUAGE:CXX>:-fPIC>
215+
$<$<COMPILE_LANGUAGE:CXX>:-DNDEBUG>
216+
$<$<COMPILE_LANGUAGE:CXX>:-Wno-deprecated-declarations>
217+
$<$<COMPILE_LANGUAGE:CXX>:-Wno-c++11-narrowing>
218+
$<$<COMPILE_LANGUAGE:CXX>:-fno-lto>
219+
$<$<COMPILE_LANGUAGE:CUDA>:-O3>
220+
$<$<COMPILE_LANGUAGE:CUDA>:-std=c++17>
221+
$<$<COMPILE_LANGUAGE:CUDA>:-DNDEBUG>
222+
$<$<COMPILE_LANGUAGE:CUDA>:-D_USE_MATH_DEFINES>
223+
$<$<COMPILE_LANGUAGE:CUDA>:-U__CUDA_NO_HALF_OPERATORS__>
224+
$<$<COMPILE_LANGUAGE:CUDA>:-U__CUDA_NO_HALF_CONVERSIONS__>
225+
$<$<COMPILE_LANGUAGE:CUDA>:-U__CUDA_NO_HALF2_OPERATORS__>
226+
$<$<COMPILE_LANGUAGE:CUDA>:-U__CUDA_NO_BFLOAT16_CONVERSIONS__>
227+
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
228+
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
229+
$<$<COMPILE_LANGUAGE:CUDA>:--use_fast_math>
230+
$<$<COMPILE_LANGUAGE:CUDA>:--ptxas-options=-v,--register-usage-level=10>)
231+
232+
# Extension name definition
233+
target_compile_definitions(flash_mla_cpp_tllm
234+
PRIVATE TORCH_EXTENSION_NAME=flash_mla_cpp_tllm)
235+
236+
# Include directories matching FlashMLA setup.py
237+
target_include_directories(
238+
flash_mla_cpp_tllm
239+
PRIVATE ${CUDA_INCLUDE_DIRS}
240+
${FLASH_MLA_SOURCE_DIR}/csrc
241+
${FLASH_MLA_SOURCE_DIR}/csrc/sm90
242+
${FLASH_MLA_SOURCE_DIR}/csrc/cutlass/include
243+
${FLASH_MLA_SOURCE_DIR}/csrc/cutlass/tools/util/include)
244+
245+
# Link libraries (matching FlashMLA setup.py: cuda, cudart + torch)
246+
target_link_libraries(
247+
flash_mla_cpp_tllm PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIB}
248+
CUDA::cuda_driver CUDA::cudart)
249+
target_link_options(
250+
flash_mla_cpp_tllm PRIVATE
251+
-Wl,--version-script,${CMAKE_CURRENT_SOURCE_DIR}/flash_mla_cpp_tllm.version
252+
-Wl,--no-undefined-version)
253+
254+
# Link directories
255+
target_link_directories(
256+
flash_mla_cpp_tllm PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/lib64
257+
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
258+
259+
# Set targets
260+
# ===========
261+
add_dependencies(flash_mla flash_mla_cpp_tllm)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
global: PyInit_flash_mla_cpp_tllm;
3+
local: *;
4+
};

0 commit comments

Comments
 (0)