diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml index 50dad4678..9e33248f2 100644 --- a/.github/workflows/ascend-build-and-test.yml +++ b/.github/workflows/ascend-build-and-test.yml @@ -12,8 +12,15 @@ concurrency: jobs: ascend-build-and-test: - runs-on: ascend + runs-on: flagtree-ascend + if: ${{ github.repository == 'FlagTree/flagtree' || github.repository == 'flagos-ai/flagtree' }} steps: + - name: Setup environment + shell: bash + run: | + source ~/env.sh + env | grep -E '^(http_proxy|https_proxy|all_proxy|no_proxy)=' >> $GITHUB_ENV || true + - name: Checkout code (attempt 1) id: checkout1 uses: actions/checkout@v6 @@ -60,11 +67,10 @@ jobs: shell: bash run: | set -x - pip uninstall -y triton export FLAGTREE_BACKEND=ascend source ~/env.sh cd python - MAX_JOBS=32 python3 -m pip install . --no-build-isolation + MAX_JOBS=32 python3 -m pip install . --no-build-isolation -vvv - name: FlagTree Test on Ascend if: steps.check_files.outputs.only_docs_changed != 'true' diff --git a/.github/workflows/nv-build-and-test.yml b/.github/workflows/nv-build-and-test.yml index e65cbc809..d4049e177 100644 --- a/.github/workflows/nv-build-and-test.yml +++ b/.github/workflows/nv-build-and-test.yml @@ -111,7 +111,10 @@ jobs: shell: bash run: | set -x - python3.11 -m pytest -s python/test/unit + python3.11 -m pytest -s python/test/unit \ + --ignore=python/test/unit/test_debug.py \ + --ignore=python/test/unit/test_debug_dump.py \ + --ignore=python/test/unit/tools/test_disasm.py if [ -d "python/test/operators" ]; then python3.11 -m pytest -s python/test/operators fi diff --git a/CMakeLists.txt b/CMakeLists.txt index 570725b65..bf2ea4f53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -410,7 +410,7 @@ if(TRITON_BUILD_PYTHON_MODULE) elseif(FLAGTREE_BACKEND STREQUAL "ascend") set(PYTHON_ROOT_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) include_directories(${PYTHON_ROOT_SRC_PATH}) - add_library(triton SHARED ${PYTHON_ROOT_SRC_PATH}/main.cc + add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/ir.cc ${PYTHON_ROOT_SRC_PATH}/passes.cc ${PYTHON_ROOT_SRC_PATH}/interpreter.cc diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 574017654..4ed63ff87 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -13,16 +13,14 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" -#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" -#include "triton/Dialect/Triton/IR/Traits.h" -#include "triton/Dialect/Triton/IR/Types.h" #if __has_include("flagtree_spec.h") #include "flagtree_spec.h" #endif -#if __has_include("triton/Dialect/Triton/IR/OpInterfaces.h") -#include "triton/Dialect/Triton/IR/OpInterfaces.h" -#endif + +#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" #define GET_OP_CLASSES #include "triton/Dialect/Triton/IR/Ops.h.inc" diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 5857ca60c..294683d33 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -822,7 +822,6 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp // We could revert it back once MLIR has a better inliner interface. //-- FuncOp -- -#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_Ops_build void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, FunctionType type, ArrayRef attrs, ArrayRef argAttrs) { @@ -835,11 +834,14 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); +#if LLVM_VERSION_MAJOR < 21 function_interface_impl::addArgAndResultAttrs( +#else // triton_v3.3.x + call_interface_impl::addArgAndResultAttrs( +#endif builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } -#endif ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = @@ -918,7 +920,6 @@ LogicalResult ReturnOp::verify() { } // -- JoinOp -- -#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_Ops_inferReturnTypes LogicalResult JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, @@ -950,7 +951,6 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); return success(); } -#endif // -- SplitOp -- LogicalResult SplitOp::inferReturnTypes( diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py index 53bf7178a..b311c2f29 100644 --- a/python/setup_tools/utils/ascend.py +++ b/python/setup_tools/utils/ascend.py @@ -4,7 +4,7 @@ downloader = DownloadManager() submodules = (Module(name="ascendnpu-ir", url="https://gitcode.com/Ascend/AscendNPU-IR.git", - commit_id="04045a06ec7c9592b17de659307d5debe7be590a", + commit_id="0501294d3e", dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "ascendnpu-ir")), ) diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index e1c74b677..05bf1fe49 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -4,8 +4,6 @@ import triton.language as tl import triton - -@pytest.mark.skip(reason="flagtree") @pytest.mark.parametrize('cond, opt_flag, env_var', [ (cond, opt_flag, env_var) for cond in [True, False] \ for opt_flag in [True, False] \ @@ -30,7 +28,6 @@ def _kernel(COND: tl.constexpr): getattr(torch, device).synchronize() -@pytest.mark.skip(reason="flagtree") @pytest.mark.parametrize("cond", [False, True]) def test_static_assert(cond): @@ -64,7 +61,6 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref # integer overflow sanitization -@pytest.mark.skip(reason="flagtree") @pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ (-2**31, -1, 'int32', 'int32', False, False), (-2**31, -1, 'int32', 'int32', True, True), @@ -89,7 +85,6 @@ def _kernel_add(X, Y, Z): # mul overflow -@pytest.mark.skip(reason="flagtree") @pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ (2**30, 4, 'int32', 'int32', False, False), (2**30, 4, 'int32', 'int32', True, True), @@ -111,7 +106,6 @@ def _kernel_mul(X, Y, Z): # sub overflow -@pytest.mark.skip(reason="flagtree") @pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ (-2**31, 1, 'int32', 'int32', False, False), (-2**31, 1, 'int32', 'int32', True, True), diff --git a/python/test/unit/test_debug_dump.py b/python/test/unit/test_debug_dump.py index a387df42d..4f522941e 100644 --- a/python/test/unit/test_debug_dump.py +++ b/python/test/unit/test_debug_dump.py @@ -16,8 +16,6 @@ def enable_dump_context(pass_name="1"): def test_fn_dump(capfd, device, fresh_triton_cache): - return # TODO: flagtree - N = 1024 src = torch.zeros(N, device=device) diff --git a/python/test/unit/tools/test_disasm.py b/python/test/unit/tools/test_disasm.py index f2c9bcc0d..cc4982706 100644 --- a/python/test/unit/tools/test_disasm.py +++ b/python/test/unit/tools/test_disasm.py @@ -5,7 +5,6 @@ import triton.language as tl -@pytest.mark.skip(reason="flagtree") def test_disam_cubin(): if not triton.runtime.driver.active.get_current_target().backend == "cuda": pytest.skip("Test requires CUDA.") diff --git a/third_party/ascend/backend/spec/include/flagtree_spec.h b/third_party/ascend/backend/spec/include/flagtree_spec.h index 22ae6730b..c1da69adf 100644 --- a/third_party/ascend/backend/spec/include/flagtree_spec.h +++ b/third_party/ascend/backend/spec/include/flagtree_spec.h @@ -1,7 +1,7 @@ #ifndef ASCEND_FLAGTREE_SPEC_H_ #define ASCEND_FLAGTREE_SPEC_H_ -#include "triton/Dialect/Triton/IR/ascend_Ops.h" #include "triton/Dialect/Triton/IR/ascend_Traits.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" #endif // ASCEND_FLAGTREE_SPEC_H_ diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h index 101bae834..e11593e80 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -1,5 +1,4 @@ -// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream -// Triton file. +//FIXME TODO: When upgrading to Triton 3.4.0, del this file #ifndef TRITON_IR_OP_INTERFACES_H_ #define TRITON_IR_OP_INTERFACES_H_ diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td index 720bfbd7b..0dc837eaf 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -1,4 +1,4 @@ -// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. +// FIXME TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. #ifndef TRITON_OP_INTERFACES #define TRITON_OP_INTERFACES diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td index c4284caba..dba1e9343 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td @@ -352,256 +352,6 @@ def TT_StoreOp : TT_Op<"store", [ let hasCanonicalizer = 1; } -// -// EmbeddingGather Op -// -def TT_EmbeddingGatherOp : TT_Op<"embedding_gather", [ - DeclareOpInterfaceMethods, - SameVariadicOperandSize, -]> { - let summary = "Gather load from a tensor pointer with the embedding semantics"; - - let arguments = ( - ins TT_Ptr:$src, - TT_Tensor:$idx, - AnyTypeOf<[I32, I64]>:$bound, - AnyTypeOf<[I32, I64]>:$blocksize, - Variadic>:$offsets, - Variadic>:$numels - ); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = [{ - $src `:` type($src) `,` $idx `:` type($idx) `,` - $bound `:` type($bound) `,` $blocksize `:` type($blocksize) `,` - `[` $offsets `:` type($offsets) `]` `,` `[` $numels `:` type($numels) `]` - attr-dict `->` type($result) - }]; - - let builders = [ - OpBuilder<(ins - "Value":$src, - "Value":$idx, - "Value":$bound, - "Value":$blocksize, - "ValueRange":$offsets, - "ValueRange":$numels - )> - ]; - // let hasCanonicalizer = 1; -} - -// -// IndexPut Op -// -def TT_IndexPutOp : TT_Op<"index_put", [ - MemoryEffects<[MemWrite]>, - SameVariadicOperandSize, -]> { - let summary = "Scatter store to a tensor pointer with embedding semantics"; - - let description = [{ - Index put values from a tensor into a destination tensor. - - The operation takes: - - ptr: pointer type, the destination tensor pointer (in GM) - - index: tensor, a index to scatter (in UB) - - value: tensor, a value to store (in UB) - - dim: int32, the dimension to scatter along - - index_boundary: int64, the upper boundary for index values - - end_offset: tuple of int, the offsets of each dimension for the end of the scatter region - - start_offset: tuple of int, the offsets of each dimension for the start of the scatter region - - dst_stride: tuple of int, the stride of each dimension of destination tensor - - - Constraints: - - `ptr` and `value` must have the same rank. - - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. - - `index.numel` must equal `value.shape[dim]`. - - `value` support 2~5D tensors. - - `dim` must be valid (0 <= dim < rank(value) - 1). - }]; - - let arguments = ( - ins TT_Ptr:$ptr, - TT_Tensor:$index, - TT_Tensor:$value, - TT_Int:$dim, - TT_Int:$indexBoundary, - Variadic>:$endOffset, - Variadic>:$startOffset, - Variadic>:$dstStride - ); - - let assemblyFormat = [{ - $ptr `:` type($ptr) `,` $index `:` type($index) `,` - $value `:` type($value) `,` $dim `:` type($dim) `,` $indexBoundary `:` type($indexBoundary) `,` - `[` $endOffset `:` type($endOffset) `]` `,` `[` $startOffset `:` type($startOffset) `]` `,` - `[` $dstStride `:` type($dstStride) `]` - attr-dict - }]; -} - -// -// GatherOutToUb Op -// -def TT_GatherOutToUbOp : TT_Op<"gather_out_to_ub", [ - DeclareOpInterfaceMethods, - AttrSizedOperandSegments, -]> { - let summary = "Gather load from a tensor pointer with the embedding semantics"; - - let description = [{ - Gather from a source tensor in Global Memory (GM) to Unified Buffer (UB) - along a specified dimension with out-of-bound handling. - - The operation takes: - - src: pointer type, the source tensor pointer (in GM) - - index: tensor, a tensor to gather (in UB) - - index_boundary: int64, the upper boundary for index values - - dim: int32, the dimension to gather along - - src_stride: tuple of int64, the stride of each dimension of src tensor - - end_offset: tuple of int32, the end offsets of each dimension for index tensor - - start_offset: tuple of int32, the start offsets of each dimension for index tensor - - other(Optional): scalar value, the default value when index is out of boundary (in UB) - - Returns: - a tensor, with the same shape as `index.shape` (in UB) - - Constraints: - - `src` and `index` must have the same rank. - - `src.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor, with rank between 1 and 5. - - `dim` must be valid (0 <= dim < rank(index)). - - `other` must be a scalar value. - - For every dimension `i` not equal to `dim`, `index.size[i]` <= `src.size[i]`. - - The output shape is the same as `index.shape`. If `index` is None, \ - the output tensor will be an empty tensor with the same shape as `index`. - }]; - - let arguments = ( - ins TT_Ptr:$src, - TT_Tensor:$index, - TT_Int:$indexBoundary, - TT_Int:$dim, - Variadic>:$srcStride, - Variadic>:$endOffset, - Variadic>:$startOffset, - Optional:$other - ); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = [{ - $src `:` type($src) `,` $index `:` type($index) `,` - $indexBoundary `:` type($indexBoundary) `,` $dim `:` type($dim) `,` - `[` $srcStride `:` type($srcStride) `]` `,` `[` $endOffset `:` type($endOffset) `]` `,` - `[` $startOffset `:` type($startOffset) `]` (`,` $other^ `:` type($other))? - attr-dict `->` type($result) - }]; -} - -// -// ScatterUbToOut Op -// -def TT_ScatterUbToOutOp : TT_Op<"scatter_ub_to_out", [ - MemoryEffects<[MemWrite]>, - SameVariadicOperandSize, -]> { - let summary = "scatter store from a tensor pointer with the embedding semantics"; - - let description = [{ - Scatter a tile from Unified Buffer (UB) into a destination tensor in Global Memory (GM) - along a specified dimension, with index-boundary checking. - - The operation takes: - - ptr: pointer type, the destination tensor pointer (in GM) - - value: tensor, a tile value to store (in UB) - - index: tensor, a tile index to scatter (in UB) - - index_boundary: int, the upper boundary for index values - - dim: int, the dimension to scatter along - - dst_stride: tuple of int, the stride of each dimension of destination tensor - - end_offset: tuple of int32, the end offsets of each dimension for index tensor - - start_offset: tuple of int32, the start offsets of each dimension for index tensor - - Constraints: - - `ptr` and `index` must have the same rank. - - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor, with rank between 1 and 5. - - `dim` must be valid (0 <= dim < rank(index)). - - For every dimension `i` not equal to `dim`, `index.size[i]` <= `ptr.size[i]`. - - The output shape is the same as `index.shape`. If `index` is None, \ - the output tensor will be an empty tensor with the same shape as `index`. - }]; - - let arguments = ( - ins TT_Ptr:$ptr, - TT_Tensor:$value, - TT_Tensor:$index, - TT_Int:$indexBoundary, - TT_Int:$dim, - Variadic>:$dstStride, - Variadic>:$endOffset, - Variadic>:$startOffset - ); - - let assemblyFormat = [{ - $ptr `:` type($ptr) `,` $value `:` type($value) `,` `,` $index `:` type($index) `,` - $indexBoundary `:` type($indexBoundary) `,` $dim `:` type($dim) `,` - `[` $dstStride `:` type($dstStride) `]` `,` `[` $endOffset `:` type($endOffset) `]` `,` - `[` $startOffset `:` type($startOffset) `]` - attr-dict - }]; -} - -def TT_IndexSelectSimdOp : TT_Op<"index_select_simd", [ - MemoryEffects<[MemRead]>, - DeclareOpInterfaceMethods, - AttrSizedOperandSegments -]> { - let summary = "Index select SIMD operation from global memory"; - - let description = [{ - Index select operation (SIMD version) that loads data from multiple indices along a - specified dimension. The operation selects data from GM and loads them - as tiles directly to UB with zero-copy semantics. - - The operation takes: - - src: Source pointer (in GM) - - index: 1D tensor of indices to select (already in UB) - - dim: The dimension along which to select - - src_shape: Complete shape of the source tensor - - src_offset: Starting offset for reading - - read_shape: Size to read (tile shape) - - Constraints: - - read_shape[dim] must be -1 - - src_offset[dim] can be -1 (will be ignored) - }]; - - let arguments = ( - ins - TT_PtrLike:$src, - TT_IntTensor:$index, - I32Attr:$dim, - Variadic:$src_shape, - Variadic:$src_offset, - DenseI32ArrayAttr:$read_shape - ); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = [{ - $src `,` $index `,` $dim `,` - `[` $src_shape `]` `,` - `[` $src_offset `]` `,` - $read_shape - attr-dict `:` type($src) `,` type($index) `->` type($result) - }]; -} - // // Atomic Ops // @@ -792,53 +542,6 @@ def TT_SplitOp : TT_Op<"split", [ let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; } -def TT_FlipOp : TT_Op<"flip", [ - NoMemoryEffect, - DeclareOpInterfaceMethods -]> { - let summary = "Reverse a tensor along a given dimension"; - let description = [{ - Reverses the elements of the input tensor along the specified dimension. - The output tensor has the same shape and element type as the input. - }]; - - let arguments = (ins - TT_Tensor:$src, // Input tensor - I64Attr:$dim // Dimension to flip along - ); - - let results = (outs - TT_Tensor:$flipped // Flipped values - ); - - let assemblyFormat = - "$src `,` $dim attr-dict `:` type($src) `->` type($flipped)"; -} - -def TT_SortOp : TT_Op<"sort", [ - NoMemoryEffect, - DeclareOpInterfaceMethods -]> { - let summary = "Sorts a tensor along a given dimension and returns sorted values."; - let description = [{ - Sorts the elements of the input tensor along the specified dimension. - Returns one tensor: - The sorted tensor (same shape and element type as input). - }]; - - let arguments = (ins - TT_Tensor:$src, // Input tensor - I64Attr:$dim, // Dimension to sort along - BoolAttr:$descending // Sort order - ); - - let results = (outs - TT_Tensor:$sorted // Sorted values - ); - - let assemblyFormat = "$src `,` $dim `,` $descending attr-dict `:` type($src) `->` type($sorted)"; -} - def TT_TransOp : TT_Op<"trans", [Pure, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { @@ -1167,20 +870,6 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> { }]; } -// -// Mod Op -// -def TT_ModOp : TT_Op<"mod", [Pure]> { - let summary = "Mod operation (%) of input tensors."; - let description = [{ - Performs element-wise division with remainder of input tensors. - }]; - - let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) type($rhs) `->` type($result)"; -} // // Gather Op @@ -1208,6 +897,7 @@ def TT_GatherOp : TT_Op<"gather", [Pure, let hasVerifier = 1; } + // // Print Op // @@ -1285,89 +975,6 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", ]; } -// -// Make Tensor Descriptor Op -// -def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ - Pure, - SameVariadicOperandSize, -]> { - let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size"; - - let description = [{ - `tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size, - and returns a descriptor object which can be used to load/store from the tensor in global memory. - }]; - - let arguments = (ins - TT_Ptr:$base, - Variadic:$shape, - Variadic:$strides - ); - - let results = (outs TT_TensorDescType:$result); - - let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)"; - - let builders = [ - OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef":$blockShape, "bool":$isSignedInteger)> - ]; - - let extraClassDeclaration = [{ - ArrayRef getTensorShape() { - return getType().getBlockType().getShape(); - } - }]; -} - -def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorOpInterface]> { - let summary = "Load from descriptor"; - let description = [{ - This operation will be lowered to Nvidia TMA load operation on targets supporting it. - `desc` is a tensor descriptor object. - The destination tensor type and shape must match the descriptor otherwise the result is undefined. - }]; - let arguments = (ins - Arg]>:$desc, - Variadic:$indices, - DefaultValuedAttr:$cache, - DefaultValuedAttr:$evict - ); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = [{ - $desc `[` $indices `]` - oilist( - `cacheModifier` `=` $cache | - `evictionPolicy` `=` $evict - ) - attr-dict `:` qualified(type($desc)) `->` type($result) - }]; - - let hasVerifier = 1; -} - -def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpInterface]> { - let summary = "store value based on descriptor"; - let description = [{ - This operation will be lowered to Nvidia TMA store operation on targets supporting it. - `desc` is a tensor descriptor object. - The shape and types of `src` must match the descriptor otherwise the result is undefined. - }]; - let arguments = (ins - Arg, MemWrite]>:$desc, - TT_Tensor:$src, - Variadic:$indices - ); - - let assemblyFormat = [{ - $desc `[` $indices `]` `,` $src - attr-dict `:` qualified(type($desc)) `,` type($src) - }]; - let hasVerifier = 1; -} - // The following ops, including `call`, `func`, and `return` are copied and modified from // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td // We could revert it back once MLIR has a better inliner interface. @@ -1389,10 +996,7 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, - Variadic:$operands, - OptionalAttr:$arg_attrs, // triton_v3.3.x - OptionalAttr:$res_attrs); // triton_v3.3.x + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); let results = (outs Variadic); let builders = [ @@ -1681,131 +1285,89 @@ def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< $desc_ptr attr-dict `:` qualified(type($desc_ptr)) }]; } - - +//FIXME TODO: When upgrading to Triton 3.4.0, remove the commented line below // -// Annotation Op +// Make Tensor Descriptor Op FIXME // -def TT_AnnotationOp : TT_Op<"annotation", [Pure, MemoryEffects<[MemWrite]>]> { - let summary = "Annotate a tensor with key-value attribute pairs"; - let description = [{ - `tt.annotation` operation can be used to annotate a tensor with - key-value attribute pairs. - - Example: - ```mlir - tt.annotation %target {key : val} - ``` - }]; - let arguments = (ins TT_Tensor:$src); - let assemblyFormat = [{ - $src attr-dict `:` type($src) - }]; -} +def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ + Pure, + SameVariadicOperandSize, +]> { + let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size"; -// -// Custom Op -// -def TT_CustomOp : TT_Op<"custom", [Pure, MemoryEffects<[MemWrite]>]> { - let summary = "self-defined custom operation"; let description = [{ - `tt.custom` triton custom op is designed to pass self-defined custom operation. - - Example: - ```tt.custom {str_args = ["sync_block_wait", "cube"]} loc(#loc12) - ``` + `tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size, + and returns a descriptor object which can be used to load/store from the tensor in global memory. }]; - let arguments = (ins StrAttr:$op_name, ArrayAttr:$str_args, Variadic:$args); - let assemblyFormat = "$op_name attr-dict ($args^ `:` type($args))?"; -} + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides + ); -// -// Built-in: IndirectLoad Op -// -def TT_IndirectLoadOp : TT_Op<"indirect_load", [ - DeclareOpInterfaceMethods, - AttrSizedOperandSegments -]> { - let summary = "Built-in: indirect load from global memory using per-element offsets with optional mask/other"; + let results = (outs TT_TensorDescType:$result); - let description = [{ - Built-in operation emitted by the compiler for unstructured (discrete) memory - accesses.These are not written directly in the user IR. + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)"; - Load values from global memory based on per-element offsets. If `mask` - is provided, false lanes return `other`. + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef":$blockShape, "bool":$isSignedInteger)> + ]; - The operation takes: - - src: Source pointer - - offsets: Tensor of per-element offsets (relative to `src`) for accessing source memory - - mask (optional): if mask[idx] is false, do not load the data at address pointer[idx] - - other (optional): if mask[idx] is false, return other[idx] + let extraClassDeclaration = [{ + ArrayRef getTensorShape() { + return getType().getBlockType().getShape(); + } }]; +} - let arguments = ( - ins TT_Ptr:$src, - TT_IntTensor:$offsets, - Optional:$mask, - Optional:$other +def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorOpInterface]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc` is a tensor descriptor object. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + }]; + let arguments = (ins + Arg]>:$desc, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict ); let results = (outs TT_Tensor:$result); let assemblyFormat = [{ - $src `:` type($src) `,` - $offsets `:` type($offsets) - (`,` $mask^ `:` type($mask))? - (`,` $other^ `:` type($other))? - attr-dict `->` type($result) + $desc `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc)) `->` type($result) }]; - let builders = [ - OpBuilder<(ins - "Value":$src, - "Value":$offsets, - "Value":$mask, - "Value":$other - )> - ]; + let hasVerifier = 1; } - -// -// Built-in: IndirectStore Op -// -def TT_IndirectStoreOp : TT_Op<"indirect_store", [ - MemoryEffects<[MemWrite]> -]> { - let summary = "Built-in: indirect store from UB using per-element offsets with optional mask/other"; - +def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpInterface]> { + let summary = "store value based on descriptor"; let description = [{ - Built-in operation emitted by the compiler for unstructured (discrete) memory - accesses.These are not written directly in the user IR. - - Store values from UB based to GM on per-element offsets. - - The operation takes: - - src: Source pointer - - offsets: Tensor of per-element offsets (relative to `src`) for accessing source memory - - value: The tensor of elements to be stored - - mask (optional): If mask[idx] is false, do not store value[idx] at pointer[idx] + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc` is a tensor descriptor object. + The shape and types of `src` must match the descriptor otherwise the result is undefined. }]; - - let arguments = ( - ins TT_Ptr:$src, - TT_IntTensor:$offsets, - TT_Type:$value, - Optional:$mask + let arguments = (ins + Arg, MemWrite]>:$desc, + TT_Tensor:$src, + Variadic:$indices ); let assemblyFormat = [{ - $src `:` type($src) `,` - $offsets `:` type($offsets) `,` - $value `:` type($value) - (`,` $mask^ `:` type($mask))? - attr-dict + $desc `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc)) `,` type($src) }]; + let hasVerifier = 1; } + #endif // Triton_OPS diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonTypes.td index 5f9f19a07..318f1e3bc 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -1,4 +1,3 @@ -// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. #ifndef TRITON_TYPES #define TRITON_TYPES @@ -140,7 +139,7 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> ]; let hasCustomAssemblyFormat = 1; } - +//FIXME TODO: When upgrading to Triton 3.4.0, remove the commented line below // Result type of MakeTensorDescriptor def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> { let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system"; diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/ascend_Ops.h b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/ascend_Ops.h deleted file mode 100644 index b88f3d565..000000000 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/ascend_Ops.h +++ /dev/null @@ -1,7 +0,0 @@ -#ifndef ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_ -#define ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_ - -#define FLAGTREE_SPEC_Dialect_Triton_IR_Ops_build -#define FLAGTREE_SPEC_Dialect_Triton_IR_Ops_inferReturnTypes - -#endif // ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_ diff --git a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp index 5ff5c3859..8129cbd20 100644 --- a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp +++ b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp @@ -1,5 +1,3 @@ #include "triton/Dialect/Triton/IR/Dialect.h" -#if __has_include("triton/Dialect/Triton/IR/OpInterfaces.cpp.inc") #include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" -#endif diff --git a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp index 4370791f8..8234e4012 100644 --- a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp +++ b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp @@ -38,77 +38,6 @@ namespace mlir { namespace triton { -void EmbeddingGatherOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), - triton::GlobalMemory::get()); -} - -void GatherOutToUbOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), - triton::GlobalMemory::get()); -} - -void IndirectLoadOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), - triton::GlobalMemory::get()); -} - -// FlipOp -LogicalResult -FlipOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto inputTy = dyn_cast(operands[0].getType()); - if (!inputTy) { - if (location) - return emitOptionalError(location, - "expected ranked tensor for flip input"); - return failure(); - } - inferredReturnTypes.push_back(inputTy); - return success(); -} - -//-- SortOp -- -LogicalResult -SortOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (operands.size() != 1) { - return emitOptionalError(location, - "expected exactly one operand for SortOp"); - } - - if (!isa(operands[0].getType())) { - return emitOptionalError(location, - "operand must be a ranked tensor type for SortOp"); - } - - Value src = operands[0]; - auto srcTy = cast(src.getType()); - auto srcShape = srcTy.getShape(); - auto srcEnc = srcTy.getEncoding(); - - if (srcShape.empty()) { - return emitOptionalError(location, "input tensor must have rank >= 1"); - } - - Type sortedTy = - RankedTensorType::get(srcShape, srcTy.getElementType(), srcEnc); - - inferredReturnTypes.push_back(sortedTy); - - return success(); -} - //-- MakeTensorDescOp -- void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, Value base, ValueRange shape, ValueRange strides, diff --git a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Traits.cpp b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Traits.cpp index af4653fb8..5a7615cad 100644 --- a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Traits.cpp +++ b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Traits.cpp @@ -21,6 +21,7 @@ LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { return op->emitError("Maximum allowed number of elements is ") << maxTensorNumElements << ", but " << *op << " has more than that"; + // FIXME:patched triton community // if ((numElements & (numElements - 1)) != 0) // return op->emitError("Number of elements must be power-of-two, but ") // << *op << " doesn't follow the rule (" << numElements << ")" @@ -36,6 +37,7 @@ LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { return op->emitError("Maximum allowed number of elements is ") << maxTensorNumElements << ", but " << *op << " has more than that"; + // FIXME:patched triton community // if ((numElements & (numElements - 1)) != 0) // return op->emitError("Number of elements must be power-of-two, but ") // << *op << " doesn't follow the rule (" << numElements << ")" diff --git a/third_party/ascend/python/src/ir.cc b/third_party/ascend/python/src/ir.cc index 5855ea22e..b2e743fec 100644 --- a/third_party/ascend/python/src/ir.cc +++ b/third_party/ascend/python/src/ir.cc @@ -1,27 +1,3 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. - * Copyright 2018-2020 Philippe Tillet - * Copyright 2020-2022 OpenAI - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - #include #include #include @@ -33,12 +9,12 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/ValueRange.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" @@ -56,99 +32,97 @@ #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/SourceMgr.h" +#include "ir.h" namespace { namespace py = pybind11; using namespace mlir; using namespace triton; +// FIXME:modify community // A custom op builder that keeps track of the last location -class TritonOpBuilder { -public: - TritonOpBuilder(MLIRContext *context, - const std::string &compile_mode = "simd") { - builder = std::make_unique(context); - lastLoc = std::make_unique(builder->getUnknownLoc()); - this->compile_mode = compile_mode; - } - - OpBuilder &getBuilder() { return *builder; } - - bool isLineInfoEnabled() { return lineInfoEnabled; } - - bool isSimtMode() const { return compile_mode == "simt"; } - - void setLastLoc(Location loc) { - if (lineInfoEnabled) - lastLoc = std::make_unique(loc); - } - - void setLastLoc(const std::string &fileName, int line, int column) { - auto context = builder->getContext(); - setLastLoc(FileLineColLoc::get(context, fileName, line, column)); - } - - Location getLastLoc() { - assert(lastLoc); - return *lastLoc; - } - - void setInsertionPointToStart(Block &block) { - if (!block.empty()) - setLastLoc(block.begin()->getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->setInsertionPointToStart(&block); - } - - void setInsertionPointToEnd(Block &block) { - if (!block.empty()) - setLastLoc(block.back().getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->setInsertionPointToEnd(&block); - } - - void setInsertionPointAfter(Operation &op) { - setLastLoc(op.getLoc()); - builder->setInsertionPointAfter(&op); - } - - void restoreInsertionPoint(OpBuilder::InsertPoint pt) { - if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) - setLastLoc(pt.getPoint()->getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->restoreInsertionPoint(pt); - } - - template OpTy create(Args &&...args) { - auto loc = getLastLoc(); - return builder->create(loc, std::forward(args)...); - } - - // Overload to create or fold a single result operation. - template - std::enable_if_t(), Value> - createOrFold(Args &&...args) { - auto loc = getLastLoc(); - return builder->createOrFold(loc, std::forward(args)...); - } - - // Overload to create or fold a zero result operation. - template - std::enable_if_t(), OpTy> - createOrFold(Args &&...args) { - auto loc = getLastLoc(); - return builder->createOrFold(loc, std::forward(args)...); - } - -private: - std::unique_ptr builder; - std::unique_ptr lastLoc; - bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); - std::string compile_mode; -}; +// class TritonOpBuilder { +// public: +// TritonOpBuilder(MLIRContext *context) { +// builder = std::make_unique(context); +// lastLoc = std::make_unique(builder->getUnknownLoc()); +// } + +// OpBuilder &getBuilder() { return *builder; } + +// bool isLineInfoEnabled() { return lineInfoEnabled; } + +// void setLastLoc(Location loc) { +// if (lineInfoEnabled) +// lastLoc = std::make_unique(loc); +// } + +// void setLastLoc(const std::string &fileName, int line, int column) { +// auto context = builder->getContext(); +// setLastLoc(FileLineColLoc::get(context, fileName, line, column)); +// } + +// Location getLastLoc() { +// assert(lastLoc); +// return *lastLoc; +// } + +// void setInsertionPointToStart(Block &block) { +// if (!block.empty()) +// setLastLoc(block.begin()->getLoc()); +// else +// setLastLoc(builder->getUnknownLoc()); +// builder->setInsertionPointToStart(&block); +// } + +// void setInsertionPointToEnd(Block &block) { +// if (!block.empty()) +// setLastLoc(block.back().getLoc()); +// else +// setLastLoc(builder->getUnknownLoc()); +// builder->setInsertionPointToEnd(&block); +// } + +// void setInsertionPointAfter(Operation &op) { +// setLastLoc(op.getLoc()); +// builder->setInsertionPointAfter(&op); +// } + +// void restoreInsertionPoint(OpBuilder::InsertPoint pt) { +// if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) +// setLastLoc(pt.getPoint()->getLoc()); +// else +// setLastLoc(builder->getUnknownLoc()); +// builder->restoreInsertionPoint(pt); +// } + +// template OpTy create(Args &&...args) { +// auto loc = getLastLoc(); +// return builder->create(loc, std::forward(args)...); +// } + +// // Overload to create or fold a single result operation. +// template +// std::enable_if_t(), Value> +// createOrFold(Args &&...args) { +// auto loc = getLastLoc(); +// return builder->createOrFold(loc, std::forward(args)...); +// } + +// // Overload to create or fold a zero result operation. +// template +// std::enable_if_t(), OpTy> +// createOrFold(Args &&...args) { +// auto loc = getLastLoc(); +// return builder->createOrFold(loc, std::forward(args)...); +// } + +// private: +// std::unique_ptr builder; +// std::unique_ptr lastLoc; +// bool lineInfoEnabled = +// !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +// }; std::string locationToString(Location loc) { std::string str; @@ -171,6 +145,15 @@ void outputWarning(Location loc, const std::string &msg) { /* Python bindings for ir */ /*****************************************************************************/ +namespace ir { + +// Pointer to the TritonOpBuilder class, used to register IR ops for third-party +// dialects. +static py::class_ *builderClassPtr = nullptr; +py::class_ *getBuilderClass() { return builderClassPtr; } + +} // namespace ir + void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; @@ -617,8 +600,10 @@ void init_triton_ir(py::module &&m) { py::class_(m, "InsertPoint", py::module_local()); - py::class_(m, "builder", py::module_local(), - py::dynamic_attr()) + static py::class_ builderClass( + m, "builder", py::module_local(), py::dynamic_attr()); + ir::builderClassPtr = &builderClass; + builderClass .def(py::init(), py::arg("context"), py::arg("compile_mode") = "simd", "Create a TritonOpBuilder with optional compile_mode (simt or simd, " @@ -745,6 +730,16 @@ void init_triton_ir(py::module &&m) { return self.create( self.getBuilder().getF64FloatAttr(v)); }) + .def("get_fp8e4nv", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + FloatAttr::get(self.getBuilder().getFloat8E4M3FNType(), v)); + }) + .def("get_fp8e5", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + FloatAttr::get(self.getBuilder().getFloat8E5M2Type(), v)); + }) .def("get_null_value", [](TritonOpBuilder &self, Type type) -> Value { if (auto floatTy = dyn_cast(type)) @@ -834,6 +829,21 @@ void init_triton_ir(py::module &&m) { std::vector &shape) -> Type { return RankedTensorType::get(shape, elementType); }) + .def("get_buffer_ty", + [](TritonOpBuilder &self, std::vector &shape, + Type &elementType, const Attribute &memorySpace) -> Type { + return MemRefType::get(shape, elementType, + MemRefLayoutAttrInterface{}, memorySpace); + }) + .def("get_buffer_ty_with_strides", + [](TritonOpBuilder &self, std::vector &shape, + Type &elementType, const std::vector &strides, + const Attribute &memorySpace) -> Type { + // create a layout with strides, using dynamic offset + auto layout = StridedLayoutAttr::get( + self.getBuilder().getContext(), ShapedType::kDynamic, strides); + return MemRefType::get(shape, elementType, layout, memorySpace); + }) .def("get_function_ty", [](TritonOpBuilder &self, std::vector inTypes, std::vector outTypes) -> Type { @@ -1331,6 +1341,14 @@ void init_triton_ir(py::module &&m) { std::vector &indices) -> void { self.create(desc, value, indices); }) + // Make a tensor descriptor + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &tensorShape, + bool isSignedInteger) -> Value { + return self.create(base, shape, strides, + tensorShape, isSignedInteger); + }) .def("create_tensormap_create", [](TritonOpBuilder &self, Value desc_ptr, Value global_address, std::vector box_dim, std::vector global_dim, @@ -1387,90 +1405,6 @@ void init_triton_ir(py::module &&m) { auto op = self.create(a); return std::vector(op->result_begin(), op->result_end()); }) - .def("create_extract_scalar", - [](TritonOpBuilder &self, Value &src, - std::vector &indices) -> Value { - llvm::SmallVector arg_indices; - for (const auto &i : indices) { - auto iTy = i.getType(); - if (!iTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), i); - arg_indices.push_back(v); - } else { - arg_indices.push_back(i); - } - } - auto ret = self.create(src, arg_indices); - return ret; - }) - .def("create_extract_slice", - [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, - std::vector &sizs_vec, std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get( - retSizes, - cast(ful.getType()).getElementType()); - - return self.create(retTy, ful, offsets, - sizes, strides); - }) - .def("create_insert_slice", - [](TritonOpBuilder &self, Value &ful, Value &sub, - std::vector &offs_vec, std::vector &sizs_vec, - std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get( - retSizes, - cast(ful.getType()).getElementType()); - auto ret = self.create(sub, ful, offsets, - sizes, strides); - return ret; - }) // Implements tl.trans and tl.permute. .def("create_trans", [](TritonOpBuilder &self, Value &arg, @@ -1608,10 +1542,6 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) - .def("create_tanh", - [](TritonOpBuilder &self, Value &val) -> Value { - return self.create(val); - }) .def("create_sqrt", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); @@ -1725,218 +1655,6 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &ptr, std::vector &offsets) -> Value { return self.create(ptr.getType(), ptr, offsets); - }) - // Add custom op - .def("create_custom_op_for_inter_core_sync", - [](TritonOpBuilder &self, std::string &op_name, - std::string &mode_or_sender, int id) -> void { - auto args = self.getBuilder().getArrayAttr( - {self.getBuilder().getStringAttr(mode_or_sender), - self.getBuilder().getI32IntegerAttr(id)}); - self.create(op_name, args, ValueRange()); - }) - // Make a tensor descriptor - .def("create_make_tensor_descriptor", - [](TritonOpBuilder &self, Value &base, std::vector &shape, - std::vector &strides, std::vector &tensorShape, - bool isSignedInteger) -> Value { - return self.create(base, shape, strides, - tensorShape, isSignedInteger); - }) - // Index select SIMD operation - .def("create_index_select_simd", - [](TritonOpBuilder &self, Value &src, Value &index, int32_t dim, - std::vector &srcShape, std::vector &srcOffset, - std::vector &readShape, - std::vector &returnShape) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - // Get element type from source pointer - Type elemType; - if (auto ptrTy = dyn_cast(src.getType())) { - elemType = ptrTy.getPointeeType(); - } else { - llvm::report_fatal_error( - "index_select_simd: src must be pointer type"); - } - - // Create return tensor type - llvm::SmallVector retShape; - for (const auto &s : returnShape) { - retShape.push_back(s); - } - auto retTensorType = RankedTensorType::get(retShape, elemType); - - // Convert srcShape and srcOffset values to index type if needed - llvm::SmallVector srcShapeIndex; - for (auto val : srcShape) { - if (!val.getType().isIndex()) { - val = self.create(builder.getIndexType(), - val); - } - srcShapeIndex.push_back(val); - } - - llvm::SmallVector srcOffsetIndex; - for (auto val : srcOffset) { - if (!val.getType().isIndex()) { - val = self.create(builder.getIndexType(), - val); - } - srcOffsetIndex.push_back(val); - } - - // Create attributes - auto dimAttr = builder.getI32IntegerAttr(dim); - auto readShapeAttr = builder.getDenseI32ArrayAttr(readShape); - - // Create the IndexSelectSimdOp - // Parameter order must match TritonOps.td definition: - // src, index, dim, src_shape, src_offset, read_shape - auto indexSelectSimdOp = builder.create( - loc, - retTensorType, // result type - src, // src pointer - index, // index tensor - dimAttr, // dim attribute - srcShapeIndex, // src_shape (variadic, index type) - srcOffsetIndex, // src_offset (variadic, index type) - readShapeAttr // read_shape attribute - ); - - return indexSelectSimdOp.getResult(); - }) - // Add an annotation - .def("create_annotation", - [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, - Attribute &attrVal) { - auto annotationOp = self.create(ptr); - annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), - attrVal); - }) - .def("create_embedding_gather", - [](TritonOpBuilder &self, Value &src, Value &idx, - const int64_t bound, const int64_t blksiz, - std::vector &offsets, - std::vector &numels) -> Value { - auto elemTy = cast(src.getType()).getPointeeType(); - auto idxTy = cast(idx.getType()); - auto idxShape = idxTy.getShape(); - std::vector retShape(idxShape.begin(), idxShape.end()); - retShape.push_back(blksiz); - auto resType = RankedTensorType::get(retShape, elemTy); - auto idxBitWidth = idxTy.getElementType().getIntOrFloatBitWidth(); - auto bound_val = - self.create(bound, idxBitWidth); - auto blksiz_val = - self.create(blksiz, idxBitWidth); - - return self.create(resType, src, idx, bound_val, - blksiz_val, offsets, numels); - }) - .def("create_index_put", - [](TritonOpBuilder &self, Value &ptr, Value &index, Value &value, - const int32_t dim, const int64_t indexBoundary, - std::vector &endOffset, std::vector &startOffset, - std::vector &dstStride) -> void { - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = - self.create(indexBoundary, BoundI64Ty); - - self.create(ptr, index, value, dim_val, bound_val, - endOffset, startOffset, dstStride); - }) - .def("create_gather_out_to_ub", - [](TritonOpBuilder &self, Value &src, Value &index, - const int64_t indexBoundary, const int32_t dim, - std::vector &srcStride, std::vector &endOffset, - std::vector &startOffset, - std::optional &other) -> Value { - auto elemTy = cast(src.getType()).getPointeeType(); - auto idxTy = cast(index.getType()); - auto idxShape = idxTy.getShape(); - std::vector retShape(idxShape.begin(), idxShape.end()); - auto resType = RankedTensorType::get(retShape, elemTy); - - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = - self.create(indexBoundary, BoundI64Ty); - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - return self.create( - resType, src, index, bound_val, dim_val, srcStride, endOffset, - startOffset, other.value_or(Value())); - }) - .def("create_scatter_ub_to_out", - [](TritonOpBuilder &self, Value &ptr, Value &value, Value &index, - const int64_t indexBoundary, const int32_t dim, - std::vector &dstStride, std::vector &endOffset, - std::vector &startOffset) -> void { - auto idxTy = cast(index.getType()); - - // indexBoundary need to be i64 type - auto BoundI64Ty = self.getBuilder().getI64Type(); - auto bound_val = - self.create(indexBoundary, BoundI64Ty); - // dim need to be i32 type - auto dimI32Ty = self.getBuilder().getI32Type(); - auto dim_val = self.create(dim, dimI32Ty); - - self.create(ptr, value, index, bound_val, - dim_val, dstStride, endOffset, - startOffset); - }) - // Add mod - .def("create_mod", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - auto type = dyn_cast(lhs.getType()); - if (!type) { - type = RankedTensorType::get({1}, lhs.getType()); - auto tensorFromLhs = - self.create(type, lhs); - auto tensorFromRhs = - self.create(type, rhs); - auto resultTensor = self.create( - type, tensorFromLhs, tensorFromRhs); - SmallVector indices{ - self.create(0)}; - return self.create(resultTensor, indices); - } - return self.create(type, lhs, rhs); - }) - // Add sort - .def("create_sort", - [](TritonOpBuilder &self, Value src, int64_t dim, - bool descending) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - auto dimAttr = builder.getI64IntegerAttr(dim); - auto descendingAttr = builder.getBoolAttr(descending); - - auto op = builder.create(loc, src, dimAttr, - descendingAttr); - - return op->getResult(0); - }) - // Add flip - .def("create_flip", - [](TritonOpBuilder &self, Value src, int64_t dim) -> Value { - auto &builder = self.getBuilder(); - auto loc = self.getLastLoc(); - - auto dimAttr = builder.getI64IntegerAttr(dim); - - auto op = builder.create(loc, src, dimAttr); - - return op->getResult(0); }); py::class_(m, "pass_manager", py::module_local()) diff --git a/third_party/ascend/python/src/main.cc b/third_party/ascend/python/src/main.cc new file mode 100644 index 000000000..7664c6bda --- /dev/null +++ b/third_party/ascend/python/src/main.cc @@ -0,0 +1,57 @@ +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Signals.h" +#include + +namespace py = pybind11; + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_buffer_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +void init_triton_stacktrace_hook(pybind11::module &m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_stacktrace_hook(m); + init_triton_env_vars(m); + init_triton_ir(m.def_submodule("ir")); + init_buffer_ir(m.def_submodule("buffer_ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/ascend/python/src/passes.h b/third_party/ascend/python/src/passes.h deleted file mode 100644 index 46801d802..000000000 --- a/third_party/ascend/python/src/passes.h +++ /dev/null @@ -1,40 +0,0 @@ -#define ADD_PASS_WRAPPER_0(name, builder) \ - m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) - -#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ - m.def(name, \ - [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) - -#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ - m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ - pm.addPass(builder(val0, val1)); \ - }) - -#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ - m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ - pm.addPass(builder(val0, val1, val2)); \ - }) - -#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ - m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ - ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) - -#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ - m.def(name, \ - [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) - -#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ - m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ - pm.addPass(builder({val0, val1})); \ - }) - -#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ - m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ - pm.addPass(builder({val0, val1, val2})); \ - }) - -#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ - m.def(name, \ - [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3) { \ - pm.addPass(builder({val0, val1, val2, val3})); \ - }) diff --git a/third_party/ascend/triton_ascend.cpp b/third_party/ascend/triton_ascend.cpp index b3d67f752..4077908ab 100644 --- a/third_party/ascend/triton_ascend.cpp +++ b/third_party/ascend/triton_ascend.cpp @@ -1,18 +1,18 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. */ -#include "incubated/Conversion/BubbleUpOperation/BubbleUpOperation.h" #include "incubated/Conversion/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.h" #include "incubated/Conversion/TritonLinearize/TritonLinearize.h" #include "incubated/Conversion/TritonToAnnotation/TritonToAnnotation.h" #include "incubated/Conversion/TritonToLinalgIncubated/TritonToLinalgIncubatedPass.h" -#include "incubated/Conversion/TritonToUnstructureIncubated/UnstructureConversionPass.h" +#include "incubated/Conversion/TritonToUnstructureIncubated/Passes.h" #include "mlir/Pass/PassManager.h" -#include "npu/Conversion/TritonToHFusion/TritonToHFusion.h" -#include "npu/Conversion/TritonToHIVM/TritonToHIVM.h" +#include "npu/Conversion/TritonToHFusion/Passes.h" +#include "npu/Conversion/TritonToHIVM/Passes.h" #include "npu/Conversion/TritonToLLVM/TritonToLLVM.h" #include "passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimental.h" +#include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" #define PY_SSIZE_T_CLEAN #include @@ -32,8 +32,6 @@ void init_triton_ascend_passes_convert(py::module &&m) { mlir::triton::createTritonToHFusionPass); ADD_PASS_WRAPPER_0("add_triton_to_llvm", mlir::triton::createTritonToLLVMPass); - ADD_PASS_WRAPPER_0("add_bubble_up_operation", - mlir::triton::createBubbleUpOperationPass); m.def( "add_triton_discretemaskaccessconversion", [](mlir::PassManager &pm, bool compile_on_910_95, @@ -46,18 +44,14 @@ void init_triton_ascend_passes_convert(py::module &&m) { }, py::arg("pm"), py::arg("compile_on_910_95"), py::arg("force_simt_template")); - m.def( - "add_triton_to_unstructure_incubated", - [](mlir::PassManager &pm, bool compile_on_910_95, - bool force_simt_template) { - TritonToUnstructureIncubatedOptions options; - options.compileOn91095 = compile_on_910_95; - options.forceSimtTemplate = force_simt_template; - pm.addPass( - mlir::triton::createTritonToUnstructureIncubatedPass(options)); - }, - py::arg("pm"), py::arg("compile_on_910_95"), - py::arg("force_simt_template")); + m.def("add_triton_to_unstructure_incubated", [](mlir::PassManager &pm, + bool compileOn91095, bool forceSimtTemplate) { + TritonToUnstructureIncubatedOptions opts; + opts.compileOn91095 = compileOn91095; + opts.forceSimtTemplate = forceSimtTemplate; + pm.addPass(mlir::triton::createTritonToUnstructureIncubatedPass(opts));}); + m.def("add_bubble_up_operation", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createBubbleUpOperationPass());}); m.def( "add_triton_to_linalg_incubated", [](mlir::PassManager &pm, bool global_kernel, bool named_ops,