Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ ptxas
# Third-party include
third_party/nvidia/backend/include
third_party/nvidia/backend/lib/cupti
third_party/sunrise/backend/lib

# Docs
docs/_build/
Expand Down
94 changes: 91 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,43 @@ elseif(FLAGTREE_BACKEND STREQUAL "aipu")
elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro")
set(CMAKE_C_COMPILER clang)
set(CMAKE_CXX_COMPILER clang++)
elseif(FLAGTREE_BACKEND STREQUAL "sunrise")
# remove_definitions(-D_GLIBCXX_USE_CXX11_ABI=1)
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
find_package(Python3 3.10 REQUIRED COMPONENTS Development.Module Interpreter)
if(EDITABLE_MODE)
set (DEFAULT_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/sunrise")
else()
set (DEFAULT_PLUGIN_DIR "${Python3_SITELIB}/triton/_C")
endif()
add_definitions(-DDEFAULT_PLUGIN_DIR="${DEFAULT_PLUGIN_DIR}")
endif()
set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}")
if(FLAGTREE_PLUGIN)
add_definitions(-D__FLAGTREE_PLUGIN__)
endif()

# FLAGTREE SPEC LIB GET FUNC
function(get_flagtree_backend_lib lib_name output_lib)
set(ret FlagTree_${FLAGTREE_BACKEND}_${lib_name})
if(NOT TARGET ${ret})
set(ret "")
endif()
set(${output_lib} ${ret} PARENT_SCOPE)
endfunction()

# FLAGTREE SPEC TD FILE GET FUNC
function(set_flagtree_backend_td output_td td_filename)
set(ret ${td_filename})
file(RELATIVE_PATH relative_path "${PROJECT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}")
get_filename_component(BACKEND_SPEC_ROOT "${BACKEND_SPEC_INCLUDE_DIR}" DIRECTORY)
set(BACKEND_SPEC_TD ${BACKEND_SPEC_ROOT}/${relative_path}/${td_filename})
if(EXISTS ${BACKEND_SPEC_TD})
set(ret ${BACKEND_SPEC_TD})
endif()
set(${output_td} ${ret} PARENT_SCOPE)
endfunction()

project(triton CXX C)
include(CTest)

Expand Down Expand Up @@ -119,12 +150,20 @@ if(TRITON_BUILD_UT)
endif()

# Compiler flags
set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include)
set(FLAGTREE_BACKEND_DIR ${PROJECT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND})
## flagtree spec include dir
set(BACKEND_SPEC_INCLUDE_DIR ${FLAGTREE_BACKEND_DIR}/backend/spec/include)
if(FLAGTREE_BACKEND AND EXISTS ${BACKEND_SPEC_INCLUDE_DIR})
include_directories(${BACKEND_SPEC_INCLUDE_DIR})
endif()
## flagtree third_party include dir
set(BACKEND_INCLUDE_DIR ${FLAGTREE_BACKEND_DIR}/include)
if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}")
include_directories(${BACKEND_INCLUDE_DIR})
else()
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
endif()

if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
else()
Expand Down Expand Up @@ -378,6 +417,53 @@ if(TRITON_BUILD_PYTHON_MODULE)
LLVMRISCVCodeGen
LLVMRISCVAsmParser
)
elseif(FLAGTREE_BACKEND STREQUAL "sunrise")
set(TRITON_LIBRARIES
${triton_libs}
${triton_plugins}

# mlir
# MLIRAMDGPUDialect
# MLIRNVVMDialect
MLIRSTVMDialect # STVM
MLIRNVVMToLLVMIRTranslation
MLIRSTVMToLLVMIRTranslation
MLIRGPUToNVVMTransforms
MLIRGPUToSTVMTransforms
MLIRGPUToGPURuntimeTransforms
MLIRGPUTransforms
MLIRIR
MLIRControlFlowToLLVM
MLIRBytecodeWriter
MLIRPass
MLIRTransforms
MLIRLLVMDialect
MLIRSupport
MLIRTargetLLVMIRExport
MLIRMathToLLVM
# MLIRROCDLToLLVMIRTranslation
MLIRGPUDialect
MLIRSCFToControlFlow
MLIRIndexToLLVM
MLIRGPUToROCDLTransforms
MLIRUBToLLVM

# LLVM
LLVMPasses
# LLVMNVPTXCodeGen
# LLVMAMDGPUCodeGen
# LLVMAMDGPUAsmParser
LLVMSTCUCodeGen
LLVMSTCUAsmParser
LLVMAArch64CodeGen
LLVMAArch64AsmParser
LLVMRISCVCodeGen
LLVMRISCVAsmParser

Python3::Module
pybind11::headers
)

endif()

if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64
Expand Down Expand Up @@ -424,7 +510,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
endif()

# Link triton with its dependencies
target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES})
#target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES})
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
if(WIN32)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
Expand All @@ -450,7 +537,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
endif()

if (UNIX AND NOT APPLE)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
#set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--export-dynamic")
endif()

if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ using ::mlir::triton::gpu::CTALayoutAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
#ifdef FLAGTREE_SPEC_BackendMmaEncodingAttr
using FLAGTREE_SPEC_BackendMmaEncodingAttr;
#endif

Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
ArrayRef<Value> strides);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
#set(LLVM_TARGET_DEFINITIONS Passes.td)
set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU)
add_public_tablegen_target(TritonConversionPassIncGen)
12 changes: 10 additions & 2 deletions include/triton/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonOps.td)
# set(LLVM_TARGET_DEFINITIONS TritonOps.td)
set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonOps.td)
# mlir_tablegen(Ops.h.inc -gen-op-decls)
# mlir_tablegen(Ops.cpp.inc -gen-op-defs)
# mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
# mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc)

mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
Expand All @@ -12,7 +19,8 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc)

set(LLVM_TARGET_DEFINITIONS TritonTypes.td)
#set(LLVM_TARGET_DEFINITIONS TritonTypes.td)
set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonTypes.td)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)

Expand Down
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
#set(LLVM_TARGET_DEFINITIONS Passes.td)
set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
add_public_tablegen_target(TritonTransformsIncGen)
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
# set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg)
mlir_tablegen(Ops.h.inc -gen-op-decls)
Expand All @@ -11,7 +12,7 @@ add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc)
add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonGPUTableGen)

set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls)
Expand All @@ -21,6 +22,7 @@ mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
#set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(TritonGPUTypeInterfacesIncGen)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
#set(LLVM_TARGET_DEFINITIONS Passes.td)
set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU)
add_public_tablegen_target(TritonGPUTransformsIncGen)
9 changes: 9 additions & 0 deletions lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
if (FLAGTREE_BACKEND)
set(NVGPUIR "")
else()
set(NVGPUIR "TritonNvidiaGPUIR")
endif()

get_flagtree_backend_lib("TritonAnalysis" _EXTRA_LINK_LIBS)

add_triton_library(TritonAnalysis
AxisInfo.cpp
Allocation.cpp
Expand All @@ -17,4 +25,5 @@ add_triton_library(TritonAnalysis
TritonIR
TritonGPUIR
TritonNvidiaGPUIR
${_EXTRA_LINK_LIBS}
)
12 changes: 12 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
if (FLAGTREE_BACKEND)
set(NVGPUIR "")
set(NVGPUTransforms "")
else()
set(NVGPUIR "NVGPUIR")
set(NVGPUTransforms "TritonNvidiaGPUTransforms")
endif()

get_flagtree_backend_lib("TritonGPUToLLVM" _EXTRA_LINK_LIBS)

add_triton_library(TritonGPUToLLVM
DotOpToLLVM/FMA.cpp
DotOpToLLVM/FMADotUtility.cpp
Expand Down Expand Up @@ -36,4 +46,6 @@ add_triton_library(TritonGPUToLLVM
TritonGPUIR
TritonGPUTransforms
TritonNvidiaGPUTransforms

${_EXTRA_LINK_LIBS}
)
4 changes: 4 additions & 0 deletions lib/Conversion/TritonToTritonGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
get_flagtree_backend_lib("TritonToTritonGPU" _EXTRA_LINK_LIBS)

add_triton_library(TritonToTritonGPU
RelayoutTritonGPU.cpp
TritonGPUConversion.cpp
Expand All @@ -13,4 +15,6 @@ add_triton_library(TritonToTritonGPU
TritonIR
ProtonIR
TritonGPUIR

${_EXTRA_LINK_LIBS}
)
4 changes: 4 additions & 0 deletions lib/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
get_flagtree_backend_lib("TritonGPUIR" _EXTRA_LINK_LIBS)

add_triton_library(TritonGPUIR
Dialect.cpp
LinearLayoutConversions.cpp
Expand All @@ -14,4 +16,6 @@ add_triton_library(TritonGPUIR
MLIRGPUDialect
TritonIR
TritonTools

${_EXTRA_LINK_LIBS}
)
8 changes: 8 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#if __has_include("flagtree_spec.h")
#include "flagtree_spec.h"
#endif

#ifndef FLAGTREE_SPEC_Dialect_TritonGPU_IR_Dialect_cpp

#include "triton/Dialect/Triton/IR/Dialect.h"

#include <cstdint>
Expand Down Expand Up @@ -3215,3 +3221,5 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(ArrayRef<int64_t> srcShape,
auto dst = reshapeLayout(ctx, src, dstShape);
return dst;
}

#endif// FLAGTREE_SPEC_Dialect_TritonGPU_IR_Dialect_cpp
8 changes: 8 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#if __has_include("flagtree_spec.h")
#include "flagtree_spec.h"
#endif

#ifndef FLAGTREE_SPEC_Triton_Dialect_TritonGPU_IR_sunrise_LinearLayoutConversion

#include <vector>

#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -1820,3 +1826,5 @@ LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
}

} // namespace mlir::triton::gpu

#endif//FLAGTREE_SPEC_Triton_Dialect_TritonGPU_IR_sunrise_LinearLayoutConversion
11 changes: 11 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
if(IS_COMPILE_TritonNvidiaGPU)
set(_TMA_LINK_CPP "Pipeliner/TMAStoresPipeline.cpp")
set(_NVIDIA_LINK_LIBS "TritonNvidiaGPUIR")
else()
set(_TMA_LINK_CPP "")
set(_NVIDIA_LINK_LIBS "")
endif()
get_flagtree_backend_lib("TritonGPUTransforms" _EXTRA_LINK_LIBS)

add_triton_library(TritonGPUTransforms
AccelerateMatmul.cpp
Canonicalize.cpp
Expand Down Expand Up @@ -50,4 +59,6 @@ add_triton_library(TritonGPUTransforms
TritonNvidiaGPUIR
TritonToTritonGPU
MLIRTransformUtils

${_EXTRA_LINK_LIBS}
)
8 changes: 8 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#if __has_include("flagtree_spec.h")
#include "flagtree_spec.h"
#endif

#ifndef FLAGTREE_SPEC_Triton_Dialect_TritonGPU_Transforms_Sunrise_Coalesce

#include <iterator>
#include <numeric>

Expand Down Expand Up @@ -193,3 +199,5 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
} // namespace gpu
} // namespace triton
} // namespace mlir

#endif//FLAGTREE_SPEC_Triton_Dialect_TritonGPU_Transforms_Sunrise_Coalesce
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#if __has_include("flagtree_spec.h")
#include "flagtree_spec.h"
#endif

#ifndef FLAGTREE_SPEC_Dialect_TritonGPU_Transforms_PipeliningUtility

#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -743,3 +749,5 @@ scf::ForOp triton::lowerTMADescriptors(scf::ForOp forOp,
}
return forOp;
}

#endif//FLAGTREE_SPEC_Dialect_TritonGPU_Transforms_PipeliningUtility
7 changes: 7 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
// scf.yield %next_a, ..., %a_prefetch_next
// }
//===----------------------------------------------------------------------===//
#if __has_include("flagtree_spec.h")
#include "flagtree_spec.h"
#endif

#ifndef FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Prefetch

#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -459,3 +464,5 @@ struct PrefetchPass : public impl::TritonGPUPrefetchBase<PrefetchPass> {
} // namespace gpu
} // namespace triton
} // namespace mlir

#endif// FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Prefetch
Loading